#!/usr/bin/env python3
"""
Rook Hyperagent — Phase 1: Skill Self-Improvement

Analyzes agent session data to identify underperforming skills,
proposes improvements, tests them against historical sessions,
and promotes successful variants.

This is the entry point for self-improving behavior. The agent
doesn't just use skills — it makes them better over time.

Architecture:
    Training DBs → Identify weak skills → LLM proposes fix →
    Test against saved sessions → Promote or archive

Runs inside Citadel sandbox. Uses the same free-tier models
the agent uses for chat — no additional cost.

Usage:
    python3 skill_improver.py run           # One improvement cycle
    python3 skill_improver.py status        # Show skill performance
    python3 skill_improver.py history       # Show improvement history
    python3 skill_improver.py rollback <id> # Revert a skill change
"""

import json
import logging
import shutil
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional

logging.basicConfig(level=logging.INFO, format="[hyperagent] %(message)s")
logger = logging.getLogger(__name__)

# --- Paths ---

HERMES_HOME = Path.home() / ".hermes"
SKILLS_DIR = HERMES_HOME / "skills"
TRAINING_DB = Path("/opt/hermes-data/training/default/user-sessions.db")
TOOL_DB = Path("/opt/hermes-data/training/default/tool-calls.db")
HYPERAGENT_DIR = HERMES_HOME / ".hyperagent"
ARCHIVE_DIR = HYPERAGENT_DIR / "archive"
HISTORY_FILE = HYPERAGENT_DIR / "history.jsonl"
CONFIG_FILE = HYPERAGENT_DIR / "config.json"

# --- Defaults ---

DEFAULT_CONFIG = {
    "max_improvements_per_cycle": 3,
    "min_sessions_for_evaluation": 5,
    "improvement_model": "qwen/qwen3.6-plus:free",
    "improvement_provider": "openrouter",
    "test_threshold": 0.6,  # New skill must score > this to be promoted
    "daily_budget": 10,  # Max improvement iterations per day
    "enabled": True,
}


# =========================================================================
# Skill Performance Analysis
# =========================================================================


