#!/usr/bin/env python3
"""Unified HYMET command-line interface."""

from __future__ import annotations

import argparse
import os
import subprocess
import sys
from pathlib import Path
from typing import List, Optional, Sequence


def run_command(cmd: List[str], *, cwd: Optional[Path] = None, env: Optional[dict] = None, dry_run: bool = False) -> None:
    display = " ".join(str(part) for part in cmd)
    location = str(cwd) if cwd else os.getcwd()
    print(f"[hymet] ({location}) $ {display}")
    if dry_run:
        return
    subprocess.run(cmd, cwd=str(cwd) if cwd else None, env=env, check=True)


def add_common_env(args, env: dict) -> None:
    if getattr(args, "cache_root", None):
        env["CACHE_ROOT"] = str(Path(args.cache_root).resolve())
    if getattr(args, "force_download", False):
        env["FORCE_DOWNLOAD"] = "1"
    if getattr(args, "keep_work", False):
        env["KEEP_HYMET_WORK"] = "1"


def command_run(args) -> None:
    repo_root: Path = args.repo_root
    env = os.environ.copy()
    env["ROOT"] = str(repo_root)
    if args.contigs:
        env["INPUT_MODE"] = "contigs"
        env["INPUT_FASTA"] = str(Path(args.contigs).resolve())
    else:
        env["INPUT_MODE"] = "reads"
        env["INPUT_READS"] = str(Path(args.reads).resolve())
    env["OUTDIR"] = str(Path(args.out).resolve())
    if args.threads is not None:
        env["THREADS"] = str(args.threads)
    if args.cand_max is not None:
        env["CAND_MAX"] = str(args.cand_max)
    if args.species_dedup:
        env["SPECIES_DEDUP"] = "1"
    if args.assembly_summary_dir:
        env["ASSEMBLY_SUMMARY_DIR"] = str(Path(args.assembly_summary_dir).resolve())
    add_common_env(args, env)
    cmd = [str(repo_root / "run_hymet_cami.sh")]
    run_command(cmd, cwd=repo_root, env=env, dry_run=args.dry_run)


def command_bench(args) -> None:
    bench_root: Path = args.bench_root
    env = os.environ.copy()
    env["BENCH_CALLER_PWD"] = str(Path.cwd().resolve())
    if args.threads is not None:
        env["THREADS"] = str(args.threads)
    add_common_env(args, env)

    cmd = [str(bench_root / "run_all_cami.sh")]
    if args.manifest:
        cmd.extend(["--manifest", str(Path(args.manifest).resolve())])
    if args.tools:
        cmd.extend(["--tools", args.tools])
    if args.max_samples is not None:
        cmd.extend(["--max-samples", str(args.max_samples)])
    if args.no_build:
        cmd.append("--no-build")
    if args.resume:
        cmd.append("--resume")
    if getattr(args, "no_publish", False):
        cmd.append("--no-publish")
    if args.extra:
        cmd.extend(args.extra)
    run_command(cmd, cwd=bench_root, env=env, dry_run=args.dry_run)


def command_case(args) -> None:
    case_root: Path = args.case_root
    env = os.environ.copy()
    env["CASE_CALLER_PWD"] = str(Path.cwd().resolve())
    if args.threads is not None:
        env["THREADS"] = str(args.threads)
    add_common_env(args, env)

    cmd = [str(case_root / "run_case.sh")]
    if args.manifest:
        cmd.extend(["--manifest", str(Path(args.manifest).resolve())])
    if args.out:
        cmd.extend(["--out", str(Path(args.out).resolve())])
    if args.extra:
        cmd.extend(args.extra)
    run_command(cmd, cwd=case_root, env=env, dry_run=args.dry_run)


def command_ablation(args) -> None:
    case_root: Path = args.case_root
    env = os.environ.copy()
    env["CASE_CALLER_PWD"] = str(Path.cwd().resolve())
    if args.threads is not None:
        env["THREADS"] = str(args.threads)
    add_common_env(args, env)

    cmd = [str(case_root / "run_ablation.sh")]
    if args.sample:
        cmd.extend(["--sample", args.sample])
    if args.taxa:
        cmd.extend(["--taxa", args.taxa])
    if args.levels:
        cmd.extend(["--levels", args.levels])
    if args.seqmap:
        cmd.extend(["--seqmap", str(Path(args.seqmap).resolve())])
    if args.fasta:
        cmd.extend(["--fasta", str(Path(args.fasta).resolve())])
    if args.out:
        cmd.extend(["--out", str(Path(args.out).resolve())])
    if args.extra:
        cmd.extend(args.extra)
    run_command(cmd, cwd=case_root, env=env, dry_run=args.dry_run)


