"""Tests for encryption-at-rest feature."""

from __future__ import annotations

import pytest

cryptography = pytest.importorskip("cryptography")


# ------------------------------------------------------------------
# Fixtures
# ------------------------------------------------------------------


@pytest.fixture()
def fernet_key():
    """Generate a fresh Fernet key for testing."""
    from cryptography.fernet import Fernet

    return Fernet.generate_key()


@pytest.fixture()
def env(tmp_path, monkeypatch):
    """Isolated engram environment with entities + index."""
    ent = tmp_path / "entities"
    ent.mkdir()
    idx = tmp_path / "vault_index.sqlite"

    monkeypatch.setattr("engram.config.ENTITIES_DIR", ent)
    monkeypatch.setattr("engram.config.INDEX_PATH", idx)
    monkeypatch.setattr("engram.versions.ENTITIES_DIR", ent)
    monkeypatch.setattr("engram.versions.INDEX_PATH", idx)

    # Point encryption key dir to tmp
    key_dir = tmp_path / ".encryption"
    key_dir.mkdir()
    monkeypatch.setattr("engram.encryption.KEY_DIR", key_dir)
    monkeypatch.setattr("engram.encryption.KEY_FILE", key_dir / "engram.key")

    # Clear env var by default
    monkeypatch.delenv("ENGRAM_ENCRYPTION_KEY", raising=False)

    return {"entities": ent, "index": idx, "key_dir": key_dir}


# ------------------------------------------------------------------
# 1. generate_key returns a valid Fernet key
# ------------------------------------------------------------------


def test_generate_key_returns_valid_fernet_key():
    from cryptography.fernet import Fernet

    from engram.encryption import generate_key

    key = generate_key()
    # Should not raise — valid key
    f = Fernet(key)
    ct = f.encrypt(b"probe")
    assert f.decrypt(ct) == b"probe"


# ------------------------------------------------------------------
# 2. encrypt / decrypt roundtrip
# ------------------------------------------------------------------


def test_encrypt_decrypt_roundtrip(fernet_key):
    from engram.encryption import decrypt, encrypt

    plaintext = "The quick brown fox jumps over the lazy dog."
    ct = encrypt(plaintext, key=fernet_key)
    assert isinstance(ct, bytes)
    result = decrypt(ct, key=fernet_key)
    assert result == plaintext


# ------------------------------------------------------------------
# 3. Wrong key raises InvalidToken
# ------------------------------------------------------------------


def test_wrong_key_raises(fernet_key):
    from cryptography.fernet import Fernet, InvalidToken

    from engram.encryption import decrypt, encrypt

    ct = encrypt("secret data", key=fernet_key)
    wrong_key = Fernet.generate_key()

    with pytest.raises(InvalidToken):
        decrypt(ct, key=wrong_key)


# ------------------------------------------------------------------
# 4. is_enabled reads ENGRAM_ENCRYPTION_KEY env var
# ------------------------------------------------------------------


def test_is_enabled_reads_env(env, monkeypatch, fernet_key):
    from engram.encryption import is_enabled

    assert not is_enabled()

    monkeypatch.setenv("ENGRAM_ENCRYPTION_KEY", fernet_key.decode())
    assert is_enabled()


# ------------------------------------------------------------------
# 5. is_enabled reads key file
# ------------------------------------------------------------------


def test_is_enabled_reads_file(env, fernet_key):
    from engram.encryption import is_enabled

    assert not is_enabled()

    key_file = env["key_dir"] / "engram.key"
    key_file.write_bytes(fernet_key)
    assert is_enabled()


# ------------------------------------------------------------------
# 6. Version snapshot encrypted when key is set
# ------------------------------------------------------------------


def test_version_snapshot_encrypted_when_enabled(env, monkeypatch, fernet_key):
    from engram.encryption import decrypt
    from engram.versions import create_snapshot

    monkeypatch.setenv("ENGRAM_ENCRYPTION_KEY", fernet_key.decode())

    vid = create_snapshot("test.md", "encrypted content")
    vdir = env["entities"] / ".versions"

    enc_file = vdir / f"{vid}.enc"
    md_file = vdir / f"{vid}.md"

    assert enc_file.exists(), "Expected .enc file when encryption is enabled"
    assert not md_file.exists(), ".md file should not exist when encryption is enabled"

    # Verify we can decrypt
    plaintext = decrypt(enc_file.read_bytes(), key=fernet_key)
    assert plaintext == "encrypted content"


# ------------------------------------------------------------------
# 7. Version snapshot plaintext when disabled
# ------------------------------------------------------------------


def test_version_snapshot_plaintext_when_disabled(env):
    from engram.versions import create_snapshot

    vid = create_snapshot("test.md", "plaintext content")
    vdir = env["entities"] / ".versions"

    md_file = vdir / f"{vid}.md"
    enc_file = vdir / f"{vid}.enc"

    assert md_file.exists(), "Expected .md file when encryption is disabled"
    assert not enc_file.exists(), ".enc file should not exist when encryption is disabled"
    assert md_file.read_text() == "plaintext content"


# ------------------------------------------------------------------
# 8. Encrypted file is not valid UTF-8 text
# ------------------------------------------------------------------


def test_encrypted_file_unreadable_as_text(env, monkeypatch, fernet_key):
    from engram.versions import create_snapshot

    monkeypatch.setenv("ENGRAM_ENCRYPTION_KEY", fernet_key.decode())

    vid = create_snapshot("test.md", "this should not be readable as plain text")
    vdir = env["entities"] / ".versions"
    enc_file = vdir / f"{vid}.enc"

    raw = enc_file.read_bytes()
    # Fernet output is base64, so it IS valid UTF-8, but it should NOT
    # match the original plaintext when decoded as text.
    text = raw.decode("utf-8", errors="replace")
    assert text != "this should not be readable as plain text"
