import argparse
from datetime import datetime, timezone, timedelta
import json
import os
import socket
import sys
import time
from dateutil import parser
from decimal import Decimal

ether_in_wei = Decimal("1000000000000000000")

blockChanges: list[tuple[int, datetime]] = []

BOLD    = "\033[1m"
DIM     = "\033[2m"
RED     = "\033[31m"
GREEN   = "\033[32m"
YELLOW  = "\033[33m"
CYAN    = "\033[36m"
RESET   = "\033[0m"

def get_terminal_width() -> int:
    try:
        return os.get_terminal_size().columns
    except OSError:
        return 80

def fmt_duration(td: timedelta) -> str:
    total = int(td.total_seconds())
    if total < 0:
        return "in the future"
    days, rem = divmod(total, 86400)
    hours, rem = divmod(rem, 3600)
    mins, secs = divmod(rem, 60)
    parts = []
    if days:
        parts.append(f"{days}d")
    if hours:
        parts.append(f"{hours}h")
    if mins:
        parts.append(f"{mins}m")
    parts.append(f"{secs}s")
    return " ".join(parts) + " ago"

def status_dot(ok: bool) -> str:
    return f"{GREEN}●{RESET}" if ok else f"{RED}●{RESET}"

def warn(msg: str) -> str:
    return f"  {RED}⚠ {msg}{RESET}"

def section(title: str, width: int) -> str:
    line = "─" * width
    return f"\n{DIM}{line}{RESET}\n{BOLD}{title}{RESET}"

def get_block_ps(blockChange: list[tuple[int, datetime]]) -> float:
    (startBlock, startTime) = blockChange[0]
    (endBlock, endTime) = blockChange[-1]
    timeChange = endTime.replace(tzinfo=timezone.utc) - startTime.replace(tzinfo=timezone.utc)
    if timeChange.total_seconds() == 0:
        return 0
    return float(endBlock - startBlock) / timeChange.total_seconds()

def fetch_status_udp(host: str, port: int) -> dict:
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.settimeout(5)
    try:
        sock.sendto(b"getStatus", (host, port))
        data, _ = sock.recvfrom(4096)
        return json.loads(data.decode("utf-8"))
    finally:
        sock.close()

def fetch_status_http(address: str) -> dict:
    import requests
    url = address + "/status"
    response = requests.get(url, timeout=5)
    response.raise_for_status()
    return response.json()

def run_loop(args: argparse.Namespace):
    try:
        while True:
            output = build_output(args)
            sys.stdout.write("\033[H\033[J")
            sys.stdout.write(output)
            sys.stdout.flush()
            time.sleep(args.refreshSeconds)
    except KeyboardInterrupt:
        print("\nExiting")