def command_truth_build_zymo(args) -> None:
    case_root: Path = args.case_root
    env = os.environ.copy()
    seqmap = args.seqmap or case_root / "truth" / "zymo_refs" / "seqid2taxid.tsv"
    cmd = [
        "python",
        str(case_root / "truth" / "build_zymo_truth.py"),
        "--contigs",
        str(Path(args.contigs).resolve()),
        "--seqmap",
        str(Path(seqmap).resolve()),
        "--paf",
        str(Path(args.paf).resolve()),
        "--out-contigs",
        str(Path(args.out_contigs).resolve()),
        "--out-profile",
        str(Path(args.out_profile).resolve()),
    ]
    run_command(cmd, cwd=case_root, env=env, dry_run=args.dry_run)


def command_legacy(args) -> None:
    repo_root: Path = args.repo_root
    cmd = ["perl", str(repo_root / "main.pl"), *args.legacy_args]
    run_command(cmd, cwd=repo_root, env=None, dry_run=args.dry_run)


def command_artifacts(args) -> None:
    python_cmd = sys.executable or "python"
    bench_root: Path = args.bench_root
    case_root: Path = args.case_root

    bench_out = bench_root / "out"
    if bench_out.is_dir():
        run_command(
            [python_cmd, str(bench_root / "aggregate_metrics.py"), "--bench-root", str(bench_root), "--outdir", "out"],
            cwd=bench_root,
            dry_run=args.dry_run,
        )
        run_command(
            [python_cmd, str(bench_root / "plot" / "make_figures.py"), "--bench-root", str(bench_root), "--outdir", "out"],
            cwd=bench_root,
            dry_run=args.dry_run,
        )
    else:
        print("[hymet] bench/out/ missing; skipping CAMI aggregation")

    case_out = case_root / "out"
    case_cmd: Optional[List[str]] = None
    if case_out.is_dir():
        case_cmd = [python_cmd, str(case_root / "plot_case.py")]
    else:
        results_root = case_root.parent / "results" / "cases"
        candidates = []
        if results_root.is_dir():
            for suite_dir in results_root.iterdir():
                if not suite_dir.is_dir():
                    continue
                for run_dir in suite_dir.iterdir():
                    if run_dir.is_dir() and run_dir.name.startswith("run_"):
                        candidates.append(run_dir)
        if candidates:
            latest = sorted(candidates, key=lambda p: p.name)[-1]
            raw_dir = latest / "raw"
            figures_dir = latest / "figures"
            if raw_dir.is_dir():
                case_cmd = [
                    python_cmd,
                    str(case_root / "plot_case.py"),
                    "--case-root",
                    str(raw_dir),
                    "--figures-dir",
                    str(figures_dir),
                ]
    if case_cmd:
        run_command(case_cmd, cwd=case_root, dry_run=args.dry_run)
    else:
        print("[hymet] No case-study outputs detected; skipping case-study figures")

    summary_path = case_root / "ablation_summary.tsv"
    if summary_path.is_file():
        eval_path = case_root / "ablation_eval_summary.tsv"
        outdir = case_root / "ablation" / "figures"
        cmd = [
            python_cmd,
            str(case_root / "plot_ablation.py"),
            "--summary",
            str(summary_path),
            "--outdir",
            str(outdir),
        ]
        if eval_path.is_file():
            cmd.extend(["--eval", str(eval_path)])
        run_command(cmd, cwd=case_root, dry_run=args.dry_run)
    else:
        print("[hymet] case/ablation_summary.tsv missing; skipping ablation figures")


