#!/usr/bin/env python3
"""
Memory Ingestion — Extract memories from existing sessions

Reads all session transcripts and uses an LLM to extract:
1. Facts about the user (user profile)
2. Important things to remember (agent memory)
3. Entities (people, companies, topics)

Writes to ~/.hermes/memories/ and user profile so the agent
remembers everything from past conversations going forward.
"""

import json
import logging
import ssl
import sys
import urllib.request
from datetime import datetime
from pathlib import Path

logging.basicConfig(level=logging.INFO, format="[memory] %(message)s")
logger = logging.getLogger(__name__)

HERMES_HOME = Path.home() / ".hermes"
SESSIONS_DIR = HERMES_HOME / "sessions"
MEMORIES_DIR = HERMES_HOME / "memories"
PROCESSED_FILE = MEMORIES_DIR / ".ingested_sessions"

# API config
API_URL = "https://openrouter.ai/api/v1/chat/completions"
MODEL = "qwen/qwen3.6-plus:free"


def get_api_key():
    """Load OpenRouter API key."""
    key_file = HERMES_HOME / "secrets" / "openrouter.key"
    if key_file.exists():
        return key_file.read_text().strip()
    env_file = Path.home() / "hermes-agent" / ".env"
    if env_file.exists():
        for line in env_file.read_text().splitlines():
            if line.startswith("OPENROUTER_API_KEY="):
                return line.split("=", 1)[1].strip()
    return ""


def llm_call(prompt: str, api_key: str) -> str:
    """Make a single LLM call."""
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}",
    }
    body = json.dumps(
        {
            "model": MODEL,
            "messages": [{"role": "user", "content": prompt}],
            "max_tokens": 1500,
        }
    ).encode()

    req = urllib.request.Request(API_URL, data=body, headers=headers)
    ctx = ssl.create_default_context()

    with urllib.request.urlopen(req, timeout=45, context=ctx) as resp:
        data = json.loads(resp.read())
        return data["choices"][0]["message"]["content"].strip()


def load_session(filepath: Path) -> list:
    """Load messages from a session file (JSON or JSONL)."""
    messages = []
    try:
        if filepath.suffix == ".jsonl":
            with open(filepath) as f:
                for line in f:
                    line = line.strip()
                    if line:
                        try:
                            messages.append(json.loads(line))
                        except json.JSONDecodeError:
                            pass
        else:
            with open(filepath) as f:
                data = json.load(f)
                if isinstance(data, list):
                    messages = data
                elif isinstance(data, dict):
                    messages = data.get("messages", [data])
    except Exception as e:
        logger.warning(f"Could not load {filepath.name}: {e}")
    return messages


def extract_conversation_text(messages: list, max_chars: int = 6000) -> str:
    """Extract user/assistant conversation text from messages."""
    lines = []
    total = 0
    for msg in messages:
        role = msg.get("role", "")
        content = msg.get("content", "") or ""
        if role in ("user", "assistant") and content and not msg.get("tool_calls"):
            line = f"{role}: {content[:500]}"
            if total + len(line) > max_chars:
                break
            lines.append(line)
            total += len(line)
    return "\n".join(lines)


def get_processed():
    """Get set of already-processed session filenames."""
    if PROCESSED_FILE.exists():
        return set(PROCESSED_FILE.read_text().strip().split("\n"))
    return set()


def mark_processed(filename: str):
    """Mark a session as ingested."""
    with open(PROCESSED_FILE, "a") as f:
        f.write(filename + "\n")


def extract_memories(conversation: str, api_key: str) -> dict:
    """Use LLM to extract memories from a conversation."""
    prompt = f"""Analyze this conversation between a user and an AI assistant. Extract:

1. USER_FACTS: Facts about the user — their name, role, interests, preferences, location, company,
   family, anything personal they revealed. One fact per line.

2. MEMORIES: Important things the assistant should remember for future conversations — decisions made,
   promises given, topics discussed, action items, opinions expressed. One memory per line.

3. ENTITIES: People, companies, products, or topics mentioned that matter.
   Format: name — brief description. One per line.

If a section has nothing, write "none".

CONVERSATION:
{conversation}

Respond in EXACTLY this format:
USER_FACTS:
- fact 1
- fact 2

MEMORIES:
- memory 1
- memory 2

ENTITIES:
- entity — description"""

    try:
        response = llm_call(prompt, api_key)
        return parse_extraction(response)
    except Exception as e:
        logger.error(f"LLM extraction failed: {e}")
        return {"user_facts": [], "memories": [], "entities": []}


