#!/usr/bin/env python3
"""Summarize PoC FDP WAF harness runs.

Input: OUT directory containing R_no_fdp_*, R_separated_*, R_mixed_* dirs.
Output: CSV to stdout.
"""

from __future__ import annotations

import csv
import json
import math
import re
import sys
from pathlib import Path
from statistics import median

MODE_ORDER = {"no_fdp": 0, "mixed": 1, "separated": 2}


def percentile(values: list[float], pct: float) -> float | None:
    if not values:
        return None
    xs = sorted(values)
    if len(xs) == 1:
        return xs[0]
    rank = (len(xs) - 1) * pct
    lo = math.floor(rank)
    hi = math.ceil(rank)
    if lo == hi:
        return xs[lo]
    return xs[lo] + (xs[hi] - xs[lo]) * (rank - lo)


def load_json(path: Path) -> dict:
    try:
        return json.loads(path.read_text(encoding="utf-8"))
    except Exception:
        return {}


def collect_latencies(run_dir: Path) -> dict[str, object]:
    read_lat: list[float] = []
    write_lat: list[float] = []
    all_lat: list[float] = []
    failures = 0
    records = 0
    for path in run_dir.glob("worker_logs/*/measurement_*/records.jsonl"):
        for line in path.read_text(encoding="utf-8", errors="replace").splitlines():
            if not line.strip():
                continue
            try:
                rec = json.loads(line)
            except json.JSONDecodeError:
                continue
            records += 1
            if rec.get("failed"):
                failures += 1
            lat = rec.get("latency_ms")
            if lat is None:
                continue
            try:
                lat_f = float(lat)
            except (TypeError, ValueError):
                continue
            all_lat.append(lat_f)
            qual = str(rec.get("qualname", "")).lower()
            if "get" in qual or "read" in qual or "load" in qual:
                read_lat.append(lat_f)
            if "put" in qual or "write" in qual or "store" in qual:
                write_lat.append(lat_f)
    return {
        "records": records,
        "failures": failures,
        "read_p50_ms": median(read_lat) if read_lat else None,
        "read_p99_ms": percentile(read_lat, 0.99),
        "write_p50_ms": median(write_lat) if write_lat else None,
        "write_p99_ms": percentile(write_lat, 0.99),
        "all_p99_ms": percentile(all_lat, 0.99),
    }


def parse_fdp_stats(path: Path) -> dict[str, int]:
    stats: dict[str, int] = {}
    if not path.exists():
        return stats
    for line in path.read_text(encoding="utf-8", errors="replace").splitlines():
        match = re.match(r".+?\(([^)]+)\):\s*(\d+)", line)
        if match:
            stats[match.group(1)] = int(match.group(2))
    return stats


def summarize_run(run_dir: Path) -> dict[str, object]:
    summary = load_json(run_dir / "summary.json")
    lat = collect_latencies(run_dir)
    fdp_before = parse_fdp_stats(run_dir / "outer_fdp_stats_before.txt")
    fdp_after = parse_fdp_stats(run_dir / "outer_fdp_stats_after.txt")
    return {
        "run_dir": str(run_dir),
        "mode": summary.get("mode") or run_dir.name,
        "run_id": summary.get("run_id", run_dir.name),
        "waf": summary.get("waf"),
        "waf_status": summary.get("waf_status"),
        "host_write_bytes_delta": summary.get("host_write_bytes_delta"),
        "media_write_bytes_delta": summary.get("media_write_bytes_delta"),
        "worker_count": summary.get("worker_count"),
        "warmup_iterations": summary.get("warmup_iterations"),
        "measurement_iterations": summary.get("measurement_iterations"),
        "records": lat["records"],
        "failures": lat["failures"],
        "read_p50_ms": lat["read_p50_ms"],
        "read_p99_ms": lat["read_p99_ms"],
        "write_p50_ms": lat["write_p50_ms"],
        "write_p99_ms": lat["write_p99_ms"],
        "all_p99_ms": lat["all_p99_ms"],
        "fdp_hbmw_delta": fdp_after.get("HBMW", 0) - fdp_before.get("HBMW", 0),
        "fdp_mbmw_delta": fdp_after.get("MBMW", 0) - fdp_before.get("MBMW", 0),
        "fdp_mbe_delta": fdp_after.get("MBE", 0) - fdp_before.get("MBE", 0),
    }


def run_sort_key(run_dir: Path) -> tuple[int, str]:
    summary = load_json(run_dir / "summary.json")
    mode = str(summary.get("mode") or "")
    return (MODE_ORDER.get(mode, 99), run_dir.name)


def main() -> int:
    if len(sys.argv) != 2:
        print(f"usage: {sys.argv[0]} OUT_DIR", file=sys.stderr)
        return 2
    root = Path(sys.argv[1])
    runs = sorted(
        (p for p in root.iterdir() if p.is_dir() and (p / "summary.json").exists()),
        key=run_sort_key,
    )
    fields = [
        "run_dir",
        "mode",
        "run_id",
        "waf",
        "waf_status",
        "host_write_bytes_delta",
        "media_write_bytes_delta",
        "worker_count",
        "warmup_iterations",
        "measurement_iterations",
        "records",
        "failures",
        "read_p50_ms",
        "read_p99_ms",
        "write_p50_ms",
        "write_p99_ms",
        "all_p99_ms",
        "fdp_hbmw_delta",
        "fdp_mbmw_delta",
        "fdp_mbe_delta",
    ]
    writer = csv.DictWriter(sys.stdout, fieldnames=fields)
    writer.writeheader()
    for run_dir in runs:
        writer.writerow(summarize_run(run_dir))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