def add_common_arguments(parser: argparse.ArgumentParser) -> None:
    parser.add_argument("--threads", type=int, help="Thread count to pass to HYMET")
    parser.add_argument("--cache-root", help="Override cache root (CACHE_ROOT)")
    parser.add_argument("--force-download", action="store_true", help="Set FORCE_DOWNLOAD=1 for HYMET runs")
    parser.add_argument("--keep-work", action="store_true", help="Set KEEP_HYMET_WORK=1 to retain intermediates")
    parser.add_argument("--dry-run", action="store_true", help="Show commands without executing them")


def git_version(repo_root: Path) -> str:
    try:
        # Prefer git metadata when available (source checkout)
        commit = subprocess.run(
            ["git", "-C", str(repo_root), "rev-parse", "--short", "HEAD"],
            capture_output=True, text=True, check=False,
        ).stdout.strip()
        dirty = subprocess.run(
            ["git", "-C", str(repo_root), "status", "--porcelain"],
            capture_output=True, text=True, check=False,
        ).stdout.strip()
        if commit:
            return f"{commit}{'-dirty' if dirty else ''}"
    except Exception:
        pass
    # Fallback to environment or unknown
    env_ver = os.environ.get("HYMET_VERSION", "").strip()
    return env_ver or "unknown"


def command_version(args) -> None:
    repo_root: Path = getattr(args, "repo_root", Path.cwd())
    print(git_version(repo_root))


