"""
Live signal-quality monitoring.

Given a short rolling window of recent samples from a stream, compute the quality
indicators a researcher needs to catch problems DURING a recording — not after:

  * sample-rate health   — actual Hz vs the stream's nominal rate,
  * gaze validity         — fraction of on-screen (non-NaN, in-range) samples,
  * flatline / dropout    — signal not changing (dead electrode, frozen tracker),
  * saturation / clipping — pinned at the rail,
  * NaN ratio, RMS, range — generic amplitude health.

Each stream gets a status: "good" / "warn" / "bad", with the reasons, so the
monitor UI can show a light per sensor. Thresholds are sensible defaults per
modality and tunable. Pure numpy; works on any stream window.
"""

from __future__ import annotations

import numpy as np

# per-modality expectations (range is in the signal's own units)
PROFILES = {
    "Gaze":   {"valid_min": 0.80, "lo": 0.0, "hi": 1.0},        # normalized screen
    "GSR":    {"flat_eps": 1e-4, "lo": 0.0, "hi": 100.0},       # microsiemens
    "ECG":    {"flat_eps": 1e-3},
    "EEG":    {"flat_eps": 1e-2, "sat": 0.95},                  # uV (board-scaled)
    "EMG":    {"flat_eps": 1e-3},
    "EDA":    {"flat_eps": 1e-4},
}


def stream_quality(times, data, *, stype: str = "", srate: float = 0.0,
                   channels=None, window: float = 2.0) -> dict:
    """
    Quality of the most recent `window` seconds of a stream.

    times : 1-D array of sample timestamps (s)
    data  : list/array of samples (scalar or per-channel)
    """
    t = np.asarray(times, dtype="float64")
    if len(t) == 0:
        return {"status": "bad", "reasons": ["no data"], "rate_hz": 0.0,
                "n": 0, "metrics": {}}
    # restrict to the rolling window
    keep = t >= (t[-1] - window)
    t = t[keep]
    x = np.asarray(data, dtype="float64")
    x = x[keep] if x.ndim == 1 else x[keep, :]

    span = (t[-1] - t[0]) or 1e-9
    rate = (len(t) - 1) / span if len(t) > 1 else 0.0
    reasons, metrics = [], {"rate_hz": round(float(rate), 1)}
    status = "good"

    prof = PROFILES.get(stype, {})

    # --- sample-rate health (regular streams only) ---
    if srate and srate > 0:
        ratio = rate / srate
        metrics["rate_ratio"] = round(float(ratio), 2)
        if ratio < 0.5:
            status = "bad"; reasons.append(f"low rate {rate:.0f}/{srate:.0f} Hz")
        elif ratio < 0.8:
            status = _worse(status, "warn"); reasons.append("rate dropping")

    flat = x[:, 0] if x.ndim > 1 else x
    nan_ratio = float(np.mean(~np.isfinite(flat))) if len(flat) else 1.0
    metrics["nan_ratio"] = round(nan_ratio, 3)
    finite = flat[np.isfinite(flat)]

    # --- gaze validity ---
    if stype == "Gaze":
        X = x if x.ndim > 1 else x.reshape(-1, 1)
        gx, gy = X[:, 0], (X[:, 1] if X.shape[1] > 1 else X[:, 0])
        on = np.isfinite(gx) & np.isfinite(gy) & (gx >= -0.05) & (gx <= 1.05) \
            & (gy >= -0.05) & (gy <= 1.05)
        valid = float(np.mean(on)) if len(on) else 0.0
        metrics["validity"] = round(valid, 3)
        vmin = prof.get("valid_min", 0.8)
        if valid < vmin * 0.6:
            status = "bad"; reasons.append(f"gaze validity {valid*100:.0f}%")
        elif valid < vmin:
            status = _worse(status, "warn"); reasons.append(f"gaze validity {valid*100:.0f}%")

    # --- amplitude / flatline / saturation (continuous signals) ---
    if len(finite) > 2 and stype != "Gaze":
        rng = float(np.max(finite) - np.min(finite))
        rms = float(np.sqrt(np.mean((finite - np.mean(finite)) ** 2)))
        metrics["range"] = round(rng, 4); metrics["rms"] = round(rms, 4)
        eps = prof.get("flat_eps", 1e-6)
        if rng < eps:
            status = "bad"; reasons.append("flatline (no signal)")
        # clipping: many samples pinned at the observed extremes
        hi = np.max(finite); lo = np.min(finite)
        pinned = float(np.mean((np.abs(finite - hi) < 1e-9) | (np.abs(finite - lo) < 1e-9)))
        if pinned > 0.30 and rng > 0:
            status = _worse(status, "warn"); reasons.append("clipping/saturation")

    if nan_ratio > 0.5:
        status = "bad"; reasons.append("mostly invalid samples")
    elif nan_ratio > 0.1:
        status = _worse(status, "warn"); reasons.append("some invalid samples")

    return {"status": status, "reasons": reasons, "rate_hz": metrics["rate_hz"],
            "n": int(len(t)), "metrics": metrics}


def _worse(a, b):
    order = {"good": 0, "warn": 1, "bad": 2}
    return a if order[a] >= order[b] else b


def session_quality(window_streams: dict, *, window: float = 2.0) -> dict:
    """Quality for every stream in a {name: {stype, srate, t, x}} mapping."""
    out = {}
    for name, s in window_streams.items():
        out[name] = stream_quality(s.get("t", []), s.get("x", []),
                                   stype=s.get("type", ""), srate=s.get("srate", 0.0),
                                   window=window)
    return out
