#!/usr/bin/env python3
"""
fTPM Seal/Unseal for LUKS Passphrase

Seals the LUKS passphrase to the fTPM's PCR values so it can only
be recovered if the boot chain is unmodified.

Uses tpm2-tools commands:
  - tpm2_createprimary: create a storage root key
  - tpm2_create: create a sealing key bound to PCR policy
  - tpm2_load: load the sealing key
  - tpm2_unseal: recover the sealed data

The sealed blob is stored on disk. Without the fTPM (i.e., on a
different device), it cannot be unsealed.

Requires: tpm2-tools, sudo access to /dev/tpm0
"""

import logging
import os
import subprocess
import tempfile
from pathlib import Path
from typing import Optional

logger = logging.getLogger(__name__)

# PCR banks to bind to — these represent the boot chain
# PCR 0: firmware
# PCR 1: firmware config
# PCR 2: option ROMs
# PCR 7: secure boot state
DEFAULT_PCRS = "sha256:0,1,2,7"

SEAL_DIR = Path.home() / ".hermes" / ".tpm-sealed"
SEALED_BLOB = SEAL_DIR / "luks-passphrase.sealed"
SEALED_PUB = SEAL_DIR / "luks-passphrase.pub"
SEALED_PRIV = SEAL_DIR / "luks-passphrase.priv"
SEALED_CTX = SEAL_DIR / "primary.ctx"
PCR_POLICY = SEAL_DIR / "pcr-policy.dat"


def _run(cmd: list[str], input_data: Optional[bytes] = None, check: bool = True) -> subprocess.CompletedProcess:
    """Run a command with sudo for TPM access."""
    full_cmd = ["sudo"] + cmd if cmd[0] != "sudo" else cmd
    result = subprocess.run(
        full_cmd,
        input=input_data,
        capture_output=True,
    )
    if check and result.returncode != 0:
        stderr = result.stderr.decode(errors="replace")
        raise RuntimeError(f"TPM command failed: {' '.join(cmd)}: {stderr}")
    return result


def is_tpm_available() -> bool:
    """Check if fTPM is available."""
    return Path("/dev/tpm0").exists() or Path("/dev/tpmrm0").exists()


def read_pcrs(pcr_spec: str = DEFAULT_PCRS) -> dict:
    """Read current PCR values."""
    result = _run(["tpm2_pcrread", pcr_spec])
    return {"raw": result.stdout.decode()}


def seal_passphrase(passphrase: str, pcr_spec: str = DEFAULT_PCRS) -> bool:
    """
    Seal a passphrase to current PCR values.

    The passphrase can only be recovered if the PCR values match
    (i.e., the boot chain hasn't been tampered with).
    """
    if not is_tpm_available():
        logger.error("fTPM not available")
        return False

    SEAL_DIR.mkdir(parents=True, exist_ok=True)
    os.chmod(SEAL_DIR, 0o700)  # nosemgrep: insecure-file-permissions  # TPM seal dir must be owner-only

    with tempfile.NamedTemporaryFile(suffix=".dat", delete=False) as tmp:
        tmp.write(passphrase.encode())
        secret_file = tmp.name

    try:
        # Step 1: Create primary key (storage root)
        _run(
            [
                "tpm2_createprimary",
                "-C",
                "o",
                "-c",
                str(SEALED_CTX),
            ]
        )

        # Step 2: Create PCR policy
        _run(
            [
                "tpm2_pcrread",
                pcr_spec,
                "-o",
                str(PCR_POLICY) + ".pcrs",
            ]
        )
        _run(
            [
                "tpm2_createpolicy",
                "--policy-pcr",
                "-l",
                pcr_spec,
                "-f",
                str(PCR_POLICY) + ".pcrs",
                "-L",
                str(PCR_POLICY),
            ]
        )

        # Step 3: Create sealing object bound to PCR policy
        _run(
            [
                "tpm2_create",
                "-C",
                str(SEALED_CTX),
                "-u",
                str(SEALED_PUB),
                "-r",
                str(SEALED_PRIV),
                "-L",
                str(PCR_POLICY),
                "-i",
                secret_file,
            ]
        )

        # Clean up secret file
        os.unlink(secret_file)

        # Verify: load and unseal to confirm
        with tempfile.NamedTemporaryFile(suffix=".ctx", delete=False) as tmp_ctx:
            loaded_ctx = tmp_ctx.name

        _run(
            [
                "tpm2_load",
                "-C",
                str(SEALED_CTX),
                "-u",
                str(SEALED_PUB),
                "-r",
                str(SEALED_PRIV),
                "-c",
                loaded_ctx,
            ]
        )

        result = _run(
            [
                "tpm2_unseal",
                "-c",
                loaded_ctx,
                "-p",
                f"pcr:{pcr_spec}",
            ]
        )

        os.unlink(loaded_ctx)

        recovered = result.stdout.decode()
        if recovered != passphrase:
            logger.error("Seal verification failed — recovered data doesn't match")
            return False

        # Set permissions
        for f in (SEALED_PUB, SEALED_PRIV, SEALED_CTX, PCR_POLICY):
            if f.exists():
                os.chmod(f, 0o600)

        logger.info("Passphrase sealed to fTPM with PCR policy %s", pcr_spec)
        return True

    except Exception as e:
        # Clean up on failure
        if os.path.exists(secret_file):
            os.unlink(secret_file)
        logger.error("Seal failed: %s", e)
        return False


