#!/usr/bin/env python3
"""
Rook Outbound Auth Proxy

Runs on the HOST (outside the Citadel sandbox). Intercepts outbound HTTP
requests from the sandbox, matches the destination domain to a stored API
credential, and injects the appropriate Authorization header before forwarding.

Agents inside the sandbox make requests with NO API keys. This proxy adds
them transparently. If an agent is compromised, it can make API calls through
the proxy but can NEVER extract the raw keys.

Usage:
    python3 proxy.py                    # Start on :9090
    python3 proxy.py --port 9090        # Explicit port
    python3 proxy.py --config /path/to/proxy-config.json

Architecture:
    Sandbox Agent → HTTP request (no auth) → Proxy (:9090) → injects auth → Provider API
"""

import argparse
import http.client
import json
import logging
import ssl
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from urllib.parse import urlparse

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [proxy] %(levelname)s %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)

# --- Config ---

DEFAULT_CONFIG_PATH = "/etc/hermes/proxy-config.json"
DEFAULT_SECRETS_DIR = "/etc/hermes/secrets"
DEFAULT_PORT = 9090

# Domain → provider ID mapping
DOMAIN_MAP = {
    "generativelanguage.googleapis.com": "gemini",
    "integrate.api.nvidia.com": "nim",
    "openrouter.ai": "openrouter",
    "api.groq.com": "groq",
    "router.huggingface.co": "huggingface",
    "api.anthropic.com": "anthropic",
    "api.openai.com": "openai",
    "api.x.ai": "xai",
    "api.mistral.ai": "mistral",
}

# Provider auth config
PROVIDER_AUTH = {
    "gemini": {"method": "query", "param": "key"},
    "nim": {"method": "header", "header": "Authorization", "prefix": "Bearer "},
    "openrouter": {"method": "header", "header": "Authorization", "prefix": "Bearer "},
    "groq": {"method": "header", "header": "Authorization", "prefix": "Bearer "},
    "huggingface": {"method": "header", "header": "Authorization", "prefix": "Bearer "},
    "anthropic": {"method": "header", "header": "x-api-key", "prefix": ""},
    "openai": {"method": "header", "header": "Authorization", "prefix": "Bearer "},
    "xai": {"method": "header", "header": "Authorization", "prefix": "Bearer "},
    "mistral": {"method": "header", "header": "Authorization", "prefix": "Bearer "},
}


class ProxyConfig:
    def __init__(
        self,
        config_path: str = DEFAULT_CONFIG_PATH,
        secrets_dir: str = DEFAULT_SECRETS_DIR,
    ):
        self.secrets_dir = Path(secrets_dir)
        self.config_path = Path(config_path)
        self.custom_domains = {}  # domain → provider_id from config file

        # Load custom domain mappings if config exists
        if self.config_path.exists():
            try:
                with open(self.config_path) as f:
                    cfg = json.load(f)
                self.custom_domains = cfg.get("domain_map", {})
                logger.info("Loaded %d custom domain mappings", len(self.custom_domains))
            except Exception as e:
                logger.warning("Failed to load config: %s", e)

    def get_provider_for_domain(self, domain: str) -> str | None:
        """Match a domain to a provider ID."""
        # Check custom mappings first
        if domain in self.custom_domains:
            return self.custom_domains[domain]  # type: ignore[no-any-return]
        # Check built-in mappings
        for pattern, provider_id in DOMAIN_MAP.items():
            if domain == pattern or domain.endswith("." + pattern):
                return provider_id
        return None

    def read_key(self, provider_id: str) -> str | None:
        """Read an API key from the secrets directory."""
        key_file = self.secrets_dir / f"{provider_id}.key"
        try:
            return key_file.read_text().strip()
        except FileNotFoundError:
            return None

    def get_auth_config(self, provider_id: str) -> dict | None:
        """Get the auth configuration for a provider."""
        return PROVIDER_AUTH.get(provider_id)


