#!/usr/bin/env python3
"""
Rook Key Derivation Module

Derives all system keys from a BIP39 mnemonic + optional key file + Google Drive folder ID.

Derivation chain:
    mnemonic → BIP39 seed (512-bit)
    seed + salt(keyfile_hash, folder_id) → HKDF-SHA256 → purpose-specific keys

Three independent keys from one root:
    info="bitwarden-master"  → Vaultwarden master password (Base85)
    info="luks-container"    → LUKS2 passphrase (hex)
    info="backup-aes-key"    → AES-256-GCM key for backups (raw bytes)

Usage:
    from keygen import RookKeyDerivation

    kd = RookKeyDerivation(
        mnemonic="abandon abandon abandon ... about",
        folder_id="1a2b3c4d5e6f",
        keyfile_path="/path/to/photo.jpg",  # optional
    )

    print(kd.bitwarden_password)   # Base85 string
    print(kd.luks_passphrase)      # hex string
    backup_key = kd.backup_key     # 32 raw bytes
"""

import hashlib
import sys
from pathlib import Path
from typing import Optional

# --- BIP39 ---
# We use the mnemonic library if available, otherwise a minimal implementation

try:
    from mnemonic import Mnemonic

    HAS_MNEMONIC_LIB = True
except ImportError:
    HAS_MNEMONIC_LIB = False

# --- Cryptography ---

try:
    from cryptography.hazmat.primitives import hashes
    from cryptography.hazmat.primitives.kdf.hkdf import HKDF

    HAS_CRYPTOGRAPHY = True
except ImportError:
    HAS_CRYPTOGRAPHY = False
    # Fallback: use hmac-based HKDF from hashlib
    import hmac


# =========================================================================
# BIP39 Seed Derivation
# =========================================================================


def generate_mnemonic(strength: int = 128) -> str:
    """Generate a BIP39 mnemonic. strength=128 → 12 words, 256 → 24 words."""
    if HAS_MNEMONIC_LIB:
        m = Mnemonic("english")
        return m.generate(strength)  # type: ignore[no-any-return]
    else:
        raise ImportError("python-mnemonic required for mnemonic generation. Install with: pip install mnemonic")


def mnemonic_to_seed(mnemonic: str, passphrase: str = "") -> bytes:
    """Convert BIP39 mnemonic to 512-bit seed via PBKDF2-HMAC-SHA512."""
    if HAS_MNEMONIC_LIB:
        m = Mnemonic("english")
        if not m.check(mnemonic):
            raise ValueError("Invalid BIP39 mnemonic")
        return Mnemonic.to_seed(mnemonic, passphrase)  # type: ignore[no-any-return]
    else:
        # Minimal BIP39 seed derivation (spec-compliant)
        mnemonic_bytes = mnemonic.encode("utf-8")
        salt = ("mnemonic" + passphrase).encode("utf-8")
        return hashlib.pbkdf2_hmac("sha512", mnemonic_bytes, salt, 2048)


def validate_mnemonic(mnemonic: str) -> bool:
    """Check if a mnemonic is valid BIP39."""
    if HAS_MNEMONIC_LIB:
        return Mnemonic("english").check(mnemonic)  # type: ignore[no-any-return]
    # Basic check: word count
    words = mnemonic.strip().split()
    return len(words) in (12, 15, 18, 21, 24)


# =========================================================================
# Key File Handling
# =========================================================================


def hash_keyfile(path: str) -> bytes:
    """SHA256 hash of a key file's contents."""
    h = hashlib.sha256()
    with open(path, "rb") as f:
        while True:
            chunk = f.read(65536)
            if not chunk:
                break
            h.update(chunk)
    return h.digest()


# =========================================================================
# HKDF Key Derivation
# =========================================================================


def _hkdf_derive(ikm: bytes, salt: bytes, info: bytes, length: int = 32) -> bytes:
    """Derive a key using HKDF-SHA256."""
    if HAS_CRYPTOGRAPHY:
        hkdf = HKDF(
            algorithm=hashes.SHA256(),
            length=length,
            salt=salt,
            info=info,
        )
        return hkdf.derive(ikm)  # type: ignore[no-any-return]
    else:
        # RFC 5869 HKDF with HMAC-SHA256 (fallback)
        # Extract
        if not salt:
            salt = b"\x00" * 32
        prk = hmac.new(salt, ikm, hashlib.sha256).digest()
        # Expand
        n = (length + 31) // 32
        okm = b""
        t = b""
        for i in range(1, n + 1):
            t = hmac.new(prk, t + info + bytes([i]), hashlib.sha256).digest()
            okm += t
        return okm[:length]


def build_salt(folder_id: str, keyfile_hash: Optional[bytes] = None) -> bytes:
    """Build the combined HKDF salt from folder ID and optional key file hash."""
    if keyfile_hash:
        # SHA256(keyfile_hash || folder_id_bytes)
        combined = keyfile_hash + folder_id.encode("utf-8")
        return hashlib.sha256(combined).digest()
    else:
        return folder_id.encode("utf-8")


# =========================================================================
# Base85 Encoding (for Bitwarden password)
# =========================================================================