def command_init(args) -> None:
    """Initialize HYMET environment: create required stub files and verify installation."""
    repo_root: Path = args.repo_root
    data_dir = repo_root / "data"
    taxonomy_dir = repo_root / "taxonomy_files"

    # Ensure data directory exists
    data_dir.mkdir(parents=True, exist_ok=True)

    # Auto-setup taxonomy if missing (unless --skip-taxonomy)
    hierarchy = data_dir / "taxonomy_hierarchy.tsv"
    needs_taxonomy = (
        not taxonomy_dir.is_dir() or
        not (taxonomy_dir / "nodes.dmp").exists() or
        not hierarchy.exists() or
        (hierarchy.exists() and hierarchy.stat().st_size == 0)
    )

    if needs_taxonomy and not getattr(args, "skip_taxonomy", False) and not args.quiet:
        config_pl = repo_root / "config.pl"
        if config_pl.exists():
            print("[hymet init] Taxonomy files missing. Setting up from NCBI...")
            print("  (this downloads ~60MB and generates the hierarchy file)")
            print(flush=True)
            try:
                subprocess.run(["perl", str(config_pl)], cwd=repo_root, check=True)
                print()
            except subprocess.CalledProcessError as e:
                print(f"[hymet init] Warning: config.pl failed (exit {e.returncode})")
                print("  You can retry with: ./config.pl")
                print()
            except FileNotFoundError:
                print("[hymet init] Warning: perl not found. Install perl or run ./config.pl manually.")
                print()

    issues = []
    created = []
    verified = []

    # Create stub detailed_taxonomy.tsv if missing
    # This file gets overwritten during runtime by downloadDB.py with actual taxonomy data
    detailed_tax = data_dir / "detailed_taxonomy.tsv"
    if not detailed_tax.exists() or detailed_tax.stat().st_size == 0:
        detailed_tax.write_text("GCF\tTaxID\tIdentifiers\n", encoding="utf-8")
        created.append(f"data/detailed_taxonomy.tsv (stub with header)")
    elif detailed_tax.is_symlink():
        # Check if symlink target exists
        if not detailed_tax.resolve().exists():
            detailed_tax.unlink()
            detailed_tax.write_text("GCF\tTaxID\tIdentifiers\n", encoding="utf-8")
            created.append(f"data/detailed_taxonomy.tsv (replaced broken symlink)")
        else:
            verified.append("data/detailed_taxonomy.tsv")
    else:
        verified.append("data/detailed_taxonomy.tsv")

    # Check for required Mash sketches
    sketches = ["sketch1.msh", "sketch2.msh", "sketch3.msh"]
    for sketch in sketches:
        sketch_path = data_dir / sketch
        if sketch_path.exists() and sketch_path.stat().st_size > 0:
            verified.append(f"data/{sketch}")
        else:
            issues.append(f"data/{sketch} missing - run: tools/fetch_sketches.sh")

    # Check taxonomy_hierarchy.tsv (re-check after potential auto-setup)
    if hierarchy.exists() and hierarchy.stat().st_size > 0:
        verified.append("data/taxonomy_hierarchy.tsv")
    else:
        issues.append("data/taxonomy_hierarchy.tsv missing - run: ./config.pl")

    # Check taxonomy_files directory
    if taxonomy_dir.is_dir():
        nodes = taxonomy_dir / "nodes.dmp"
        names = taxonomy_dir / "names.dmp"
        if nodes.exists() and names.exists():
            verified.append("taxonomy_files/ (nodes.dmp, names.dmp)")
        else:
            issues.append("taxonomy_files/ incomplete - run: ./config.pl")
    else:
        issues.append("taxonomy_files/ missing - run: ./config.pl")

    # Check required scripts
    scripts = [
        "scripts/mash.sh",
        "scripts/minimap2.sh",
        "scripts/classification_cami.py",
        "tools/hymet2cami.py",
        "tools/build_id_map.py",
        "tools/mini_classify.py",
    ]
    for script in scripts:
        script_path = repo_root / script
        if script_path.exists():
            verified.append(script)
        else:
            issues.append(f"{script} missing")

    # Print results
    print(f"[hymet init] Repository: {repo_root}")
    print()

    if created:
        print("Created:")
        for item in created:
            print(f"  ✓ {item}")
        print()

    if verified:
        print("Verified:")
        for item in verified:
            print(f"  ✓ {item}")
        print()

    if issues:
        print("Issues requiring attention:")
        for issue in issues:
            print(f"  ✗ {issue}")
        print()
        print("Run 'bin/hymet init' again after resolving issues.")
        if not args.quiet:
            raise SystemExit(1)
    else:
        print("✓ HYMET environment is ready. Run: bin/hymet run --contigs <file> --out <dir>")


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="HYMET unified command-line interface")
    parser.add_argument(
        "--hymet-root",
        help="Path to the HYMET repository. Defaults to HYMET_ROOT env or auto-discovery.",
    )
    subparsers = parser.add_subparsers(dest="command", required=True)

    ver_parser = subparsers.add_parser("version", help="Print HYMET version information")
    ver_parser.set_defaults(func=command_version)

    init_parser = subparsers.add_parser("init", help="Initialize HYMET: create stub files and verify installation")
    init_parser.add_argument("--quiet", "-q", action="store_true", help="Don't exit with error on missing files")
    init_parser.add_argument("--skip-taxonomy", action="store_true", help="Skip automatic NCBI taxonomy download")
    init_parser.set_defaults(func=command_init)

    run_parser = subparsers.add_parser("run", help="Run HYMET on a single sample")
    input_group = run_parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument("--contigs", help="Input contigs FASTA")
    input_group.add_argument("--reads", help="Input reads FASTQ/FASTA")
    run_parser.add_argument("--out", required=True, help="Output directory")
    run_parser.add_argument("--cand-max", type=int, help="Maximum Mash candidates (CAND_MAX)")
    run_parser.add_argument("--species-dedup", action="store_true", help="Enable species-level candidate deduplication")
    run_parser.add_argument("--assembly-summary-dir", help="Directory holding assembly_summary files")
    add_common_arguments(run_parser)
    run_parser.set_defaults(func=command_run)

    bench_parser = subparsers.add_parser("bench", help="Run the CAMI benchmark harness")
    bench_parser.add_argument("--manifest", help="Manifest TSV (default bench/cami_manifest.tsv)")
    bench_parser.add_argument("--tools", help="Comma-separated tool list")
    bench_parser.add_argument("--max-samples", type=int, help="Limit number of samples processed")
    bench_parser.add_argument("--no-build", action="store_true", help="Skip database build step")
    bench_parser.add_argument("--resume", action="store_true", help="Resume without clearing runtime log")
    bench_parser.add_argument("--no-publish", action="store_true", help="Skip publishing under results/")
    bench_parser.add_argument("extra", nargs=argparse.REMAINDER, help="Extra args forwarded to run_all_cami.sh")
    add_common_arguments(bench_parser)
    bench_parser.set_defaults(func=command_bench)

    case_parser = subparsers.add_parser("case", help="Run the case-study harness")
    case_parser.add_argument("--manifest", help="Manifest TSV (default case/manifest.tsv)")
    case_parser.add_argument("--out", help="Output root directory")
    case_parser.add_argument("extra", nargs=argparse.REMAINDER, help="Extra args forwarded to run_case.sh")
    add_common_arguments(case_parser)
    case_parser.set_defaults(func=command_case)

    ablation_parser = subparsers.add_parser("ablation", help="Run the curated reference ablation workflow")
    ablation_parser.add_argument("--sample", help="Sample ID to ablate")
    ablation_parser.add_argument("--taxa", help="Comma-separated TaxIDs to remove at each level")
    ablation_parser.add_argument("--levels", help="Comma-separated ablation fractions (e.g. 0,0.5,1.0)")
    ablation_parser.add_argument("--seqmap", help="Sequence-to-taxid map")
    ablation_parser.add_argument("--fasta", help="Reference FASTA to ablate")
    ablation_parser.add_argument("--out", help="Output directory for ablation results")
    ablation_parser.add_argument("extra", nargs=argparse.REMAINDER, help="Extra args forwarded to run_ablation.sh")
    add_common_arguments(ablation_parser)
    ablation_parser.set_defaults(func=command_ablation)

    truth_parser = subparsers.add_parser("truth", help="Truth-set utilities")
    truth_sub = truth_parser.add_subparsers(dest="truth_command", required=True)
    truth_zymo = truth_sub.add_parser("build-zymo", help="Build Zymo mock community truth tables")
    truth_zymo.add_argument("--contigs", required=True, help="Input contigs FASTA")
    truth_zymo.add_argument("--paf", required=True, help="PAF alignment against curated references")
    truth_zymo.add_argument("--seqmap", help="SeqID→TaxID map (default case/truth/zymo_refs/seqid2taxid.tsv)")
    truth_zymo.add_argument("--out-contigs", required=True, help="Output contig truth TSV")
    truth_zymo.add_argument("--out-profile", required=True, help="Output CAMI profile TSV")
    truth_zymo.add_argument("--dry-run", action="store_true", help="Show command without executing it")
    truth_zymo.set_defaults(func=command_truth_build_zymo)

    legacy_parser = subparsers.add_parser("legacy", help="Bridge to legacy entry points (main.pl)")
    legacy_parser.add_argument("legacy_args", nargs=argparse.REMAINDER, help="Arguments passed to main.pl")
    legacy_parser.add_argument("--dry-run", action="store_true", help="Show command without executing it")
    legacy_parser.set_defaults(func=command_legacy)

    artifacts_parser = subparsers.add_parser("artifacts", help="Regenerate benchmark and case-study artefacts")
    artifacts_parser.add_argument("--dry-run", action="store_true", help="Show commands without executing them")
    artifacts_parser.set_defaults(func=command_artifacts)

    return parser