class AuthProxyHandler(BaseHTTPRequestHandler):
    """HTTP proxy that injects auth credentials into outbound requests."""

    config: ProxyConfig = None  # type: ignore[assignment]  # Set by server before use

    def do_CONNECT(self):
        """Handle HTTPS CONNECT tunneling — we can't inject headers here.
        For HTTPS, the sandbox should use HTTP_PROXY pointing to us,
        and we forward as HTTP (the proxy terminates TLS)."""
        self.send_error(501, "CONNECT not supported — use HTTP_PROXY mode")

    def _proxy_request(self, method: str):
        """Forward the request to the target with auth injected."""
        # Parse the target URL
        url = self.path
        if not url.startswith("http"):
            # Relative URL — shouldn't happen in proxy mode
            self.send_error(400, "Absolute URL required in proxy mode")
            return

        parsed = urlparse(url)
        domain = parsed.hostname
        if domain is None:
            self.send_error(400, "Could not parse hostname from URL")
            return
        port = parsed.port or (443 if parsed.scheme == "https" else 80)
        path = parsed.path
        if parsed.query:
            path += "?" + parsed.query

        # Look up provider and inject auth
        provider_id = self.config.get_provider_for_domain(domain)
        if provider_id:
            api_key = self.config.read_key(provider_id)
            if api_key:
                auth_cfg = self.config.get_auth_config(provider_id)
                if auth_cfg:
                    if auth_cfg["method"] == "query":
                        separator = "&" if "?" in path else "?"
                        path += f"{separator}{auth_cfg['param']}={api_key}"
                        logger.info("AUTH %s → %s (query param)", domain, provider_id)
                    else:
                        # Will be added to headers below
                        logger.info("AUTH %s → %s (header)", domain, provider_id)
            else:
                logger.warning("No key for provider %s (domain: %s)", provider_id, domain)
        else:
            logger.debug("No provider match for domain: %s — forwarding without auth", domain)

        # Read request body
        content_length = int(self.headers.get("Content-Length", 0))
        body = self.rfile.read(content_length) if content_length > 0 else None

        # Build outbound headers (copy from original, inject auth)
        out_headers = {}
        for key in self.headers:
            if key.lower() in ("host", "proxy-connection", "proxy-authorization"):
                continue
            out_headers[key] = self.headers[key]

        # Inject auth header if needed
        if provider_id:
            api_key = self.config.read_key(provider_id)
            auth_cfg = self.config.get_auth_config(provider_id)
            if api_key and auth_cfg and auth_cfg["method"] == "header":
                out_headers[auth_cfg["header"]] = f"{auth_cfg['prefix']}{api_key}"

        # Make the outbound request
        try:
            conn: http.client.HTTPConnection
            if parsed.scheme == "https":
                ctx = ssl.create_default_context()
                conn = http.client.HTTPSConnection(domain, port, context=ctx, timeout=60)
            else:
                conn = http.client.HTTPConnection(domain, port, timeout=60)

            conn.request(method, path, body=body, headers=out_headers)
            resp = conn.getresponse()

            # Forward response back to client
            self.send_response(resp.status)
            for key, value in resp.getheaders():
                if key.lower() not in ("transfer-encoding",):
                    self.send_header(key, value)
            self.end_headers()

            # Stream response body
            while True:
                chunk = resp.read(8192)
                if not chunk:
                    break
                self.wfile.write(chunk)

            conn.close()

        except Exception as e:
            logger.error("Proxy error for %s: %s", url, e)
            self.send_error(502, f"Proxy error: {e}")

    def do_GET(self):
        self._proxy_request("GET")

    def do_POST(self):
        self._proxy_request("POST")

    def do_PUT(self):
        self._proxy_request("PUT")

    def do_DELETE(self):
        self._proxy_request("DELETE")

    def do_PATCH(self):
        self._proxy_request("PATCH")

    def log_message(self, format, *args):
        """Suppress default access log — we use our own logging."""
        pass


def main():
    parser = argparse.ArgumentParser(description="Rook Outbound Auth Proxy")
    parser.add_argument("--port", type=int, default=DEFAULT_PORT, help="Listen port (default: 9090)")
    parser.add_argument("--host", default="127.0.0.1", help="Listen host (default: 127.0.0.1)")
    parser.add_argument("--config", default=DEFAULT_CONFIG_PATH, help="Config file path")
    parser.add_argument("--secrets-dir", default=DEFAULT_SECRETS_DIR, help="Secrets directory")
    args = parser.parse_args()

    config = ProxyConfig(args.config, args.secrets_dir)
    AuthProxyHandler.config = config

    server = HTTPServer((args.host, args.port), AuthProxyHandler)
    logger.info("Auth proxy listening on %s:%d", args.host, args.port)
    logger.info("Secrets dir: %s", args.secrets_dir)  # nosemgrep: python-logger-credential-disclosure
    logger.info("Known providers: %s", ", ".join(DOMAIN_MAP.values()))

    try:
        server.serve_forever()
    except KeyboardInterrupt:
        logger.info("Shutting down.")
        server.shutdown()


if __name__ == "__main__":
    main()
