#!/usr/bin/env python3
"""kubectl-tt-driver — TenstorrentDriverPolicy operator companion.

Subcommands:
  status (default)        Per-policy / per-node installer-pod table
  logs <cr> [node]        Tail driver-installer pod logs
  restart <cr>            Delete installer pods (DS respawns; re-runs install)

Env:
  TT_OPERATOR_NAMESPACE   default: tt-operator-system
"""

from __future__ import annotations

import datetime
import json
import os
import subprocess
import sys
import time
from typing import Any

NS = os.environ.get("TT_OPERATOR_NAMESPACE", "tt-operator-system")

USAGE = __doc__


def kubectl(*args: str, check: bool = True) -> str:
    p = subprocess.run(["kubectl", *args], capture_output=True, text=True)
    if check and p.returncode != 0:
        sys.stderr.write(p.stderr)
        sys.exit(p.returncode)
    return p.stdout


def kubectl_passthrough(*args: str) -> None:
    os.execvp("kubectl", ["kubectl", *args])


def age(ts: str | None) -> str:
    if not ts:
        return "-"
    t = datetime.datetime.strptime(ts, "%Y-%m-%dT%H:%M:%SZ").replace(
        tzinfo=datetime.timezone.utc
    ).timestamp()
    s = max(0, int(time.time() - t))
    if s < 60:
        return f"{s}s"
    if s < 3600:
        return f"{s//60}m"
    if s < 86400:
        return f"{s//3600}h"
    return f"{s//86400}d"


_USE_COLOR = sys.stdout.isatty()


def color_pod(phase: str, ready: str) -> tuple[str, str]:
    if not _USE_COLOR:
        return "", ""
    if phase != "Running":
        return "\033[31m", "\033[0m"
    if ready != "1/1":
        return "\033[33m", "\033[0m"
    return "\033[32m", "\033[0m"


def cmd_status(argv: list[str]) -> None:
    crs = json.loads(kubectl("get", "ttdp", "-o", "json") or "{}").get("items", [])
    pods = json.loads(
        kubectl("-n", NS, "get", "pods", "-l", "driver.tenstorrent.com/cr", "-o", "json")
        or "{}"
    ).get("items", [])

    if not crs:
        print("(no TenstorrentDriverPolicy resources)")
        return

    # Index pods by (cr, node)
    pod_by_node: dict[tuple[str, str], dict[str, Any]] = {}
    for p in pods:
        labels = p.get("metadata", {}).get("labels", {}) or {}
        cr_name = labels.get("driver.tenstorrent.com/cr")
        node = (p.get("spec", {}) or {}).get("nodeName")
        if cr_name and node:
            pod_by_node[(cr_name, node)] = p

    for cr in crs:
        name = cr["metadata"]["name"]
        spec = cr.get("spec", {}) or {}
        st = cr.get("status", {}) or {}
        summary = st.get("summary", {}) or {}
        paused = " [PAUSED]" if spec.get("paused") else ""

        print(f"POLICY: {name}{paused}")
        print(
            f"  version: {spec.get('version','?'):<10} "
            f"matched: {summary.get('matched',0):<3} "
            f"scheduled: {summary.get('desired',0):<3} "
            f"ready: {summary.get('ready',0):<3} "
            f"failed: {summary.get('failed',0):<3} "
            f"ds: {st.get('daemonSet','-')}"
        )

        cr_pods = sorted(
            [(node, p) for (c, node), p in pod_by_node.items() if c == name]
        )
        if not cr_pods:
            print("  (no installer pods scheduled)\n")
            continue

        print()
        print(
            f"  {'NODE':<14} {'POD':<36} {'PHASE':<10} "
            f"{'READY':<6} {'RESTARTS':<8} AGE"
        )
        for node, p in cr_pods:
            phase = p.get("status", {}).get("phase", "")
            ready = "0/1"
            for cs in p.get("status", {}).get("containerStatuses", []) or []:
                if cs.get("ready"):
                    ready = "1/1"
                    break
            restarts = sum(
                cs.get("restartCount", 0)
                for cs in p.get("status", {}).get("containerStatuses", []) or []
            )
            a = age(p.get("metadata", {}).get("creationTimestamp"))
            c, reset = color_pod(phase, ready)
            print(
                f"  {node:<14} {p['metadata']['name']:<36} "
                f"{c}{phase:<10}{reset} {ready:<6} {restarts:<8} {a}"
            )
        print()


def cmd_logs(argv: list[str]) -> None:
    if not argv:
        sys.exit("usage: kubectl tt driver logs <cr> [node]")
    cr = argv[0]
    node = argv[1] if len(argv) > 1 else None
    sel = f"driver.tenstorrent.com/cr={cr}"
    if node:
        out = kubectl(
            "-n", NS, "get", "pods", "-l", sel,
            "--field-selector", f"spec.nodeName={node}",
            "-o", "name",
        )
        pods = out.strip().split("\n")
        if not pods or not pods[0]:
            sys.exit(f"no pod for {cr} on {node}")
        kubectl_passthrough("-n", NS, "logs", "-f", pods[0], "--tail=200")
    else:
        kubectl_passthrough(
            "-n", NS, "logs", "-f", "-l", sel, "--tail=200", "--max-log-requests=8"
        )


def cmd_restart(argv: list[str]) -> None:
    if not argv:
        sys.exit("usage: kubectl tt driver restart <cr>")
    cr = argv[0]
    out = kubectl("-n", NS, "delete", "pods", "-l", f"driver.tenstorrent.com/cr={cr}")
    sys.stdout.write(out)


def main() -> None:
    argv = sys.argv[1:]
    sub = argv[0] if argv else "status"
    rest = argv[1:]
    table = {
        "status": cmd_status,
        "": cmd_status,
        "logs": cmd_logs,
        "restart": cmd_restart,
    }
    if sub in ("-h", "--help", "help"):
        print(USAGE)
        return
    fn = table.get(sub)
    if not fn:
        sys.stderr.write(f"unknown subcommand: {sub}\n{USAGE}\n")
        sys.exit(1)
    fn(rest)


if __name__ == "__main__":
    main()
