"""
The sync core (hardened).

Resolves LSL streams, opens an inlet per stream, and pulls samples concurrently.
The crucial step is `inlet.time_correction()`: LSL measures the clock offset
between each device's outlet and this recorder, so a 250 Hz ECG sample, a 60 Hz
gaze sample, and an irregular event marker all land on ONE common timeline.

Robustness (so a real study doesn't lose data):
  * verify-resolve — keep resolving through the whole session, so late-appearing
    streams are still captured; optionally require an `expect` set and report what
    resolved vs is missing,
  * reconnect — inlets use recover=True (LSL re-establishes transport silently);
    a per-stream watchdog logs any gap so you know where data gaps are,
  * autosave — periodically flush the partial session to disk, so a crash keeps
    most of the recording.

Still writes one self-contained HDF5 per session (one group per stream). The
Source contract and the analysis layer are unchanged; `record(duration, out)`
works exactly as before, with the new behavior as opt-in keyword args.
"""

from __future__ import annotations

import os
import threading
import time
from dataclasses import dataclass, field

import h5py
import numpy as np
from pylsl import StreamInlet, cf_string, local_clock, resolve_streams


@dataclass
class _Buffer:
    name: str
    stype: str
    channels: list
    srate: float
    is_string: bool
    lock: threading.Lock = field(default_factory=threading.Lock)
    times: list = field(default_factory=list)
    data: list = field(default_factory=list)
    gaps: list = field(default_factory=list)        # (start, end) LSL-clock gaps
    last_recv: float = 0.0                           # wall time of last sample