class SkillAnalyzer:
    """Analyzes skill performance from training databases."""

    def __init__(self):
        self.conv_db = None
        self.tool_db = None
        if TRAINING_DB.exists():
            self.conv_db = sqlite3.connect(TRAINING_DB)
        if TOOL_DB.exists():
            self.tool_db = sqlite3.connect(TOOL_DB)

    def get_skill_usage(self) -> Dict[str, Dict]:
        """Get usage stats for each skill from tool call data."""
        if not self.tool_db:
            return {}

        rows = self.tool_db.execute("""
            SELECT tool_name, COUNT(*) as uses,
                   SUM(CASE WHEN tool_response IS NOT NULL AND tool_response != '' THEN 1 ELSE 0 END) as successes
            FROM tool_calls
            GROUP BY tool_name
            ORDER BY uses DESC
        """).fetchall()

        return {
            row[0]: {
                "uses": row[1],
                "successes": row[2],
                "success_rate": row[2] / row[1] if row[1] > 0 else 0,
            }
            for row in rows
        }

    def get_session_quality_signals(self) -> List[Dict]:
        """Extract quality signals from conversation data."""
        if not self.conv_db:
            return []

        rows = self.conv_db.execute("""
            SELECT session_id, role, content
            FROM conversations
            ORDER BY session_id, turn_index
        """).fetchall()

        sessions = {}
        for session_id, role, content in rows:
            if session_id not in sessions:
                sessions[session_id] = {"messages": [], "signals": {}}
            sessions[session_id]["messages"].append({"role": role, "content": content})

        # Analyze signals per session
        results = []
        for session_id, data in sessions.items():
            signals = {
                "session_id": session_id,
                "turn_count": len(data["messages"]),
                "has_retry": any(
                    "retry" in (m["content"] or "").lower() for m in data["messages"] if m["role"] == "user"
                ),
                "has_thanks": any(
                    w in (m["content"] or "").lower()
                    for m in data["messages"]
                    if m["role"] == "user"
                    for w in ["thanks", "thank you", "perfect", "great", "awesome"]
                ),
                "has_error": any(
                    "error" in (m["content"] or "").lower() for m in data["messages"] if m["role"] == "assistant"
                ),
                "avg_response_len": sum(len(m["content"] or "") for m in data["messages"] if m["role"] == "assistant")
                / max(1, sum(1 for m in data["messages"] if m["role"] == "assistant")),
            }
            results.append(signals)

        return results

    def get_installed_skills(self) -> List[Dict]:
        """List installed skills with their metadata."""
        skills = []
        if not SKILLS_DIR.exists():
            return skills

        for skill_dir in SKILLS_DIR.iterdir():
            if not skill_dir.is_dir():
                continue

            skill_info = {
                "name": skill_dir.name,
                "path": str(skill_dir),
                "has_skill_md": (skill_dir / "SKILL.md").exists(),
                "has_description": (skill_dir / "DESCRIPTION.md").exists(),
                "files": [f.name for f in skill_dir.rglob("*") if f.is_file()],
                "size_bytes": sum(f.stat().st_size for f in skill_dir.rglob("*") if f.is_file()),
            }

            # Read SKILL.md for metadata
            skill_md = skill_dir / "SKILL.md"
            if skill_md.exists():
                try:
                    content = skill_md.read_text()
                    skill_info["content_preview"] = content[:500]
                except Exception:
                    pass

            skills.append(skill_info)

        return skills

    def identify_improvement_targets(self) -> List[Dict]:
        """Identify skills that could benefit from improvement."""
        tool_usage = self.get_skill_usage()
        session_signals = self.get_session_quality_signals()

        targets = []

        # Target 1: Tools with low success rates
        for tool_name, stats in tool_usage.items():
            if stats["uses"] >= 3 and stats["success_rate"] < 0.7:
                targets.append(
                    {
                        "type": "low_success_rate",
                        "target": tool_name,
                        "reason": (
                            f"Tool '{tool_name}' has {stats['success_rate']:.0%} success rate"
                            f" ({stats['successes']}/{stats['uses']} calls)"
                        ),
                        "priority": 1 - stats["success_rate"],
                        "data": stats,
                    }
                )

        # Target 2: Sessions with retries (user dissatisfaction)
        retry_sessions = [s for s in session_signals if s["has_retry"]]
        if retry_sessions:
            targets.append(
                {
                    "type": "user_retry",
                    "target": "response_quality",
                    "reason": f"{len(retry_sessions)} sessions had user retries (dissatisfaction signal)",
                    "priority": len(retry_sessions) / max(1, len(session_signals)),
                    "data": {
                        "retry_count": len(retry_sessions),
                        "total_sessions": len(session_signals),
                    },
                }
            )

        # Target 3: Sessions with errors
        error_sessions = [s for s in session_signals if s["has_error"]]
        if error_sessions:
            targets.append(
                {
                    "type": "error_responses",
                    "target": "error_handling",
                    "reason": f"{len(error_sessions)} sessions contained error responses",
                    "priority": len(error_sessions) / max(1, len(session_signals)),
                    "data": {"error_count": len(error_sessions)},
                }
            )

        # Sort by priority (highest first)
        targets.sort(key=lambda t: t["priority"], reverse=True)
        return targets

    def close(self):
        if self.conv_db:
            self.conv_db.close()
        if self.tool_db:
            self.tool_db.close()


# =========================================================================
# Skill Modifier (LLM-based)
# =========================================================================