def build_output(args: argparse.Namespace) -> str:
    lines: list[str] = []
    w = get_terminal_width()

    if args.http:
        lines.append(f"{YELLOW}⚠ DEPRECATED{RESET}: HTTP mode is deprecated, switch to UDP (default)")
        data = fetch_status_http(args.address)
        target = f"{args.address}/status"
    else:
        data = fetch_status_udp(args.address, args.port)
        target = f"{args.address}:{args.port}"

    now = datetime.now(timezone.utc)
    now_str = now.strftime("%H:%M:%S UTC")

    title = f"{BOLD}Redbelly Node Monitor{RESET}"
    meta = f"{DIM}{target}  ·  {now_str}{RESET}"
    lines.append(title + "  " + meta)

    # --- Recovery ---
    isRecoveryComplete = bool(data["isRecoveryComplete"])
    dot = status_dot(isRecoveryComplete)
    sync_label = "Synced" if isRecoveryComplete else "Initial sync in progress"
    lines.append(f"{dot} {sync_label}")

    # --- Blocks ---
    lines.append(section("Blocks", w))

    lastCommittedBlockAt = datetime.fromtimestamp(0) if data["lastCommittedBlockAt"] == "" else parser.parse(data["lastCommittedBlockAt"])
    timeSinceLastBlock = now - lastCommittedBlockAt.replace(tzinfo=timezone.utc)
    currentBlock = int(data["currentBlock"])
    lastBlockFromGovernors = int(data["lastBlockFromGovernors"])
    lastSyncedWithGovernorNodes = datetime.fromtimestamp(0) if data["lastSyncedWithGovernorNodes"] == "" else parser.parse(data["lastSyncedWithGovernorNodes"])
    timeSinceLastSyncWithGovs = now - lastSyncedWithGovernorNodes.replace(tzinfo=timezone.utc)
    blocksBehind = lastBlockFromGovernors - currentBlock

    blockChanges.append((currentBlock, lastCommittedBlockAt))
    if len(blockChanges) > 10:
        blockChanges.pop(0)
    bps = get_block_ps(blockChanges)

    block_ok = timeSinceLastBlock <= timedelta(minutes=5)
    lines.append(f"  {status_dot(block_ok)} Current Block    {BOLD}{currentBlock:,}{RESET}  {DIM}({fmt_duration(timeSinceLastBlock)}){RESET}")
    lines.append(f"    Governor Block  {lastBlockFromGovernors:,}  {DIM}(synced {fmt_duration(timeSinceLastSyncWithGovs)}){RESET}")
    if blocksBehind > 0:
        behind_color = RED if blocksBehind > 100 else YELLOW if blocksBehind > 10 else DIM
        lines.append(f"    Blocks Behind   {behind_color}{blocksBehind:,}{RESET}")
    lines.append(f"    Process Rate    {bps:.2f} blocks/s")

    if timeSinceLastBlock > timedelta(minutes=5):
        lines.append(warn("No block processed in 5+ minutes — may be out of sync"))
    if blocksBehind > 100:
        lines.append(warn(f"Node is {blocksBehind:,} blocks behind governors"))
    if timeSinceLastSyncWithGovs > timedelta(minutes=1):
        lines.append(warn("Governor sync stale for 1+ minute"))

    # --- Superblocks ---
    lines.append(section("Superblocks", w))

    currentSuperblock = int(data["currentSuperblock"])
    lastSyncedWithBootnodes = datetime.fromtimestamp(0) if data["lastSyncedWithBootnodes"] == "" else parser.parse(data["lastSyncedWithBootnodes"])
    lastSuperblockFromBootnodes = int(data["lastSuperblockFromBootnodes"])
    timeSinceLastSyncWithBootnodes = now - lastSyncedWithBootnodes.replace(tzinfo=timezone.utc)
    superblocksBehind = lastSuperblockFromBootnodes - currentSuperblock

    sb_ok = superblocksBehind <= 100
    lines.append(f"  {status_dot(sb_ok)} Current Superblock  {BOLD}{currentSuperblock:,}{RESET}")
    lines.append(f"    Bootnode Superblock  {lastSuperblockFromBootnodes:,}  {DIM}(synced {fmt_duration(timeSinceLastSyncWithBootnodes)}){RESET}")
    if superblocksBehind > 0:
        behind_color = RED if superblocksBehind > 100 else YELLOW if superblocksBehind > 10 else DIM
        lines.append(f"    Behind              {behind_color}{superblocksBehind:,}{RESET}")

    if superblocksBehind > 100:
        lines.append(warn(f"Node is {superblocksBehind:,} superblocks behind bootnodes"))
    if timeSinceLastSyncWithBootnodes > timedelta(minutes=2):
        lines.append(warn("Bootnode sync stale for 2+ minutes"))

    # --- Signing ---
    lines.append(section("Signing", w))

    signingAddress = str(data["signingAddress"])
    signingAddressBalance = Decimal(data["signingAddressBalance"]) / ether_in_wei
    bal_ok = signingAddressBalance >= args.minBalance

    lines.append(f"  {status_dot(bal_ok)} Address  {CYAN}{signingAddress}{RESET}")
    lines.append(f"    Balance  {BOLD}{signingAddressBalance:.4f}{RESET} RBNT")

    if not bal_ok:
        lines.append(warn(f"Balance below minimum of {args.minBalance} RBNT"))

    # --- Certificates ---
    lines.append(section("Certificates", w))

    certificateDnsNames: list[str] = data["certificateDnsNames"]
    certificatesValidUpto = datetime.fromtimestamp(0) if data["certificatesValidUpto"] == "" else parser.parse(data["certificatesValidUpto"])
    certificateValidDuration = certificatesValidUpto.replace(tzinfo=timezone.utc) - now

    cert_ok = certificateValidDuration > timedelta(days=7)
    cert_expiry_str = certificatesValidUpto.strftime("%Y-%m-%d %H:%M UTC") if certificatesValidUpto.year > 1970 else "N/A"

    lines.append(f"  {status_dot(cert_ok)} DNS Names  {', '.join(certificateDnsNames)}")
    lines.append(f"    Expires    {cert_expiry_str}  {DIM}({fmt_duration(certificateValidDuration) if certificateValidDuration > timedelta() else 'EXPIRED'}){RESET}")

    if certificateValidDuration <= timedelta():
        lines.append(warn("Certificate has EXPIRED"))
    elif certificateValidDuration <= timedelta(days=7):
        lines.append(warn("Certificate expires within 7 days"))

    identityCertCN = str(data.get("identityCertificateCommonName", ""))
    if identityCertCN:
        identityCertValidUpto = str(data.get("identityCertificatesValidUpto", ""))
        identityCertPubKey = str(data.get("identityCertificatePubKey", ""))
        lines.append(f"    Identity CN  {identityCertCN}")
        lines.append(f"    ID Expires   {identityCertValidUpto}")
        if identityCertPubKey:
            lines.append(f"    ID PubKey    {identityCertPubKey[:20]}...{identityCertPubKey[-8:]}" if len(identityCertPubKey) > 32 else f"    ID PubKey    {identityCertPubKey}")

    # --- Node Info ---
    lines.append(section("Node Info", w))

    version = str(data["version"])
    stateScheme = str(data.get("stateScheme", ""))
    gcMode = str(data.get("gcMode", ""))

    parts = version.split(" ", 1)
    ver_num = parts[0]
    ver_hash = parts[1] if len(parts) > 1 else ""
    lines.append(f"  Version      {BOLD}{ver_num}{RESET}  {DIM}{ver_hash}{RESET}")
    if stateScheme:
        lines.append(f"  State Scheme {stateScheme}")
    if gcMode:
        lines.append(f"  GC Mode      {gcMode}")

    lines.append("")
    return "\n".join(lines) + "\n"

def parse_args() -> argparse.Namespace:
    arg_parser = argparse.ArgumentParser(description="Watch the stats of a local Redbelly node")
    arg_parser.add_argument("-a", "--address", type=str, default="localhost",
                            help="Hostname of the Redbelly node (default: localhost). When using --http, provide the full URL (e.g. http://localhost:6539)")
    arg_parser.add_argument("-p", "--port", type=int, default=6540,
                            help="UDP port of the Redbelly node's status server (default: 6540)")
    arg_parser.add_argument("--http", action="store_true",
                            help="(Deprecated) Use HTTP instead of UDP. When set, --address should be a full URL like http://localhost:6539")
    arg_parser.add_argument("-mb", "--minBalance", type=int, default=10,
                            help="Minimum signing address balance in RBNT before warning")
    arg_parser.add_argument("-r", "--refreshSeconds", type=int, default=5,
                            help="Frequency to refresh values")
    args = arg_parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    run_loop(args)
