"""
Analysis layer.

Reads the synchronized session and shows the two primitives every iMotions-style
metric is built from:

1. event-locking -- given a stimulus marker, slice each continuous signal around
   it on the shared clock. This is the basis for AOI dwell time, ERP epochs, and
   stimulus-by-stimulus aggregation.
2. validated physiology -- NeuroKit2 turns the raw ECG into heart rate / HRV,
   the EDA into tonic+phasic SCR features, etc. (the R-notebook replacement).
"""

from __future__ import annotations

import h5py
import neurokit2 as nk
import numpy as np


def load_session(path: str) -> dict:
    out = {"streams": {}}
    with h5py.File(path, "r") as f:
        out["clock_origin"] = float(f.attrs["recorder_clock_origin"])
        for name in f:
            g = f[name]
            out["streams"][name] = {
                "type": g.attrs["type"],
                "srate": float(g.attrs["nominal_srate"]),
                "channels": [c.decode() if isinstance(c, bytes) else c
                             for c in g.attrs["channels"]],
                "t": g["timestamps"][:],
                "x": g["data"][:],
            }
    return out


def common_window(session: dict) -> tuple[float, float]:
    """Latest start / earliest end across streams -- the fully-overlapping span."""
    starts, ends = [], []
    for s in session["streams"].values():
        if len(s["t"]):
            starts.append(s["t"][0]); ends.append(s["t"][-1])
    return max(starts), min(ends)


def markers(session: dict):
    m = session["streams"].get("Markers")
    if not m:
        return []
    labels = [v.decode() if isinstance(v, bytes) else str(v) for v in m["x"]]
    return list(zip(m["t"], labels))


def value_at(session: dict, stream: str, t_query: float, channel: int = 0):
    """Nearest-sample lookup of a continuous signal at an event time."""
    s = session["streams"][stream]
    i = int(np.argmin(np.abs(s["t"] - t_query)))
    val = s["x"][i]
    return float(val[channel]) if np.ndim(val) else float(val)


def heart_rate(session: dict) -> float | None:
    ecg = session["streams"].get("ECG")
    if not ecg or len(ecg["x"]) < ecg["srate"] * 5:
        return None
    sig = np.asarray(ecg["x"], dtype="float64").ravel()
    clean = nk.ecg_clean(sig, sampling_rate=int(ecg["srate"]))
    _, info = nk.ecg_peaks(clean, sampling_rate=int(ecg["srate"]))
    peaks = info["ECG_R_Peaks"]
    if len(peaks) < 2:
        return None
    rr = np.diff(peaks) / ecg["srate"]
    return float(60.0 / np.mean(rr))