class SkillModifier:
    """Uses an LLM to propose skill improvements."""

    def __init__(self, model: str, provider: str, api_key: str = None):
        self.model = model
        self.provider = provider
        self.api_key = api_key or self._load_api_key()

    def _load_api_key(self) -> str:
        """Load API key from secrets."""
        key_file = HERMES_HOME / "secrets" / f"{self.provider}.key"
        if key_file.exists():
            return key_file.read_text().strip()
        # Fallback to .env
        env_file = Path.home() / "hermes-agent" / ".env"
        if env_file.exists():
            for line in env_file.read_text().splitlines():
                if line.startswith("OPENROUTER_API_KEY="):
                    return line.split("=", 1)[1].strip()
        return ""

    def propose_improvement(self, skill_content: str, target: Dict, session_examples: List[str]) -> Optional[str]:
        """Ask the LLM to propose an improved version of a skill."""
        if not self.api_key:
            logger.error("No API key available for improvement model")
            return None

        examples_text = "\n\n".join(session_examples[:3])  # Limit context

        prompt = f"""You are a meta-agent improving an AI assistant's skills.

## Current Skill Content
```
{skill_content[:2000]}
```

## Problem Identified
{target["reason"]}

## Relevant Session Examples
{examples_text[:2000]}

## Your Task
Propose an improved version of this skill that addresses the identified problem.
Return ONLY the improved skill content — no explanation, no markdown wrapping.
The output should be a drop-in replacement for the current skill file.

Focus on:
- Clearer instructions that reduce errors
- Better handling of edge cases
- More specific guidance for the agent
- Removing ambiguity that might cause retries

Keep the same format and structure. Make targeted improvements, not a complete rewrite."""

        try:
            import ssl
            import urllib.request

            url = "https://openrouter.ai/api/v1/chat/completions"
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {self.api_key}",
            }
            body = json.dumps(
                {
                    "model": self.model,
                    "messages": [{"role": "user", "content": prompt}],
                    "max_tokens": 2000,
                }
            ).encode()

            req = urllib.request.Request(url, data=body, headers=headers)
            ctx = ssl.create_default_context()

            with urllib.request.urlopen(req, timeout=30, context=ctx) as resp:
                data = json.loads(resp.read())
                return data["choices"][0]["message"]["content"].strip()

        except Exception as e:
            logger.error(f"LLM call failed: {e}")
            return None


# =========================================================================
# Skill Versioning & Archive
# =========================================================================


class SkillArchive:
    """Manages skill versions and rollback capability."""

    def __init__(self):
        ARCHIVE_DIR.mkdir(parents=True, exist_ok=True)
        HYPERAGENT_DIR.mkdir(parents=True, exist_ok=True)

    def backup_current(self, skill_name: str, skill_path: str) -> str:
        """Backup current skill before modification. Returns backup ID."""
        backup_id = f"{skill_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        backup_dir = ARCHIVE_DIR / backup_id

        if Path(skill_path).exists():
            shutil.copytree(skill_path, backup_dir)
            logger.info(f"Backed up {skill_name} → {backup_id}")

        return backup_id

    def rollback(self, backup_id: str) -> bool:
        """Restore a skill from backup."""
        backup_dir = ARCHIVE_DIR / backup_id
        if not backup_dir.exists():
            logger.error(f"Backup not found: {backup_id}")
            return False

        skill_name = backup_id.rsplit("_", 2)[0]
        skill_dir = SKILLS_DIR / skill_name

        if skill_dir.exists():
            shutil.rmtree(skill_dir)
        shutil.copytree(backup_dir, skill_dir)
        logger.info(f"Rolled back {skill_name} from {backup_id}")
        return True

    def list_backups(self) -> List[Dict]:
        """List all archived skill versions."""
        backups = []
        if ARCHIVE_DIR.exists():
            for d in sorted(ARCHIVE_DIR.iterdir(), reverse=True):
                if d.is_dir():
                    backups.append(
                        {
                            "id": d.name,
                            "skill": d.name.rsplit("_", 2)[0],
                            "timestamp": d.stat().st_mtime,
                            "size": sum(f.stat().st_size for f in d.rglob("*") if f.is_file()),
                        }
                    )
        return backups

    def record_improvement(self, entry: Dict):
        """Append to improvement history."""
        entry["timestamp"] = datetime.now().isoformat()
        with open(HISTORY_FILE, "a") as f:
            f.write(json.dumps(entry) + "\n")

    def get_history(self, limit: int = 20) -> List[Dict]:
        """Read improvement history."""
        if not HISTORY_FILE.exists():
            return []
        lines = HISTORY_FILE.read_text().strip().split("\n")
        entries = []
        for line in lines[-limit:]:
            try:
                entries.append(json.loads(line))
            except json.JSONDecodeError:
                pass
        return entries

    def get_daily_count(self) -> int:
        """Count improvements made today."""
        today = datetime.now().strftime("%Y-%m-%d")
        return sum(1 for entry in self.get_history(100) if entry.get("timestamp", "").startswith(today))


