"""
Presentation layer — Plotly/Dash dashboard (Phase 5).

Reads a recorded session file and renders the iMotions "Data Visualization
Dashboard" equivalents:

  * synchronized multi-stream timeline (the proof that everything is on one clock),
    with stimulus markers overlaid,
  * event-locked aggregates per stimulus (the per-stimulus comparison view),
  * a validated-physiology summary (HR/HRV, EDA/SCR) from `analysis.physio`,
  * an affect circumplex (valence × arousal) — arousal from EDA; valence axis is
    the documented wire-in point for FEA (py-feat) once that source is recording.

Two entry points:
  * `export_html(path, out)` — write a single self-contained HTML report (no server).
  * `serve(path)`           — run an interactive Dash app with a Reload button.
"""

from __future__ import annotations

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from ..analysis import load as A
from ..analysis import physio

PALETTE = ["#4C78A8", "#F58518", "#54A24B", "#E45756", "#72B7B2", "#B279A2"]


# --------------------------------------------------------------------------
# event-locked aggregation
# --------------------------------------------------------------------------
def event_locked_table(session: dict, window: float = 1.5) -> list[dict]:
    """Per stimulus onset, mean of each continuous signal over [t, t+window]."""
    rows = []
    streams = {n: s for n, s in session["streams"].items() if s["type"] != "Markers"}
    for t, label in A.markers(session):
        if not str(label).startswith("stim_on"):
            continue
        name = str(label).split(":", 1)[1] if ":" in str(label) else str(label)
        row = {"stimulus": name, "onset": float(t)}
        for sname, s in streams.items():
            ts = np.asarray(s["t"])
            mask = (ts >= t) & (ts <= t + window)
            if mask.any():
                seg = np.asarray(s["x"])[mask]
                row[sname] = float(np.mean(seg[:, 0] if seg.ndim > 1 else seg))
        rows.append(row)
    return rows


def _channel_mean(session, stream, channel, onset, window):
    s = session["streams"].get(stream)
    if not s or not len(s["t"]):
        return None
    ts = np.asarray(s["t"]); m = (ts >= onset) & (ts <= onset + window)
    if not m.any():
        return None
    x = np.asarray(s["x"]); x = x[:, channel] if x.ndim > 1 else x
    return float(np.mean(x[m]))


def _arousal_valence(session: dict, window=1.5):
    """Per-stimulus valence × arousal for the circumplex.

    Prefers the FEA stream (real valence + arousal from facial expression) when it
    was recorded; otherwise falls back to EDA-derived arousal with a flat valence.
    """
    rows = event_locked_table(session, window)
    if not rows:
        return rows
    streams = session["streams"]
    fea_chans = streams.get("FEA", {}).get("channels", [])
    use_fea = "valence" in fea_chans and "arousal" in fea_chans
    if use_fea:
        vi, ai = fea_chans.index("valence"), fea_chans.index("arousal")
        for r in rows:
            r["valence"] = _channel_mean(session, "FEA", vi, r["onset"], window) or 0.0
            r["arousal"] = _channel_mean(session, "FEA", ai, r["onset"], window) or 0.0
        return rows
    # fallback: arousal from EDA, valence unavailable
    eda = np.array([r.get("GSR", np.nan) for r in rows], dtype="float64")
    mu, sd = np.nanmean(eda), (np.nanstd(eda) or 1.0)
    for r in rows:
        r["arousal"] = float((r.get("GSR", mu) - mu) / sd)
        r["valence"] = 0.0
    return rows


# --------------------------------------------------------------------------
# figures
# --------------------------------------------------------------------------
def fig_timeline(session: dict) -> go.Figure:
    streams = [(n, s) for n, s in session["streams"].items() if s["type"] != "Markers"]
    fig = make_subplots(rows=len(streams), cols=1, shared_xaxes=True,
                        subplot_titles=[f"{n} ({s['type']}, {s['srate']:.0f} Hz)"
                                        for n, s in streams])
    t0 = A.common_window(session)[0]
    for i, (n, s) in enumerate(streams, start=1):
        t = np.asarray(s["t"]) - t0
        x = np.asarray(s["x"]); x = x[:, 0] if x.ndim > 1 else x
        fig.add_trace(go.Scattergl(x=t, y=x, mode="lines", name=n,
                                   line=dict(width=1, color=PALETTE[(i-1) % len(PALETTE)])),
                      row=i, col=1)
    for t, label in A.markers(session):
        if str(label).startswith("stim_on"):
            fig.add_vline(x=float(t) - t0, line=dict(color="rgba(120,120,120,.5)",
                                                     width=1, dash="dot"))
    fig.update_layout(height=180 * len(streams), showlegend=False,
                      title="Synchronized streams (one clock) with stimulus onsets",
                      margin=dict(l=50, r=20, t=60, b=40))
    fig.update_xaxes(title_text="seconds from session start", row=len(streams), col=1)
    return fig


