#!/usr/bin/env python3
"""
Memory Deduplication — Keeps user profile and memories clean

Runs after ingestion to:
1. Deduplicate user_profile.md using LLM to merge similar facts
2. Deduplicate memory files (remove near-identical entries)
3. Deduplicate entities.jsonl

Can run standalone or is called automatically by ingest_memories.py
"""

import json
import logging
import ssl
import sys
import urllib.request
from pathlib import Path

logging.basicConfig(level=logging.INFO, format="[dedup] %(message)s")
logger = logging.getLogger(__name__)

HERMES_HOME = Path.home() / ".hermes"
MEMORIES_DIR = HERMES_HOME / "memories"
PROFILE_FILE = HERMES_HOME / "user_profile.md"

API_URL = "https://openrouter.ai/api/v1/chat/completions"
MODEL = "qwen/qwen3.6-plus:free"


def get_api_key():
    key_file = HERMES_HOME / "secrets" / "openrouter.key"
    if key_file.exists():
        return key_file.read_text().strip()
    env_file = Path.home() / "hermes-agent" / ".env"
    if env_file.exists():
        for line in env_file.read_text().splitlines():
            if line.startswith("OPENROUTER_API_KEY="):
                return line.split("=", 1)[1].strip()
    return ""


def llm_call(prompt: str, api_key: str) -> str:
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}",
    }
    body = json.dumps(
        {
            "model": MODEL,
            "messages": [{"role": "user", "content": prompt}],
            "max_tokens": 2000,
        }
    ).encode()
    req = urllib.request.Request(API_URL, data=body, headers=headers)
    ctx = ssl.create_default_context()
    with urllib.request.urlopen(req, timeout=45, context=ctx) as resp:
        data = json.loads(resp.read())
        return data["choices"][0]["message"]["content"].strip()


def dedup_profile(api_key: str):
    """Deduplicate user_profile.md using LLM to merge similar facts."""
    if not PROFILE_FILE.exists():
        logger.info("No user profile to deduplicate.")
        return

    current = PROFILE_FILE.read_text().strip()
    if not current:
        return

    # Extract all facts
    facts = []
    for line in current.split("\n"):
        line = line.strip()
        if line.startswith("- "):
            facts.append(line[2:])

    if len(facts) < 2:
        logger.info(f"Only {len(facts)} facts, nothing to dedup.")
        return

    logger.info(f"Deduplicating {len(facts)} user profile facts...")

    prompt = f"""Here are facts collected about a user from multiple conversations.
Many are duplicates or say the same thing in different words.

FACTS:
{chr(10).join(f"- {f}" for f in facts)}

Merge these into a clean, deduplicated list. Rules:
- Keep the most specific/detailed version of each fact
- Combine similar facts into one
- Remove exact or near duplicates
- Keep the same format: one fact per line starting with "- "
- Do NOT add any facts that aren't in the original list
- Do NOT add headers, commentary, or explanation

Output ONLY the deduplicated facts, one per line starting with "- "."""

    try:
        response = llm_call(prompt, api_key)
        # Parse the clean facts
        clean_facts = []
        for line in response.split("\n"):
            line = line.strip()
            if line.startswith("- "):
                clean_facts.append(line[2:])

        if not clean_facts:
            logger.warning("LLM returned no facts — keeping original.")
            return

        # Write clean profile
        PROFILE_FILE.write_text(
            "# User Profile (auto-maintained)\n\n" + "\n".join(f"- {f}" for f in clean_facts) + "\n"
        )
        logger.info(f"Profile: {len(facts)} facts → {len(clean_facts)} facts ({len(facts) - len(clean_facts)} removed)")

    except Exception as e:
        logger.error(f"Profile dedup failed: {e}")


def dedup_memories():
    """Deduplicate memory files by removing near-identical content."""
    if not MEMORIES_DIR.exists():
        return

    mem_files = sorted(MEMORIES_DIR.glob("*.md"))
    if len(mem_files) < 2:
        return

    # Read all memories and their content
    memories = {}
    for mf in mem_files:
        content = mf.read_text().strip().lower()
        if content:
            memories[mf] = content

    # Simple similarity: if one memory's content is a substring of another, remove the shorter one
    to_remove = set()
    mem_list = list(memories.items())

    for i in range(len(mem_list)):
        if mem_list[i][0] in to_remove:
            continue
        for j in range(i + 1, len(mem_list)):
            if mem_list[j][0] in to_remove:
                continue

            a_content = mem_list[i][1]
            b_content = mem_list[j][1]

            # Check for high similarity (one contains the other, or very similar)
            if a_content in b_content:
                to_remove.add(mem_list[i][0])
            elif b_content in a_content:
                to_remove.add(mem_list[j][0])
            elif _similarity(a_content, b_content) > 0.8:
                # Keep the longer one
                if len(a_content) >= len(b_content):
                    to_remove.add(mem_list[j][0])
                else:
                    to_remove.add(mem_list[i][0])

    # Remove duplicates
    for mf in to_remove:
        mf.unlink()

    if to_remove:
        logger.info(f"Memories: removed {len(to_remove)} duplicates, {len(memories) - len(to_remove)} remain")
    else:
        logger.info(f"Memories: {len(memories)} files, no duplicates found")


def _similarity(a: str, b: str) -> float:
    """Simple word-overlap similarity (Jaccard)."""
    words_a = set(a.split())
    words_b = set(b.split())
    if not words_a or not words_b:
        return 0.0
    intersection = words_a & words_b
    union = words_a | words_b
    return len(intersection) / len(union)


def dedup_entities():
    """Deduplicate entities.jsonl."""
    entities_file = MEMORIES_DIR / "entities.jsonl"
    if not entities_file.exists():
        return

    lines = entities_file.read_text().strip().split("\n")
    seen = {}
    unique = []

    for line in lines:
        try:
            entry = json.loads(line)
            entity = entry.get("entity", "").strip().lower()
            # Keep the first occurrence (or longest description)
            key = entity.split("—")[0].strip() if "—" in entity else entity
            if key and key not in seen:
                seen[key] = True
                unique.append(line)
        except json.JSONDecodeError:
            pass

    entities_file.write_text("\n".join(unique) + "\n")
    removed = len(lines) - len(unique)
    if removed:
        logger.info(f"Entities: {len(lines)} → {len(unique)} ({removed} duplicates removed)")
    else:
        logger.info(f"Entities: {len(unique)} entries, no duplicates")


def main():
    api_key = get_api_key()
    if not api_key:
        logger.error("No API key found")
        sys.exit(1)

    logger.info("Running deduplication...")
    dedup_profile(api_key)
    dedup_memories()
    dedup_entities()
    logger.info("Deduplication complete.")


if __name__ == "__main__":
    main()