# =========================================================================
# Improvement Evaluator
# =========================================================================


class SkillEvaluator:
    """Evaluates whether a proposed skill improvement is actually better."""

    def __init__(self, modifier: SkillModifier):
        self.modifier = modifier

    def evaluate_improvement(
        self,
        original_content: str,
        improved_content: str,
        target: Dict,
        session_examples: List[str],
    ) -> Dict:
        """Ask the LLM to evaluate if the improvement is better."""
        prompt = f"""You are evaluating a proposed improvement to an AI skill.

## Original Skill
```
{original_content[:1000]}
```

## Proposed Improvement
```
{improved_content[:1000]}
```

## Problem Being Addressed
{target["reason"]}

## Evaluate
Rate the improvement on a scale of 0.0 to 1.0:
- 0.0 = worse than original
- 0.5 = about the same
- 1.0 = significantly better

Consider:
1. Does it address the identified problem?
2. Could it introduce new problems?
3. Is it clear and unambiguous?
4. Is it a targeted fix, not an unnecessary rewrite?

Respond with ONLY a JSON object: {{"score": 0.X, "reason": "brief explanation"}}"""

        try:
            import ssl
            import urllib.request

            url = "https://openrouter.ai/api/v1/chat/completions"
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {self.modifier.api_key}",
            }
            body = json.dumps(
                {
                    "model": self.modifier.model,
                    "messages": [{"role": "user", "content": prompt}],
                    "max_tokens": 200,
                }
            ).encode()

            req = urllib.request.Request(url, data=body, headers=headers)
            ctx = ssl.create_default_context()

            with urllib.request.urlopen(req, timeout=30, context=ctx) as resp:
                data = json.loads(resp.read())
                response = data["choices"][0]["message"]["content"].strip()

                # Parse JSON from response
                # Handle potential markdown wrapping
                if "```" in response:
                    response = response.split("```")[1].strip()
                    if response.startswith("json"):
                        response = response[4:].strip()

                result = json.loads(response)
                return {
                    "score": float(result.get("score", 0)),
                    "reason": result.get("reason", ""),
                }

        except Exception as e:
            logger.error(f"Evaluation failed: {e}")
            return {"score": 0.5, "reason": f"Evaluation error: {e}"}


# =========================================================================
# Main Improvement Loop
# =========================================================================


def load_config() -> Dict:
    """Load hyperagent configuration."""
    if CONFIG_FILE.exists():
        try:
            with open(CONFIG_FILE) as f:
                return {**DEFAULT_CONFIG, **json.load(f)}
        except Exception:
            pass
    return DEFAULT_CONFIG.copy()


def save_config(config: Dict):
    """Save hyperagent configuration."""
    HYPERAGENT_DIR.mkdir(parents=True, exist_ok=True)
    with open(CONFIG_FILE, "w") as f:
        json.dump(config, f, indent=2)