def fig_event_bars(session: dict, window=1.5) -> go.Figure:
    rows = event_locked_table(session, window)
    fig = go.Figure()
    if rows:
        signals = [k for k in rows[0] if k not in ("stimulus", "onset")]
        labels = [f"{r['stimulus']}#{i+1}" for i, r in enumerate(rows)]
        for j, sig in enumerate(signals):
            fig.add_trace(go.Bar(name=sig, x=labels,
                                 y=[r.get(sig) for r in rows],
                                 marker_color=PALETTE[j % len(PALETTE)]))
    fig.update_layout(barmode="group", title=f"Event-locked means per stimulus (0–{window}s)",
                      height=380, margin=dict(l=50, r=20, t=60, b=60),
                      xaxis_title="stimulus presentation", yaxis_title="mean signal")
    return fig


def fig_circumplex(session: dict) -> go.Figure:
    rows = _arousal_valence(session)
    fig = go.Figure()
    if rows and "arousal" in rows[0]:
        fig.add_trace(go.Scatter(
            x=[r.get("valence", 0) for r in rows],
            y=[r.get("arousal", 0) for r in rows],
            mode="markers+text", text=[r["stimulus"] for r in rows],
            textposition="top center",
            marker=dict(size=14, color=[r.get("arousal", 0) for r in rows],
                        colorscale="RdYlBu_r", showscale=True,
                        colorbar=dict(title="arousal"))))
    has_fea = "valence" in session["streams"].get("FEA", {}).get("channels", [])
    fig.add_hline(y=0, line=dict(color="#999", width=1))
    fig.add_vline(x=0, line=dict(color="#999", width=1))
    fig.update_layout(
        title=("Affect circumplex — valence × arousal (from FEA / py-feat)" if has_fea
               else "Affect circumplex — arousal (EDA); add an FEA source for valence"),
        height=420,
        xaxis_title="valence" if has_fea else "valence (needs FEA source)",
        yaxis_title="arousal (facial)" if has_fea else "arousal (z-scored EDA)",
        margin=dict(l=50, r=20, t=60, b=50))
    return fig


def fig_physio_table(session: dict) -> go.Figure:
    summ = physio.summary(session)
    flat = []
    for fam, metrics in summ.items():
        for k, v in metrics.items():
            flat.append((fam, k, f"{v:.2f}" if isinstance(v, float) else str(v)))
    fig = go.Figure(go.Table(
        header=dict(values=["family", "metric", "value"],
                    fill_color="#37626F", font=dict(color="white")),
        cells=dict(values=[[r[0] for r in flat], [r[1] for r in flat],
                           [r[2] for r in flat]])))
    fig.update_layout(title="Validated physiology (NeuroKit2)",
                      height=60 + 28 * max(len(flat), 1),
                      margin=dict(l=20, r=20, t=60, b=20))
    return fig


def fig_gaze_metrics(session: dict):
    """Eye-movement metrics table (fixations, saccades, scanpath)."""
    from ..analysis import eyetracking as ET
    m = ET.gaze_metrics(session)
    label = {"fixation_count": "fixations", "fixation_rate_hz": "fixation rate (Hz)",
             "mean_fixation_ms": "mean fixation (ms)", "saccade_count": "saccades",
             "mean_saccade_amp_deg": "mean saccade (deg)",
             "scanpath_len_norm": "scanpath length", "valid_ratio": "valid gaze"}
    rows = [(label.get(k, k), f"{v:.2f}" if isinstance(v, float) else str(v))
            for k, v in m.items() if k in label]
    fig = go.Figure(go.Table(
        header=dict(values=["metric", "value"], fill_color="#2f6470",
                    font=dict(color="white")),
        cells=dict(values=[[r[0] for r in rows], [r[1] for r in rows]])))
    fig.update_layout(title="Eye-movement metrics (I-VT)", height=60 + 28 * len(rows),
                      margin=dict(l=20, r=20, t=50, b=20))
    return fig