class Recorder:
    def __init__(self, resolve_timeout: float = 2.0, autosave_every: float = 0.0,
                 gap_threshold: float = 1.0):
        self.resolve_timeout = resolve_timeout
        self.autosave_every = autosave_every         # 0 = off
        self.gap_threshold = gap_threshold           # s without samples = a gap
        self._buffers: dict[str, _Buffer] = {}
        self._open_uids: set = set()                 # source uids already inlet-ed
        self._threads: list[threading.Thread] = []
        self._stop = threading.Event()
        self._t0_lsl: float | None = None
        self.summary: dict = {}

    # --- per-stream pull loop -------------------------------------------
    def _pull_loop(self, inlet: StreamInlet, buf: _Buffer):
        try:
            offset = inlet.time_correction()
        except Exception:
            offset = 0.0
        last_corr = time.time()
        buf.last_recv = time.time()
        while not self._stop.is_set():
            try:
                chunk, stamps = inlet.pull_chunk(timeout=0.2, max_samples=256)
            except Exception:
                # transport hiccup; recover=True will rebuild — keep going
                time.sleep(0.05)
                continue
            now = time.time()
            if stamps:
                if now - last_corr > 5:
                    try:
                        offset = inlet.time_correction()
                    except Exception:
                        pass
                    last_corr = now
                # gap detection: a quiet span longer than threshold
                if buf.last_recv and (now - buf.last_recv) > self.gap_threshold:
                    buf.gaps.append((buf.last_recv, now))
                buf.last_recv = now
                with buf.lock:
                    for s, ts in zip(chunk, stamps):
                        buf.times.append(ts + offset)
                        buf.data.append(s[0] if buf.is_string else s)

    # --- open any newly-resolved streams --------------------------------
    def _open_new_streams(self):
        try:
            streams = resolve_streams(self.resolve_timeout)
        except Exception:
            return
        for info in streams:
            uid = info.source_id() or f"{info.name()}|{info.uid()}"
            if uid in self._open_uids:
                continue
            inlet = StreamInlet(info, max_buflen=360, recover=True)
            full = inlet.info()
            labels, ch = [], full.desc().child("channels").child("channel")
            for _ in range(full.channel_count()):
                labels.append(ch.child_value("label") or f"ch{len(labels)}")
                ch = ch.next_sibling()
            name = full.name()
            # if a name re-appears (restarted source), suffix to avoid clobbering
            key = name if name not in self._buffers else f"{name}#{len(self._buffers)}"
            buf = _Buffer(name=name, stype=full.type(), channels=labels,
                          srate=full.nominal_srate(),
                          is_string=(full.channel_format() == cf_string))
            self._buffers[key] = buf
            self._open_uids.add(uid)
            t = threading.Thread(target=self._pull_loop, args=(inlet, buf), daemon=True)
            self._threads.append(t)
            t.start()

    def _resolver_loop(self):
        """Keep resolving through the session to catch late / restarted streams."""
        while not self._stop.is_set():
            self._open_new_streams()
            self._stop.wait(0.7)

    def _autosave_loop(self, out_path):
        while not self._stop.is_set():
            self._stop.wait(self.autosave_every)
            if self._stop.is_set():
                break
            try:
                self._write_hdf5(out_path + ".part", partial=True)
                os.replace(out_path + ".part", out_path)   # atomic-ish
            except Exception:
                pass

    # --- main entry ------------------------------------------------------
    def record(self, duration: float, out_path: str, *, expect=None,
               settle: float = 1.0) -> str:
        """
        Record `duration` seconds to `out_path`.

        expect : optional iterable of stream NAMES that must be present; the result
                 summary flags any that never resolved.
        settle : seconds to wait up front for streams to appear before timing the
                 duration (so a slow device doesn't get clipped).
        """
        self._t0_lsl = local_clock()
        self._open_new_streams()
        # brief settle so all expected streams come up before the clock starts
        t_settle = time.time()
        want = set(expect or [])
        while time.time() - t_settle < settle:
            self._open_new_streams()
            if want and want <= set(b.name for b in self._buffers.values()):
                break
            time.sleep(0.2)
        if not self._buffers:
            raise RuntimeError("No LSL streams found -- are sources running?")

        # continuous resolver + optional autosave run for the whole session
        resolver = threading.Thread(target=self._resolver_loop, daemon=True)
        resolver.start()
        if self.autosave_every > 0:
            threading.Thread(target=self._autosave_loop, args=(out_path,),
                             daemon=True).start()

        time.sleep(duration)
        self._stop.set()
        for t in self._threads:
            t.join(timeout=2)
        resolver.join(timeout=2)

        self._build_summary(expect)
        path = self._write_hdf5(out_path)
        if os.path.exists(out_path + ".part"):
            try:
                os.remove(out_path + ".part")
            except Exception:
                pass
        return path

    def _build_summary(self, expect):
        present = {b.name for b in self._buffers.values()}
        counts = {}
        gaps = {}
        for buf in self._buffers.values():
            with buf.lock:
                counts[buf.name] = len(buf.times)
            if buf.gaps:
                gaps[buf.name] = buf.gaps
        missing = sorted(set(expect or []) - present)
        self.summary = {
            "resolved": sorted(present),
            "missing": missing,
            "complete": not missing,
            "samples": counts,
            "gaps": gaps,
            "empty_streams": sorted(n for n, c in counts.items() if c == 0),
        }
        return self.summary

    # --- write -----------------------------------------------------------
    def _write_hdf5(self, out_path: str, partial: bool = False) -> str:
        with h5py.File(out_path, "w") as f:
            f.attrs["recorder_clock_origin"] = self._t0_lsl
            f.attrs["created"] = time.strftime("%Y-%m-%dT%H:%M:%S")
            f.attrs["partial"] = bool(partial)
            for name, buf in self._buffers.items():
                with buf.lock:
                    times = list(buf.times)
                    data = list(buf.data)
                    gaps = list(buf.gaps)
                g = f.create_group(name)
                g.attrs["type"] = buf.stype
                g.attrs["nominal_srate"] = buf.srate
                g.attrs["channels"] = [c.encode() for c in buf.channels]
                if gaps:
                    g.attrs["gaps"] = np.asarray(gaps, dtype="float64")
                g.create_dataset("timestamps", data=np.asarray(times, dtype="float64"))
                if buf.is_string:
                    g.create_dataset("data", data=np.asarray(data, dtype=object),
                                     dtype=h5py.string_dtype())
                else:
                    # guard against a stream that resolved but sent nothing yet
                    arr = np.asarray(data, dtype="float32") if data else \
                        np.empty((0, max(1, len(buf.channels))), dtype="float32")
                    g.create_dataset("data", data=arr)
        return out_path