def run_improvement_cycle():
    """Run one cycle of skill self-improvement."""
    config = load_config()

    if not config.get("enabled", True):
        logger.info("Hyperagent is disabled. Enable in config.")
        return

    archive = SkillArchive()

    # Check daily budget
    daily_count = archive.get_daily_count()
    if daily_count >= config["daily_budget"]:
        logger.info(f"Daily budget reached ({daily_count}/{config['daily_budget']}). Skipping.")
        return

    # Analyze current performance
    logger.info("Analyzing skill performance...")
    analyzer = SkillAnalyzer()
    targets = analyzer.identify_improvement_targets()

    if not targets:
        logger.info("No improvement targets identified. Everything looks good.")
        analyzer.close()
        return

    logger.info(f"Found {len(targets)} improvement targets:")
    for t in targets[:5]:
        logger.info(f"  [{t['priority']:.2f}] {t['reason']}")

    # Set up modifier and evaluator
    modifier = SkillModifier(config["improvement_model"], config["improvement_provider"])
    evaluator = SkillEvaluator(modifier)

    # Get session examples for context
    session_examples = []
    if analyzer.conv_db:
        rows = analyzer.conv_db.execute(
            "SELECT role, content FROM conversations ORDER BY turn_index DESC LIMIT 20"
        ).fetchall()
        session_examples = [f"{role}: {content[:200]}" for role, content in rows]

    # Process top targets
    improvements_made = 0
    max_improvements = min(config["max_improvements_per_cycle"], config["daily_budget"] - daily_count)

    for target in targets[:max_improvements]:
        logger.info(f"\nProcessing: {target['reason']}")

        # Find relevant skill to improve
        skills = analyzer.get_installed_skills()
        if not skills:
            logger.info("No installed skills to improve.")
            break

        # Pick the most relevant skill (simple heuristic: first skill with SKILL.md)
        skill = next((s for s in skills if s["has_skill_md"]), skills[0])
        skill_path = Path(skill["path"])
        skill_file = skill_path / "SKILL.md"

        if not skill_file.exists():
            continue

        original_content = skill_file.read_text()

        # Propose improvement
        logger.info(f"  Proposing improvement for skill: {skill['name']}")
        improved_content = modifier.propose_improvement(original_content, target, session_examples)

        if not improved_content or improved_content == original_content:
            logger.info("  No improvement proposed. Skipping.")
            archive.record_improvement(
                {
                    "skill": skill["name"],
                    "target": target["reason"],
                    "status": "no_change",
                }
            )
            continue

        # Evaluate improvement
        logger.info("  Evaluating proposed improvement...")
        evaluation = evaluator.evaluate_improvement(original_content, improved_content, target, session_examples)

        logger.info(f"  Score: {evaluation['score']:.2f} — {evaluation['reason']}")

        if evaluation["score"] >= config["test_threshold"]:
            # Backup and apply
            backup_id = archive.backup_current(skill["name"], str(skill_path))
            skill_file.write_text(improved_content)
            logger.info(f"  ✓ PROMOTED — skill '{skill['name']}' improved (backup: {backup_id})")

            # Notify owner
            try:
                from notify import notify_owner

                notify_owner(
                    f"Improved skill <b>{skill['name']}</b>\n"
                    f"Reason: {target['reason'][:100]}\n"
                    f"Score: {evaluation['score']:.2f}\n"
                    f"Rollback: <code>rook hyperagent rollback {backup_id}</code>"
                )
            except Exception:
                pass

            archive.record_improvement(
                {
                    "skill": skill["name"],
                    "target": target["reason"],
                    "status": "promoted",
                    "score": evaluation["score"],
                    "reason": evaluation["reason"],
                    "backup_id": backup_id,
                }
            )
            improvements_made += 1
        else:
            logger.info(f"  ✗ REJECTED — score {evaluation['score']:.2f} below threshold {config['test_threshold']}")

            archive.record_improvement(
                {
                    "skill": skill["name"],
                    "target": target["reason"],
                    "status": "rejected",
                    "score": evaluation["score"],
                    "reason": evaluation["reason"],
                }
            )

    analyzer.close()
    logger.info(f"\nCycle complete: {improvements_made} improvements promoted.")