def base85_encode(data: bytes) -> str:
    """Encode bytes as Base85 (ASCII85) string."""
    import base64

    return base64.b85encode(data).decode("ascii")


# =========================================================================
# Main Class
# =========================================================================


class RookKeyDerivation:
    """
    Derives all Hermes system keys from a BIP39 mnemonic.

    Attributes:
        bitwarden_password: str — Base85-encoded Vaultwarden master password
        luks_passphrase: str — Hex-encoded LUKS2 container passphrase
        backup_key: bytes — 32-byte AES-256-GCM key for backup encryption
    """

    def __init__(
        self,
        mnemonic: str,
        folder_id: str,
        keyfile_path: Optional[str] = None,
    ):
        # Validate inputs
        if not validate_mnemonic(mnemonic):
            raise ValueError("Invalid BIP39 mnemonic")
        if not folder_id:
            raise ValueError("Google Drive folder ID is required")

        # Step 1: BIP39 seed
        self._seed = mnemonic_to_seed(mnemonic)

        # Step 2: Key file hash (optional)
        keyfile_hash = None
        if keyfile_path:
            if not Path(keyfile_path).exists():
                raise FileNotFoundError(f"Key file not found: {keyfile_path}")
            keyfile_hash = hash_keyfile(keyfile_path)

        # Step 3: Build salt
        self._salt = build_salt(folder_id, keyfile_hash)

        # Step 4: Derive keys
        self._bw_key = _hkdf_derive(self._seed, self._salt, b"bitwarden-master", 32)
        self._luks_key = _hkdf_derive(self._seed, self._salt, b"luks-container", 32)
        self._backup_key = _hkdf_derive(self._seed, self._salt, b"backup-aes-key", 32)

    @property
    def bitwarden_password(self) -> str:
        """Vaultwarden master password (Base85-encoded, 32 bytes → ~40 chars)."""
        return base85_encode(self._bw_key)

    @property
    def luks_passphrase(self) -> str:
        """LUKS2 container passphrase (hex-encoded, 32 bytes → 64 chars)."""
        return self._luks_key.hex()

    @property
    def backup_key(self) -> bytes:
        """AES-256-GCM key for backup encryption (32 raw bytes)."""
        return self._backup_key

    @property
    def backup_key_hex(self) -> str:
        """AES-256-GCM key as hex string (for display/verification)."""
        return self._backup_key.hex()

    def wipe(self):
        """Overwrite sensitive material in memory."""
        import ctypes

        for attr in ("_seed", "_salt", "_bw_key", "_luks_key", "_backup_key"):
            val = getattr(self, attr, None)
            if val and isinstance(val, bytes):
                ctypes.memset(id(val) + sys.getsizeof(val) - len(val), 0, len(val))
            setattr(self, attr, None)


# =========================================================================
# CLI
# =========================================================================


def main():
    """CLI for testing key derivation."""
    import argparse

    parser = argparse.ArgumentParser(description="Rook Key Derivation")
    sub = parser.add_subparsers(dest="command")

    # Generate mnemonic
    gen = sub.add_parser("generate", help="Generate a new BIP39 mnemonic")
    gen.add_argument("--words", type=int, choices=[12, 24], default=12)

    # Derive keys
    derive = sub.add_parser("derive", help="Derive keys from mnemonic")
    derive.add_argument("--mnemonic", required=True, help="BIP39 mnemonic (quoted)")
    derive.add_argument("--folder-id", required=True, help="Google Drive folder ID")
    derive.add_argument("--keyfile", help="Optional key file path")
    derive.add_argument("--show-keys", action="store_true", help="Print derived keys (DANGER)")

    # Validate
    val = sub.add_parser("validate", help="Check if a mnemonic is valid")
    val.add_argument("mnemonic", help="BIP39 mnemonic to validate")

    args = parser.parse_args()

    if args.command == "generate":
        strength = 128 if args.words == 12 else 256
        mnemonic = generate_mnemonic(strength)
        print(f"\nGenerated {args.words}-word mnemonic:\n")
        words = mnemonic.split()
        for i, w in enumerate(words, 1):
            print(f"  {i:2}. {w}")
        print("\nWRITE THESE DOWN. Do not store digitally.\n")

    elif args.command == "derive":
        kd = RookKeyDerivation(
            mnemonic=args.mnemonic,
            folder_id=args.folder_id,
            keyfile_path=args.keyfile,
        )
        print("\nKey derivation successful.")
        print(f"  Folder ID:    {args.folder_id}")
        print(f"  Key file:     {args.keyfile or 'none'}")
        if args.show_keys:
            print("\n  *** SENSITIVE — DO NOT SHARE ***")
            print(f"  Bitwarden pw: {kd.bitwarden_password}")
            print(f"  LUKS phrase:  {kd.luks_passphrase}")
            print(f"  Backup key:   {kd.backup_key_hex}")
        else:
            print("\n  Keys derived but not displayed. Use --show-keys to reveal.")
        kd.wipe()

    elif args.command == "validate":
        valid = validate_mnemonic(args.mnemonic)
        print(f"Mnemonic is {'VALID' if valid else 'INVALID'}")

    else:
        parser.print_help()


if __name__ == "__main__":
    main()