def locate_repo_root(explicit: Optional[str]) -> Path:
    candidates: List[Path] = []
    if explicit:
        candidates.append(Path(explicit))
    env_root = os.environ.get("HYMET_ROOT")
    if env_root:
        candidates.append(Path(env_root))
    script_dir = Path(__file__).resolve().parent
    candidates.append(script_dir)
    candidates.append(script_dir.parent)
    cwd = Path.cwd()
    candidates.append(cwd)
    candidates.extend(cwd.parents)

    seen = set()
    for candidate in candidates:
        if candidate is None:
            continue
        candidate = candidate.resolve()
        if candidate in seen:
            continue
        seen.add(candidate)
        bench_marker = candidate / "bench" / "run_all_cami.sh"
        case_marker = candidate / "case" / "run_case.sh"
        if bench_marker.exists() and case_marker.exists():
            return candidate
    raise SystemExit(
        "Unable to locate HYMET repository root. Set --hymet-root or HYMET_ROOT to the cloned repository path."
    )


def main(argv: Optional[Sequence[str]] = None) -> int:
    parser = build_parser()
    args = parser.parse_args(argv)
    repo_root = locate_repo_root(args.hymet_root)
    args.repo_root = repo_root
    args.bench_root = repo_root / "bench"
    args.case_root = repo_root / "case"
    try:
        args.func(args)
    except subprocess.CalledProcessError as exc:
        return exc.returncode
    return 0


if __name__ == "__main__":
    sys.exit(main())
