"""
Multi-participant aggregation (iMotions-class group analysis).

Takes a set of recorded sessions (one or more per participant) and combines them
into group-level results — the step that turns individual recordings into a study:

  * physiology   — mean +/- SD of HR/HRV, EDA SCR across participants,
  * AOI          — per stimulus per AOI: mean dwell, mean TTFF, hit ratio
                   (fraction of participants who fixated it), n,
  * gaze heatmap — summed gaze density across everyone (the group attention map),
  * gaze metrics — group mean fixation count/duration, saccades, scanpath.

Each function accepts a list of session *paths* (or loaded dicts). Reuses the
per-session analyzers (`physio`, `aoi`, `eyetracking`). Pure numpy + optional
plotly for the figures.
"""

from __future__ import annotations

import numpy as np

from . import load as A
from . import physio, aoi as aoimod, eyetracking as ET


def _load(sessions):
    out = []
    for s in sessions:
        out.append(s if isinstance(s, dict) else A.load_session(s))
    return out


def aggregate_physio(sessions) -> dict:
    """Group mean/SD of the key physiology metrics across participants."""
    sess = _load(sessions)
    keys = ["hr_bpm", "sdnn_ms", "rmssd_ms"]
    eda_keys = ["tonic_mean_uS", "scr_per_min"]
    acc = {k: [] for k in keys + eda_keys}
    for s in sess:
        summ = physio.summary(s)
        for k in keys:
            if k in summ.get("heart", {}):
                acc[k].append(summ["heart"][k])
        for k in eda_keys:
            if k in summ.get("eda", {}):
                acc[k].append(summ["eda"][k])
    out = {"n_participants": len(sess)}
    for k, vals in acc.items():
        if vals:
            out[k] = {"mean": float(np.mean(vals)), "sd": float(np.std(vals)),
                      "n": len(vals)}
    return out


def aggregate_aoi(sessions, aois, *, window=2.0) -> dict:
    """Per stimulus per AOI: mean dwell, mean TTFF, hit ratio across participants."""
    sess = _load(sessions)
    # collect: stim -> aoi -> lists across participants
    bucket: dict = {}
    for s in sess:
        per = aoimod.per_stimulus(s, aois, window=window)
        for stim, trials in per.items():
            for trial in trials:
                for aoi_name, m in trial.items():
                    b = bucket.setdefault(stim, {}).setdefault(aoi_name,
                        {"dwell": [], "ttff": [], "hits": 0, "n": 0})
                    b["dwell"].append(m["dwell_time_s"])
                    if m["ttff_s"] is not None:
                        b["ttff"].append(m["ttff_s"])
                    b["hits"] += 1 if m["hit"] else 0
                    b["n"] += 1
    out = {}
    for stim, aois_d in bucket.items():
        out[stim] = {}
        for aoi_name, b in aois_d.items():
            out[stim][aoi_name] = {
                "mean_dwell_s": float(np.mean(b["dwell"])) if b["dwell"] else 0.0,
                "sd_dwell_s": float(np.std(b["dwell"])) if b["dwell"] else 0.0,
                "mean_ttff_s": float(np.mean(b["ttff"])) if b["ttff"] else None,
                "hit_ratio": float(b["hits"] / b["n"]) if b["n"] else 0.0,
                "n": b["n"],
            }
    return out


def aggregate_heatmap(sessions, *, stream="EyeTracker", bins=64, sigma=1.8):
    """Summed, normalized gaze-density grid across all sessions (group attention)."""
    sess = _load(sessions)
    total = None
    for s in sess:
        if stream not in s["streams"]:
            continue
        H, xe, ye = ET.heatmap_grid(s, stream=stream, bins=bins, sigma=sigma)
        total = H if total is None else total + H
    if total is None:
        total = np.zeros((bins, bins))
        xe = ye = np.linspace(0, 1, bins + 1)
    mx = total.max() or 1.0
    return total / mx, xe, ye