def build_figures(session: dict) -> dict:
    figs = {
        "timeline": fig_timeline(session),
        "event_bars": fig_event_bars(session),
        "circumplex": fig_circumplex(session),
        "physio": fig_physio_table(session),
    }
    # eye-tracking views only when a gaze stream was recorded
    if "EyeTracker" in session["streams"]:
        from ..analysis import eyetracking as ET
        figs["gaze_metrics"] = fig_gaze_metrics(session)
        figs["heatmap"] = ET.fig_heatmap(session)
        figs["scanpath"] = ET.fig_scanpath(session)
    return figs


# --------------------------------------------------------------------------
# entry points
# --------------------------------------------------------------------------
def export_html(session_path: str, out_path: str, *, back_href: str | None = None,
                subtitle: str | None = None) -> str:
    """Render a single self-contained HTML report from a session file.

    `back_href` adds the shared Back/Home bar (used when served inside the app)."""
    from ..app.ui import THEME_CSS, topbar
    session = A.load_session(session_path)
    figs = build_figures(session)
    cards = []
    order = ["timeline", "heatmap", "scanpath", "gaze_metrics",
             "event_bars", "circumplex", "physio"]
    first = True
    for key in order:
        if key not in figs:
            continue
        body = figs[key].to_html(full_html=False, include_plotlyjs=("cdn" if first else False),
                                 config={"displayModeBar": False})
        cards.append(f'<div class="card plot">{body}</div>')
        first = False
    head = (f'<!doctype html><html lang="en"><head><meta charset="utf-8">'
            f'<meta name="viewport" content="width=device-width, initial-scale=1">'
            f'<title>biosync — dashboard</title>'
            f'<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">'
            f'<style>{THEME_CSS}</style></head><body>')
    page = (head + topbar(subtitle=subtitle or "Dashboard", back_href=back_href)
            + '<div class="wrap">' + "\n".join(cards) + "</div></body></html>")
    with open(out_path, "w") as f:
        f.write(page)
    return out_path


def export_group_html(session_paths, out_path: str, *, aois=None,
                      back_href: str | None = None, subtitle: str | None = None) -> str:
    """Render a multi-participant (group) dashboard from several sessions."""
    from ..app.ui import THEME_CSS, topbar
    from ..analysis import aggregate as AG
    figs = [AG.fig_group_heatmap(session_paths), AG.fig_group_gaze(session_paths),
            AG.fig_group_physio(session_paths)]
    if aois:
        figs.insert(1, AG.fig_group_aoi(session_paths, aois))
    cards, first = [], True
    for fg in figs:
        body = fg.to_html(full_html=False, include_plotlyjs=("cdn" if first else False),
                          config={"displayModeBar": False})
        cards.append(f'<div class="card plot">{body}</div>')
        first = False
    head = (f'<!doctype html><html lang="en"><head><meta charset="utf-8">'
            f'<meta name="viewport" content="width=device-width, initial-scale=1">'
            f'<title>biosync — group</title>'
            f'<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">'
            f'<style>{THEME_CSS}</style></head><body>')
    page = (head + topbar(subtitle=subtitle or "Group analysis", back_href=back_href)
            + '<div class="wrap">' + "\n".join(cards) + "</div></body></html>")
    with open(out_path, "w") as f:
        f.write(page)
    return out_path


def serve(session_path: str, host="127.0.0.1", port=8050):  # pragma: no cover
    """Interactive Dash app. `pip install dash` and open http://host:port."""
    from dash import Dash, dcc, html, Input, Output

    app = Dash(__name__)
    app.layout = html.Div([
        html.H1("biosync session"),
        html.Button("Reload", id="reload"),
        dcc.Graph(id="timeline"), dcc.Graph(id="bars"),
        dcc.Graph(id="circumplex"), dcc.Graph(id="physio"),
    ])

    @app.callback(Output("timeline", "figure"), Output("bars", "figure"),
                  Output("circumplex", "figure"), Output("physio", "figure"),
                  Input("reload", "n_clicks"))
    def _load(_):
        figs = build_figures(A.load_session(session_path))
        return figs["timeline"], figs["event_bars"], figs["circumplex"], figs["physio"]

    app.run(host=host, port=port, debug=False)