def parse_extraction(response: str) -> dict:
    """Parse the structured LLM response."""
    result = {"user_facts": [], "memories": [], "entities": []}
    current_section = None

    for line in response.split("\n"):
        line = line.strip()
        if line.upper().startswith("USER_FACTS"):
            current_section = "user_facts"
        elif line.upper().startswith("MEMORIES"):
            current_section = "memories"
        elif line.upper().startswith("ENTITIES"):
            current_section = "entities"
        elif line.startswith("- ") and current_section:
            item = line[2:].strip()
            if item.lower() != "none" and item:
                result[current_section].append(item)

    return result


def save_memories(extraction: dict, session_name: str):
    """Save extracted memories to the Hermes memory system."""
    MEMORIES_DIR.mkdir(parents=True, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Save individual memories as files (Hermes reads these)
    for i, memory in enumerate(extraction["memories"]):
        mem_file = MEMORIES_DIR / f"{timestamp}_{session_name}_{i}.md"
        mem_file.write_text(memory)

    # Append to user profile
    if extraction["user_facts"]:
        profile_file = HERMES_HOME / "user_profile.md"
        existing = profile_file.read_text() if profile_file.exists() else ""

        new_facts = []
        for fact in extraction["user_facts"]:
            # Don't duplicate facts already in profile
            if fact.lower() not in existing.lower():
                new_facts.append(fact)

        if new_facts:
            with open(profile_file, "a") as f:
                f.write(f"\n# From session {session_name} ({timestamp})\n")
                for fact in new_facts:
                    f.write(f"- {fact}\n")

    # Save entities to a central file
    if extraction["entities"]:
        entities_file = MEMORIES_DIR / "entities.jsonl"
        with open(entities_file, "a") as f:
            for entity in extraction["entities"]:
                f.write(
                    json.dumps(
                        {
                            "entity": entity,
                            "source": session_name,
                            "timestamp": timestamp,
                        }
                    )
                    + "\n"
                )

    return (
        len(extraction["memories"]),
        len(extraction["user_facts"]),
        len(extraction["entities"]),
    )


def main():
    import argparse

    parser = argparse.ArgumentParser(description="Ingest memories from sessions")
    parser.add_argument(
        "--all",
        action="store_true",
        help="Process all sessions (ignore already-processed)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Show what would be extracted without saving",
    )
    parser.add_argument("--session", help="Process a specific session file")
    args = parser.parse_args()

    api_key = get_api_key()
    if not api_key:
        logger.error("No OpenRouter API key found")
        sys.exit(1)

    MEMORIES_DIR.mkdir(parents=True, exist_ok=True)

    if args.session:
        session_files = [Path(args.session)]
        processed = set()
    else:
        session_files = sorted(SESSIONS_DIR.glob("*"))
        processed = set() if args.all else get_processed()

    # Filter to files with actual conversation content
    to_process = []
    for sf in session_files:
        if sf.name in processed or sf.name.startswith("."):
            continue
        messages = load_session(sf)
        conv_text = extract_conversation_text(messages)
        if len(conv_text) > 100:  # Skip near-empty sessions
            to_process.append((sf, conv_text))

    if not to_process:
        logger.info("No new sessions to process.")
        return

    logger.info(f"Processing {len(to_process)} sessions...")

    total_memories = 0
    total_facts = 0
    total_entities = 0

    for sf, conv_text in to_process:
        logger.info(f"\n  Processing: {sf.name} ({len(conv_text)} chars)")

        extraction = extract_memories(conv_text, api_key)

        if args.dry_run:
            print(f"\n  === {sf.name} ===")
            print(f"  User facts: {extraction['user_facts']}")
            print(f"  Memories: {extraction['memories']}")
            print(f"  Entities: {extraction['entities']}")
        else:
            m, f, e = save_memories(extraction, sf.stem)
            total_memories += m
            total_facts += f
            total_entities += e
            mark_processed(sf.name)
            logger.info(f"    Saved: {m} memories, {f} new facts, {e} entities")

    logger.info(
        f"\nDone: {total_memories} memories, {total_facts} user facts, "
        f"{total_entities} entities from {len(to_process)} sessions"
    )


if __name__ == "__main__":
    main()