def aggregate_gaze_metrics(sessions, **kw) -> dict:
    """Group mean/SD of eye-movement metrics across participants."""
    sess = _load(sessions)
    rows = [ET.gaze_metrics(s, **kw) for s in sess if "EyeTracker" in s["streams"]]
    if not rows:
        return {}
    keys = ["fixation_count", "mean_fixation_ms", "saccade_count",
            "mean_saccade_amp_deg", "scanpath_len_norm"]
    return {"n_participants": len(rows),
            **{k: {"mean": float(np.mean([r[k] for r in rows])),
                   "sd": float(np.std([r[k] for r in rows]))} for k in keys}}


def group_summary(sessions, aois=None, **kw) -> dict:
    """Everything at once — the group-dashboard data bundle."""
    out = {"physio": aggregate_physio(sessions),
           "gaze": aggregate_gaze_metrics(sessions)}
    if aois:
        out["aoi"] = aggregate_aoi(sessions, aois, **kw)
    return out


# --------------------------------------------------------------------------
# figures
# --------------------------------------------------------------------------
def fig_group_heatmap(sessions, **kw):
    import plotly.graph_objects as go
    H, xe, ye = aggregate_heatmap(sessions, **kw)
    fig = go.Figure(go.Heatmap(z=H, x=(xe[:-1]+xe[1:])/2, y=(ye[:-1]+ye[1:])/2,
                               colorscale="Turbo", zsmooth="best"))
    fig.update_yaxes(autorange="reversed", range=[1, 0], title="screen y")
    fig.update_xaxes(range=[0, 1], title="screen x")
    fig.update_layout(title="Group gaze heatmap (all participants)", height=430,
                      margin=dict(l=50, r=20, t=50, b=40))
    return fig


def fig_group_aoi(sessions, aois, *, window=2.0):
    import plotly.graph_objects as go
    agg = aggregate_aoi(sessions, aois, window=window)
    aoi_names = [a.name for a in aois]
    stims = list(agg.keys())
    fig = go.Figure()
    for aoi_name in aoi_names:
        means = [agg[s].get(aoi_name, {}).get("mean_dwell_s", 0) for s in stims]
        sds = [agg[s].get(aoi_name, {}).get("sd_dwell_s", 0) for s in stims]
        fig.add_trace(go.Bar(name=aoi_name, x=stims, y=means,
                             error_y=dict(type="data", array=sds, visible=True)))
    fig.update_layout(barmode="group", title="Mean AOI dwell per stimulus (± SD across participants)",
                      height=400, yaxis_title="dwell (s)", xaxis_title="stimulus",
                      margin=dict(l=50, r=20, t=50, b=50))
    return fig


def fig_group_gaze(sessions, **kw):
    import plotly.graph_objects as go
    agg = aggregate_gaze_metrics(sessions)
    label = {"fixation_count": "fixations", "mean_fixation_ms": "mean fixation (ms)",
             "saccade_count": "saccades", "mean_saccade_amp_deg": "mean saccade (deg)",
             "scanpath_len_norm": "scanpath length"}
    rows = [(label[k], f"{agg[k]['mean']:.2f} ± {agg[k]['sd']:.2f}")
            for k in label if k in agg]
    fig = go.Figure(go.Table(
        header=dict(values=["metric", "mean ± SD"], fill_color="#2f6470",
                    font=dict(color="white")),
        cells=dict(values=[[r[0] for r in rows], [r[1] for r in rows]])))
    fig.update_layout(title=f"Group eye-movement metrics (n={agg.get('n_participants',0)})",
                      height=60 + 28 * max(len(rows), 1), margin=dict(l=20, r=20, t=50, b=20))
    return fig


def fig_group_physio(sessions):
    import plotly.graph_objects as go
    agg = aggregate_physio(sessions)
    rows = []
    for k, v in agg.items():
        if isinstance(v, dict) and "mean" in v:
            rows.append((k, f"{v['mean']:.2f} ± {v['sd']:.2f}", v["n"]))
    fig = go.Figure(go.Table(
        header=dict(values=["metric", "mean ± SD", "n"], fill_color="#2f6470",
                    font=dict(color="white")),
        cells=dict(values=[[r[0] for r in rows], [r[1] for r in rows],
                           [r[2] for r in rows]])))
    fig.update_layout(title=f"Group physiology (n={agg.get('n_participants',0)})",
                      height=60 + 28 * max(len(rows), 1), margin=dict(l=20, r=20, t=50, b=20))
    return fig