def show_status():
    """Show current skill performance and hyperagent status."""
    config = load_config()
    archive = SkillArchive()
    analyzer = SkillAnalyzer()

    print("\n⚔ Rook Hyperagent — Status\n")
    print(f"Enabled:          {config.get('enabled', True)}")
    print(f"Model:            {config.get('improvement_model')}")
    print(f"Daily budget:     {archive.get_daily_count()}/{config.get('daily_budget')} iterations used today")
    print(f"Threshold:        {config.get('test_threshold')}")

    skills = analyzer.get_installed_skills()
    print(f"\nInstalled skills: {len(skills)}")
    for s in skills[:10]:
        print(f"  {s['name']:30} {len(s['files']):3} files  {s['size_bytes']:>8} bytes")

    tool_usage = analyzer.get_skill_usage()
    if tool_usage:
        print("\nTool performance:")
        for name, stats in sorted(tool_usage.items(), key=lambda x: -x[1]["uses"])[:10]:
            bar = "█" * int(stats["success_rate"] * 10) + "░" * (10 - int(stats["success_rate"] * 10))
            print(f"  {name:25} {bar} {stats['success_rate']:.0%}  ({stats['uses']} calls)")

    targets = analyzer.identify_improvement_targets()
    if targets:
        print(f"\nImprovement targets: {len(targets)}")
        for t in targets[:5]:
            print(f"  [{t['priority']:.2f}] {t['reason']}")

    backups = archive.list_backups()
    if backups:
        print(f"\nArchived versions: {len(backups)}")
        for b in backups[:5]:
            print(f"  {b['id']}")

    analyzer.close()


def show_history():
    """Show improvement history."""
    archive = SkillArchive()
    history = archive.get_history(30)

    if not history:
        print("No improvement history yet.")
        return

    print("\n⚔ Rook Hyperagent — Improvement History\n")
    for entry in reversed(history):
        status_icon = "✓" if entry.get("status") == "promoted" else "✗" if entry.get("status") == "rejected" else "—"
        score = entry.get("score", "")
        score_str = f" ({score:.2f})" if isinstance(score, (int, float)) else ""
        print(
            f"  {status_icon} [{entry.get('timestamp', '?')[:16]}] {entry.get('skill', '?')} — "
            f"{entry.get('target', '?')[:60]}{score_str}"
        )


# =========================================================================
# CLI
# =========================================================================


def main():
    import argparse

    parser = argparse.ArgumentParser(description="Rook Hyperagent — Skill Self-Improvement")
    sub = parser.add_subparsers(dest="command")

    sub.add_parser("run", help="Run one improvement cycle")
    sub.add_parser("status", help="Show skill performance and hyperagent status")
    sub.add_parser("history", help="Show improvement history")

    rollback_cmd = sub.add_parser("rollback", help="Revert a skill change")
    rollback_cmd.add_argument("backup_id", help="Backup ID to restore")

    sub.add_parser("enable", help="Enable hyperagent")
    sub.add_parser("disable", help="Disable hyperagent")

    args = parser.parse_args()

    if args.command == "run":
        run_improvement_cycle()
    elif args.command == "status":
        show_status()
    elif args.command == "history":
        show_history()
    elif args.command == "rollback":
        archive = SkillArchive()
        archive.rollback(args.backup_id)
    elif args.command == "enable":
        config = load_config()
        config["enabled"] = True
        save_config(config)
        print("Hyperagent enabled.")
    elif args.command == "disable":
        config = load_config()
        config["enabled"] = False
        save_config(config)
        print("Hyperagent disabled.")
    else:
        parser.print_help()


if __name__ == "__main__":
    main()