def unseal_passphrase(pcr_spec: str = DEFAULT_PCRS) -> Optional[str]:
    """
    Unseal the passphrase from fTPM.

    Returns the passphrase if PCR values match, None otherwise.
    """
    if not is_tpm_available():
        logger.error("fTPM not available")
        return None

    if not SEALED_PUB.exists() or not SEALED_PRIV.exists():
        logger.error("No sealed passphrase found")
        return None

    try:
        # Recreate primary context
        _run(
            [
                "tpm2_createprimary",
                "-C",
                "o",
                "-c",
                str(SEALED_CTX),
            ]
        )

        # Load sealed object
        with tempfile.NamedTemporaryFile(suffix=".ctx", delete=False) as tmp_ctx:
            loaded_ctx = tmp_ctx.name

        _run(
            [
                "tpm2_load",
                "-C",
                str(SEALED_CTX),
                "-u",
                str(SEALED_PUB),
                "-r",
                str(SEALED_PRIV),
                "-c",
                loaded_ctx,
            ]
        )

        # Unseal with PCR policy
        result = _run(
            [
                "tpm2_unseal",
                "-c",
                loaded_ctx,
                "-p",
                f"pcr:{pcr_spec}",
            ],
            check=False,
        )

        os.unlink(loaded_ctx)

        if result.returncode != 0:
            logger.warning("Unseal failed — PCR values may have changed (boot tampered?)")
            return None

        return result.stdout.decode()  # type: ignore[no-any-return]

    except Exception as e:
        logger.error("Unseal error: %s", e)
        return None


def is_sealed() -> bool:
    """Check if a passphrase is currently sealed."""
    return SEALED_PUB.exists() and SEALED_PRIV.exists()


def clear_sealed() -> None:
    """Remove sealed passphrase data."""
    for f in (SEALED_PUB, SEALED_PRIV, SEALED_CTX, PCR_POLICY):
        if f.exists():
            f.unlink()
    logger.info("Sealed data cleared")


# =========================================================================
# CLI
# =========================================================================


def main():
    import argparse

    logging.basicConfig(level=logging.INFO, format="[tpm] %(message)s")

    parser = argparse.ArgumentParser(description="fTPM Seal/Unseal for Hermes")
    sub = parser.add_subparsers(dest="command")

    sub.add_parser("check", help="Check if fTPM is available")
    sub.add_parser("pcrs", help="Read current PCR values")

    seal_cmd = sub.add_parser("seal", help="Seal a passphrase")
    seal_cmd.add_argument("--pcrs", default=DEFAULT_PCRS, help="PCR spec")

    unseal_cmd = sub.add_parser("unseal", help="Unseal passphrase")
    unseal_cmd.add_argument("--pcrs", default=DEFAULT_PCRS, help="PCR spec")

    sub.add_parser("status", help="Check seal status")
    sub.add_parser("clear", help="Remove sealed data")

    args = parser.parse_args()

    if args.command == "check":
        if is_tpm_available():
            print("fTPM is available at /dev/tpm0")
            pcr_data = read_pcrs()
            print(pcr_data["raw"])
        else:
            print("fTPM NOT available")

    elif args.command == "pcrs":
        data = read_pcrs()
        print(data["raw"])

    elif args.command == "seal":
        import getpass

        pw = getpass.getpass("Passphrase to seal: ")
        if seal_passphrase(pw, args.pcrs):
            print("Sealed successfully.")
        else:
            print("Seal FAILED.")

    elif args.command == "unseal":
        result = unseal_passphrase(args.pcrs)
        if result:
            print(f"Unsealed: {result}")
        else:
            print("Unseal FAILED (PCR mismatch or no sealed data)")

    elif args.command == "status":
        print(f"fTPM available: {is_tpm_available()}")
        print(f"Passphrase sealed: {is_sealed()}")

    elif args.command == "clear":
        clear_sealed()
        print("Sealed data cleared.")

    else:
        parser.print_help()


if __name__ == "__main__":
    main()
