#!/usr/bin/env python3
"""
Persistent Ingestion Script for Engram.
"""

import argparse
import json
import logging
import os
import shutil
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Set


def setup_logging(log_dir: Path) -> logging.Logger:
    log_file = log_dir / "ingestion.log"
    log_dir.mkdir(exist_ok=True, parents=True)
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s %(levelname)s %(message)s",
        handlers=[
            logging.FileHandler(log_file, mode="a"),
            logging.StreamHandler(),
        ],
    )
    return logging.getLogger(__name__)


def atomic_dump(checkpoint: Path, data: Dict):
    temp = checkpoint.with_suffix(".tmp")
    with open(temp, "w") as f:
        json.dump(data, f, indent=2)
    os.replace(temp, checkpoint)  # Atomic on Unix


def get_all_files(source: Path) -> List[Path]:
    extensions = {".md", ".txt", ".py", ".json", ".yaml", ".yml", ".log"}
    return sorted([p for p in source.rglob("*") if p.is_file() and p.suffix.lower() in extensions])


def main():
    parser = argparse.ArgumentParser(description="Persistent ingest with checkpoints.")
    parser.add_argument("--dry-run", action="store_true", help="Dry run without copying.")
    parser.add_argument(
        "--source", default=str(Path.home() / ".local" / "share" / "engram" / "entities"), help="Source dir"
    )
    parser.add_argument("--target", default="entities", help="Target dir relative to cwd")
    args = parser.parse_args()

    source = Path(args.source)
    target = Path.cwd() / args.target
    checkpoint = Path.cwd() / "checkpoint_state.json"
    log_dir = Path.cwd() / "logs"

    logger = setup_logging(log_dir)
    target.mkdir(exist_ok=True, parents=True)

    if not source.exists():
        logger.error(f"Source {source} does not exist!")
        return 1

    all_files = get_all_files(source)
    total = len(all_files)
    logger.info(f"[START] Source: {source}, Total: {total}")

    if total == 0:
        logger.info("No files to process.")
        return 0

    # Load checkpoint
    processed: Set[str] = set()
    start_time_str = datetime.now().isoformat()
    try:
        with open(checkpoint) as f:
            state = json.load(f)
            processed = set(state.get("processed", []))
    except (FileNotFoundError, json.JSONDecodeError, KeyError):
        logger.info("No prior checkpoint, starting fresh.")

    pending = [f for f in all_files if str(f) not in processed]
    processed_count = len(all_files) - len(pending)
    logger.info(f"Previously processed: {processed_count}/{total}. Pending: {len(pending)}")

    last_status = time.time()
    for i, src_file in enumerate(pending, processed_count + 1):
        try:
            if src_file.stat().st_size > 10 * 1024**2:  # 10MB
                logger.warning(f"Skipping large: {src_file}")
                continue

            tgt_name = src_file.name
            target_file = target / tgt_name
            counter = 1
            while target_file.exists() and not args.dry_run:
                stem, suffix = target_file.stem, target_file.suffix
                tgt_name = f"{stem}_{counter}{suffix}"
                target_file = target / tgt_name
                counter += 1

            if args.dry_run:
                logger.info(f"DRY [{i}/{total}]: {src_file.name} -> {tgt_name}")
            else:
                shutil.copy2(src_file, target_file)
                logger.info(f"OK  [{i}/{total}]: {src_file.name} -> {tgt_name}")

            processed.add(str(src_file))
            state = {
                "processed": list(processed),
                "total": total,
                "start_time": start_time_str,
                "complete": len(processed) == total,
            }
            atomic_dump(checkpoint, state)

        except Exception as e:
            logger.error(f"FAIL [{i}/{total}] {src_file}: {type(e).__name__}: {e}")
            continue

        # Progress
        now = time.time()
        if i % 50 == 0 or now - last_status > 300:
            logger.info(f"[PROGRESS] {i}/{total} ({i / total * 100:.1f}%). Current: {src_file.name}")
            last_status = now

    final_count = len([f for f in all_files if str(f) in processed])
    if final_count == total:
        logger.info("100% Complete!")
        return 0
    else:
        logger.error(f"Incomplete: {final_count}/{total}")
        return 1


if __name__ == "__main__":
    import sys

    sys.exit(main())
