#!/usr/bin/env python3
"""
Shamir's Secret Sharing for BIP39 Mnemonics

Splits a mnemonic into N shares where any K shares can reconstruct it.
Default: 2-of-3 scheme.

Share distribution:
  Share 1: User writes down (paper, sealed envelope)
  Share 2: Encrypted to user's personal email
  Share 3: Stored on device (hardware-bound via OP-TEE HUK if available)

Any 2 of 3 shares reconstruct the mnemonic.

Uses GF(256) arithmetic for the secret sharing — each byte of the
mnemonic entropy is split independently.

Usage:
    from shamir import split_secret, combine_shares

    shares = split_secret(b"secret data", threshold=2, num_shares=3)
    # shares = [(1, b"..."), (2, b"..."), (3, b"...")]

    recovered = combine_shares([(1, shares[0][1]), (3, shares[2][1])])
    assert recovered == b"secret data"
"""

import secrets
from typing import List, Tuple

# GF(256) arithmetic using the irreducible polynomial x^8 + x^4 + x^3 + x + 1
# This is the same field used in AES (Rijndael)


def _gf256_add(a: int, b: int) -> int:
    """Addition in GF(256) is XOR."""
    return a ^ b


def _gf256_mul(a: int, b: int) -> int:
    """Multiplication in GF(256) using Russian peasant multiplication."""
    p = 0
    for _ in range(8):
        if b & 1:
            p ^= a
        hi = a & 0x80
        a = (a << 1) & 0xFF
        if hi:
            a ^= 0x1B  # x^8 + x^4 + x^3 + x + 1
        b >>= 1
    return p


def _gf256_inv(a: int) -> int:
    """Multiplicative inverse in GF(256) via exponentiation (a^254 = a^-1)."""
    if a == 0:
        raise ZeroDivisionError("No inverse for 0 in GF(256)")
    # a^254 = a^-1 in GF(256) since |GF(256)*| = 255
    result = a
    for _ in range(6):  # Square-and-multiply: a^2, a^4, ..., a^128, then multiply
        result = _gf256_mul(result, result)
        result = _gf256_mul(result, a)
    result = _gf256_mul(result, result)  # a^254
    return result


def _gf256_div(a: int, b: int) -> int:
    """Division in GF(256)."""
    return _gf256_mul(a, _gf256_inv(b))


def _eval_polynomial(coeffs: List[int], x: int) -> int:
    """Evaluate polynomial at x in GF(256). coeffs[0] is the secret."""
    result = 0
    for coeff in reversed(coeffs):
        result = _gf256_add(_gf256_mul(result, x), coeff)
    return result


def _lagrange_interpolate(shares: List[Tuple[int, int]], x: int = 0) -> int:
    """Lagrange interpolation at x in GF(256) to recover the secret (f(0))."""
    k = len(shares)
    result = 0
    for i in range(k):
        xi, yi = shares[i]
        # Compute Lagrange basis polynomial L_i(x)
        num = 1
        den = 1
        for j in range(k):
            if i == j:
                continue
            xj = shares[j][0]
            num = _gf256_mul(num, _gf256_add(x, xj))
            den = _gf256_mul(den, _gf256_add(xi, xj))
        # L_i(x) * y_i
        term = _gf256_mul(yi, _gf256_div(num, den))
        result = _gf256_add(result, term)
    return result


def split_secret(
    secret: bytes,
    threshold: int = 2,
    num_shares: int = 3,
) -> List[Tuple[int, bytes]]:
    """
    Split a secret into shares using Shamir's Secret Sharing.

    Args:
        secret: The secret to split (arbitrary bytes)
        threshold: Minimum shares needed to reconstruct (K)
        num_shares: Total shares to generate (N)

    Returns:
        List of (share_index, share_data) tuples.
        share_index is 1-based (1..N).
    """
    if threshold > num_shares:
        raise ValueError("Threshold cannot exceed number of shares")
    if threshold < 2:
        raise ValueError("Threshold must be at least 2")
    if num_shares > 255:
        raise ValueError("Maximum 255 shares (GF(256) limit)")

    shares = [bytearray() for _ in range(num_shares)]

    for byte in secret:
        # Generate random polynomial of degree (threshold-1) with secret as constant term
        coeffs = [byte] + [secrets.randbelow(256) for _ in range(threshold - 1)]

        # Evaluate at x=1, x=2, ..., x=num_shares
        for i in range(num_shares):
            x = i + 1
            shares[i].append(_eval_polynomial(coeffs, x))

    return [(i + 1, bytes(shares[i])) for i in range(num_shares)]


