#!/usr/bin/env python3
"""
Engram Node Discovery Scanner

Scans the USB-C bridge (192.168.55.1) and local subnet (192.168.1.0/24) for
Jetson Orin Nano units. Identifies nodes via NVIDIA MAC OUI (00:04:4b) or
hostnames containing 'jetson'/'orin'.

Part of the Engram Consumer Onboarding Wizard.

Usage:
    python3 discover_nodes.py
    # Returns sorted list of discovered Engram node IPs, or exits with code 1
"""

import concurrent.futures
import socket
import subprocess
import sys

# Configuration
USB_BRIDGE_IP = "192.168.55.1"
SUBNET_PREFIX = "192.168.1"
NVIDIA_MAC_OUI = "00:04:4b"
MAX_WORKERS_PING = 50
MAX_WORKERS_FILTER = 10
PING_TIMEOUT = 2
DNS_TIMEOUT = 1
ARP_TIMEOUT = 1


def check_ip(ip):
    """
    Ping check for IP reachability.

    Args:
        ip: IP address to ping

    Returns:
        IP string if reachable, None otherwise
    """
    try:
        subprocess.run(
            ["ping", "-c", "1", "-W", "0.5", ip],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
            check=True,
            timeout=PING_TIMEOUT,
        )
        return ip
    except (subprocess.CalledProcessError, subprocess.TimeoutExpired, Exception):
        return None


def get_hostname(ip):
    """
    Attempt reverse DNS lookup to get hostname.

    Args:
        ip: IP address to look up

    Returns:
        Hostname (lowercase) if found, None otherwise
    """
    try:
        hostname = socket.gethostbyaddr(ip)[0]
        return hostname.lower()
    except (socket.herror, socket.gaierror, Exception):
        return None


def get_mac_address(ip):
    """
    Get MAC address via ARP lookup.

    Args:
        ip: IP address to query

    Returns:
        MAC address (lowercase) if found, None otherwise
    """
    try:
        result = subprocess.run(
            ["arp", "-n", ip],
            stdout=subprocess.PIPE,
            stderr=subprocess.DEVNULL,
            check=False,
            timeout=ARP_TIMEOUT,
            text=True,
        )
        if result.returncode == 0:
            for line in result.stdout.strip().split("\n"):
                if ip in line:
                    parts = line.split()
                    if len(parts) > 2:
                        return parts[2].lower()
    except Exception:
        pass
    return None


def is_engram_node(ip):
    """
    Identify if node is an Engram node (Jetson Orin Nano).

    Checks:
    1. Hostname contains 'jetson' or 'orin' (case-insensitive)
    2. MAC address starts with NVIDIA OUI (00:04:4b)

    Args:
        ip: IP address to check

    Returns:
        True if node is identified as Engram device, False otherwise
    """
    # Check hostname first (faster, doesn't require ARP)
    hostname = get_hostname(ip)
    if hostname and ("jetson" in hostname or "orin" in hostname):
        return True

    # Check MAC address (ARP fallback)
    mac = get_mac_address(ip)
    if mac and mac.startswith(NVIDIA_MAC_OUI):
        return True

    return False


def scan_network():
    """
    Scan 192.168.55.1 (USB-C bridge) and 192.168.1.0/24 (local subnet)
    for Engram nodes.

    Returns:
        List of sorted IP addresses of discovered Engram nodes
    """
    # Build target list
    targets = [f"{SUBNET_PREFIX}.{i}" for i in range(1, 255)]
    targets.append(USB_BRIDGE_IP)

    print(
        f"Scanning {len(targets)} addresses (this may take 30-60 seconds)...",
        file=sys.stderr,
    )

    # First pass: identify reachable IPs via concurrent pings
    reachable_ips = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS_PING) as executor:
        results = executor.map(check_ip, targets)
        reachable_ips = [ip for ip in results if ip is not None]

    if not reachable_ips:
        return []

    print(
        f"Found {len(reachable_ips)} reachable node(s). Identifying Engram nodes...",
        file=sys.stderr,
    )

    # Second pass: identify Engram-specific nodes via hostname/MAC filtering
    engram_nodes = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS_FILTER) as executor:
        futures = {executor.submit(is_engram_node, ip): ip for ip in reachable_ips}
        for future in concurrent.futures.as_completed(futures):
            ip = futures[future]
            try:
                if future.result():
                    engram_nodes.append(ip)
            except Exception as e:
                print(f"Warning: Error checking {ip}: {e}", file=sys.stderr)

    return sorted(engram_nodes)


def main():
    """Main entry point for discovery scanner."""
    try:
        engram_nodes = scan_network()

        if engram_nodes:
            for ip in engram_nodes:
                print(ip)
        else:
            print(
                "No Engram nodes found. Ensure Jetson nodes are powered on and reachable.",
                file=sys.stderr,
            )
            sys.exit(1)
    except KeyboardInterrupt:
        print("\nScan cancelled.", file=sys.stderr)
        sys.exit(1)
    except Exception as e:
        print(f"Scan failed: {e}", file=sys.stderr)
        sys.exit(1)


if __name__ == "__main__":
    main()
