#!/usr/bin/env python3
"""Engram retrieval latency benchmark with percentile reporting.

Measures ripgrep query latency against the entity vault at varying scales.
Reports p50 (median) and p99 (tail) latencies — the metrics that matter
for real-time LLM tool-call budgets.

Usage:
    python scripts/benchmark.py                        # benchmark live vault
    python scripts/benchmark.py --synthetic 50000      # generate + benchmark
    python scripts/benchmark.py --rounds 20 --json     # JSON output for CI
"""

from __future__ import annotations

import argparse
import json as json_mod
import math
import subprocess
import sys
import time
from pathlib import Path

from engram.config import ENTITIES_DIR

QUERIES: list[str] = [
    "memory",
    "jetson",
    "sandbox",
    "cron",
    "yoga",
    r"def\s+\w+",  # regex: function definitions
    "nonexistent_token",  # guaranteed miss — measures worst-case full scan
]


def percentile(sorted_data: list[float], pct: float) -> float:
    """Return the *pct*-th percentile from pre-sorted data."""
    if not sorted_data:
        return 0.0
    k = (len(sorted_data) - 1) * (pct / 100.0)
    f = math.floor(k)
    c = math.ceil(k)
    if f == c:
        return sorted_data[int(k)]
    return sorted_data[f] * (c - k) + sorted_data[c] * (k - f)


def run_query(query: str, target: Path) -> tuple[float, int]:
    """Run a single ripgrep query and return (elapsed_ms, match_count)."""
    t0 = time.perf_counter()
    result = subprocess.run(
        ["rg", "--no-heading", "--count", query, str(target)],
        capture_output=True,
        text=True,
    )
    elapsed_ms = (time.perf_counter() - t0) * 1000

    matches = 0
    if result.returncode == 0:
        for line in result.stdout.strip().splitlines():
            parts = line.rsplit(":", 1)
            if len(parts) == 2 and parts[1].isdigit():
                matches += int(parts[1])

    return elapsed_ms, matches


def count_files(target: Path) -> int:
    """Count *.md files in the target directory."""
    return sum(1 for _ in target.rglob("*.md"))


def generate_synthetic(target: Path, n: int) -> None:
    """Create *n* synthetic .md files with realistic content variance."""
    target.mkdir(parents=True, exist_ok=True)
    print(f"Generating {n:,} synthetic files in {target}...")
    for i in range(n):
        content = f"---\ntopics: [Memory]\nsummary: Synthetic session {i}\n---\n"
        if i % 3 == 0:
            content += "Discussing memory and engram recall patterns.\n"
        if i % 5 == 0:
            content += "Jetson nano board setup with jetpack SDK.\n"
        if i % 7 == 0:
            content += "Cron daemon export sync schedule.\n"
        if i % 11 == 0:
            content += "def process_data(batch):\n    return transform(batch)\n"
        content += f"Session {i} body content for benchmarking.\n"
        (target / f"synthetic_{i:06d}.md").write_text(content)
    print(f"  Done: {n:,} files created.")


def main() -> None:
    parser = argparse.ArgumentParser(description="Engram retrieval benchmark with percentile latencies")
    parser.add_argument(
        "--synthetic",
        type=int,
        metavar="N",
        help="Generate N synthetic .md files before benchmarking",
    )
    parser.add_argument(
        "--target",
        type=Path,
        default=ENTITIES_DIR,
        help=f"Directory to benchmark (default: {ENTITIES_DIR})",
    )
    parser.add_argument(
        "--rounds",
        type=int,
        default=10,
        help="Number of rounds per query (default: 10)",
    )
    parser.add_argument(
        "--json",
        action="store_true",
        help="Output results as JSON (for CI integration)",
    )
    args = parser.parse_args()

    # --- preflight ----------------------------------------------------------
    try:
        rg_version = subprocess.run(["rg", "--version"], capture_output=True, text=True, check=True)
        rg_ver = rg_version.stdout.strip().splitlines()[0]
    except FileNotFoundError:
        print("Error: ripgrep (rg) not found on PATH.", file=sys.stderr)
        sys.exit(1)

    if args.synthetic:
        generate_synthetic(args.target, args.synthetic)

    file_count = count_files(args.target)
    if file_count == 0:
        print(f"No .md files found in {args.target}.", file=sys.stderr)
        sys.exit(1)

    # --- run benchmark ------------------------------------------------------
    results: list[dict] = []

    if not args.json:
        print("\nEngram Retrieval Benchmark")
        print(f"{'=' * 72}")
        print(f"Target:     {args.target}")
        print(f"Files:      {file_count:,}")
        print(f"Rounds:     {args.rounds}")
        print(f"Engine:     {rg_ver}")
        print(f"{'=' * 72}\n")
        print(f"{'Query':<22} {'p50 (ms)':>9} {'p99 (ms)':>9} {'Avg (ms)':>9} {'Min':>8} {'Max':>8} {'Matches':>9}")
        print(f"{'-' * 22} {'-' * 9} {'-' * 9} {'-' * 9} {'-' * 8} {'-' * 8} {'-' * 9}")

    for query in QUERIES:
        timings: list[float] = []
        last_matches = 0

        for _ in range(args.rounds):
            elapsed, matches = run_query(query, args.target)
            timings.append(elapsed)
            last_matches = matches

        timings.sort()
        p50 = percentile(timings, 50)
        p99 = percentile(timings, 99)
        avg = sum(timings) / len(timings)

        entry = {
            "query": query,
            "rounds": args.rounds,
            "p50_ms": round(p50, 1),
            "p99_ms": round(p99, 1),
            "avg_ms": round(avg, 1),
            "min_ms": round(min(timings), 1),
            "max_ms": round(max(timings), 1),
            "matches": last_matches,
        }
        results.append(entry)

        if not args.json:
            label = query if len(query) <= 21 else query[:18] + "..."
            print(
                f"{label:<22} {p50:>9.1f} {p99:>9.1f} {avg:>9.1f} "
                f"{min(timings):>8.1f} {max(timings):>8.1f} {last_matches:>9,}"
            )

    # --- aggregate stats ----------------------------------------------------
    all_p50s = [r["p50_ms"] for r in results]
    all_p99s = [r["p99_ms"] for r in results]

    summary = {
        "target": str(args.target),
        "file_count": file_count,
        "rounds_per_query": args.rounds,
        "engine": rg_ver,
        "aggregate_p50_ms": round(sum(all_p50s) / len(all_p50s), 1),
        "aggregate_p99_ms": round(max(all_p99s), 1),
        "queries": results,
    }

    if args.json:
        print(json_mod.dumps(summary, indent=2))
    else:
        print(f"\n{'=' * 72}")
        print(f"Aggregate:  p50 = {summary['aggregate_p50_ms']:.1f} ms  |  p99 = {summary['aggregate_p99_ms']:.1f} ms")
        print(f"{'=' * 72}")
        print("Done.")


if __name__ == "__main__":
    main()