def combine_shares(shares: List[Tuple[int, bytes]]) -> bytes:
    """
    Reconstruct a secret from shares.

    Args:
        shares: List of (share_index, share_data) tuples.
                Need at least threshold shares.

    Returns:
        The reconstructed secret bytes.
    """
    if not shares:
        raise ValueError("No shares provided")

    # All shares must be the same length
    length = len(shares[0][1])
    if not all(len(s[1]) == length for s in shares):
        raise ValueError("All shares must be the same length")

    result = bytearray()
    for byte_idx in range(length):
        # Collect (x, y) pairs for this byte position
        points = [(s[0], s[1][byte_idx]) for s in shares]
        # Interpolate to find f(0) = secret byte
        result.append(_lagrange_interpolate(points, 0))

    return bytes(result)


# =========================================================================
# Mnemonic-specific helpers
# =========================================================================


def split_mnemonic(
    mnemonic: str,
    threshold: int = 2,
    num_shares: int = 3,
) -> List[Tuple[int, str]]:
    """
    Split a BIP39 mnemonic into Shamir shares.

    Returns shares as (index, hex_string) tuples for easy storage/display.
    """
    secret = mnemonic.encode("utf-8")
    raw_shares = split_secret(secret, threshold, num_shares)
    return [(idx, data.hex()) for idx, data in raw_shares]


def combine_mnemonic_shares(shares: List[Tuple[int, str]]) -> str:
    """
    Reconstruct a BIP39 mnemonic from Shamir shares.

    Args:
        shares: List of (index, hex_string) tuples.
    """
    raw_shares = [(idx, bytes.fromhex(hex_data)) for idx, hex_data in shares]
    secret = combine_shares(raw_shares)
    return secret.decode("utf-8")


# =========================================================================
# CLI
# =========================================================================


def main():
    import argparse

    parser = argparse.ArgumentParser(description="Shamir's Secret Sharing for Hermes")
    sub = parser.add_subparsers(dest="command")

    split_cmd = sub.add_parser("split", help="Split a mnemonic into shares")
    split_cmd.add_argument("--mnemonic", required=True, help="BIP39 mnemonic to split")
    split_cmd.add_argument("--threshold", "-k", type=int, default=2, help="Shares needed (default: 2)")
    split_cmd.add_argument("--shares", "-n", type=int, default=3, help="Total shares (default: 3)")

    combine_cmd = sub.add_parser("combine", help="Reconstruct mnemonic from shares")
    combine_cmd.add_argument(
        "--shares",
        nargs="+",
        required=True,
        help="Shares as 'index:hex' pairs (e.g., '1:ab12ef 3:cd34gh')",
    )

    sub.add_parser("test", help="Run self-test")

    args = parser.parse_args()

    if args.command == "split":
        shares = split_mnemonic(args.mnemonic, args.threshold, args.shares)
        print(f"\nMnemonic split into {len(shares)} shares (need {args.threshold} to recover):\n")
        for idx, hex_data in shares:
            print(f"  Share {idx}: {hex_data}")
        print(f"\nStore these separately. Any {args.threshold} shares can reconstruct the mnemonic.")

    elif args.command == "combine":
        parsed = []
        for s in args.shares:
            idx_str, hex_data = s.split(":", 1)
            parsed.append((int(idx_str), hex_data))
        mnemonic = combine_mnemonic_shares(parsed)
        print(f"\nRecovered mnemonic:\n  {mnemonic}\n")

    elif args.command == "test":
        print("Running Shamir's Secret Sharing self-test...")

        # Test 1: Basic 2-of-3
        secret = b"hello world test"
        shares = split_secret(secret, 2, 3)
        assert combine_shares([shares[0], shares[1]]) == secret, "Failed: shares 1,2"
        assert combine_shares([shares[0], shares[2]]) == secret, "Failed: shares 1,3"
        assert combine_shares([shares[1], shares[2]]) == secret, "Failed: shares 2,3"
        print("  [PASS] 2-of-3 basic")

        # Test 2: 3-of-5
        shares5 = split_secret(secret, 3, 5)
        assert combine_shares([shares5[0], shares5[2], shares5[4]]) == secret, "Failed: 3-of-5"
        print("  [PASS] 3-of-5")

        # Test 3: Mnemonic round-trip
        mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about"
        m_shares = split_mnemonic(mnemonic, 2, 3)
        recovered = combine_mnemonic_shares([m_shares[0], m_shares[2]])
        assert recovered == mnemonic, "Failed: mnemonic recovery"
        print("  [PASS] Mnemonic round-trip")

        # Test 4: Single byte edge cases
        for val in (0, 1, 127, 128, 254, 255):
            s = split_secret(bytes([val]), 2, 3)
            assert combine_shares([s[0], s[2]]) == bytes([val]), f"Failed: byte {val}"
        print("  [PASS] Edge cases (0, 1, 127, 128, 254, 255)")

        print("\nAll tests passed!")

    else:
        parser.print_help()


if __name__ == "__main__":
    main()
