#!/usr/bin/env python3

import tempfile
import pandas as pd
import json


def read_checksum(checksum_file):
    with open(checksum_file, "rt") as f:
        return f.readline().split()[0]


def get_stats_params(wildcards, input):
    stats_content = pd.read_csv(input.stats_file, sep="\t")
    stats_dict = stats_content.to_dict(orient="records")[0]

    raw_stats_content = pd.read_csv(input.raw_stats_file, sep="\t", header=None)
    raw_stats_dict = raw_stats_content.set_index(0)[1].to_dict()

    qc_reads_removed = raw_stats_dict.get(
        "raw total sequences:", None
    ) - stats_dict.get("num_seqs", None)
    qc_bases_removed = raw_stats_dict.get("total length:", None) - stats_dict.get(
        "sum_len", None
    )

    reads_out_name = Path(input.reads_out).name
    md5sum = read_checksum(input.md5sum)
    sha256sum = read_checksum(input.sha256sum)

    return {
        "base_count": int(stats_dict.get("sum_len", None)),
        "read_count": int(stats_dict.get("num_seqs", None)),
        "qc_bases_removed": int(qc_bases_removed),
        "qc_reads_removed": int(qc_reads_removed),
        "mean_gc_content": float(stats_dict.get("GC(%)", None)),
        "n50_length": int(stats_dict.get("N50", None)),
        "checksums": {reads_out_name: {"sha256": str(sha256sum), "md5": str(md5sum)}},
    }


globals().update(config)
workingdir = tempfile.mkdtemp()
logger.debug(f"Using {workingdir} for intermediate files")

if not logs_directory:
    logger.debug(f"Not keeping logs")
    logs_directory = workingdir
else:
    logger.debug(f"Saving logs to {logs_directory}")


rule output_json_stats:
    input:
        stats_file=Path(workingdir, "stats.txt"),
        raw_stats_file=Path(workingdir, "samtools_stats.txt"),
        md5sum=Path(reads_out.as_posix() + ".md5"),
        sha256sum=Path(reads_out.as_posix() + ".sha256"),
        reads_out=reads_out,
        stats_schema=stats_schema,
    output:
        stats_json=stats,
    log:
        Path(logs_directory, "output_json_stats.log"),
    benchmark:
        Path(logs_directory, "output_json_stats.benchmark.txt")
    params:
        stats=get_stats_params,
    script:
        "scripts/output_json_stats.py"


rule checksum:
    input:
        Path("{file}"),
    output:
        Path("{file}.{checksum}"),
    wildcard_constraints:
        checksum="|".join(["md5", "sha256"]),
    shell:
        "{wildcards.checksum}sum " "{input} " "> {output}"


rule seqkit_read_length_plot:
    input:
        trimmed_fastq=reads_out,
    output:
        plot_file=stats.with_suffix(".read_lengths.png"),
    log:
        Path(logs_directory, "seqkit_read_length_plot.log"),
    benchmark:
        Path(logs_directory, "seqkit_read_length_plot.benchmark.txt")
    shell:
        "seqkit watch --fields ReadLen {input.trimmed_fastq} -O {output.plot_file}"


rule seqkit_stats:
    input:
        trimmed_fastq=reads_out,
    output:
        stats_file=temp(Path(workingdir, "stats.txt")),
    log:
        Path(logs_directory, "seqkit_stats.log"),
    benchmark:
        Path(logs_directory, "seqkit_stats.benchmark.txt")
    shell:
        "seqkit stats -a -T {input.trimmed_fastq} > {output.stats_file}"


rule pigz:
    input:
        trimmed_fastq=Path(workingdir, "trimmed_file.fq"),
    output:
        reads_out=reads_out,
    log:
        Path(logs_directory, "pigz.log"),
    benchmark:
        Path(logs_directory, "pigz.benchmark.txt")
    threads: 3
    shell:
        "pigz --fast -p {threads} -c < {input.trimmed_fastq} > {output.reads_out}"


rule cutadapt:
    input:
        fastq=Path(workingdir, "file.fq"),
        pacbio_adapters=pacbio_adapters,
    output:
        trimmed_fastq=pipe(Path(workingdir, "trimmed_file.fq")),
        cutadapt_json=temp(Path(workingdir, "cutadapt.json")),
    log:
        Path(logs_directory, "cutadapt.log"),
    benchmark:
        Path(logs_directory, "cutadapt.benchmark.txt")
    threads: 3
    resources:
        mem="4GB",
    params:
        revcomp=lambda wildcards: "--revcomp" if revcomp else "",
        match_read_wildcards=lambda wildcards: (
            "--match-read-wildcards" if match_read_wildcards else ""
        ),
        discard_trimmed=lambda wildcards: "--discard-trimmed" if discard_trimmed else "",
        error_rate=error_rate,
        overlap=overlap,
        min_length=min_length,
    shell:
        "cutadapt "
        "--cores {threads} "
        "--anywhere 'file:{input.pacbio_adapters}' "
        "--minimum-length {params.min_length} "
        "--error-rate {params.error_rate} "
        "--overlap {params.overlap} "
        "{params.revcomp} "
        "{params.match_read_wildcards} "
        "{params.discard_trimmed} "
        "--json {output.cutadapt_json} "
        "<( cat {input.fastq} ) "
        "> {output.trimmed_fastq} "
        "2> {log}"


rule samtools_fastq:
    input:
        filtered_bam=Path(workingdir, "filtered.bam"),
    output:
        fastq=pipe(Path(workingdir, "file.fq")),
    log:
        Path(logs_directory, "samtools_fastq.log"),
    benchmark:
        Path(logs_directory, "samtools_fastq.benchmark.txt")
    resources:
        mem="4GB",
    shell:
        "samtools fastq " "{input.filtered_bam} " "> {output.fastq} " "2>{log}"


##bamtools filter on raw file
rule bamtools_filter:
    input:
        bam=bam,
    output:
        filtered_bam=pipe(Path(workingdir, "filtered.bam")),
    log:
        Path(logs_directory, "bamtools_filter.log"),
    benchmark:
        Path(logs_directory, "bamtools_filter.benchmark.txt")
    resources:
        mem="4GB",
    shell:
        "samtools view "
        "-be '[rq]>=0.99' "
        "{input.bam} "
        "> {output.filtered_bam} "
        "2>{log}"


## stats on raw file
rule samtools_stats:
    input:
        bam=bam,
    output:
        stats=temp(Path(workingdir, "samtools_stats.txt")),
    log:
        Path(logs_directory, "samtools_stats.log"),
    benchmark:
        Path(logs_directory, "samtools_stats.benchmark.txt")
    resources:
        mem="4GB",
    shell:
        "samtools stats {input.bam} | grep ^SN | cut -f 2-3 "
        "> {output.stats} "
        "2>{log}"


rule target:
    default_target: True
    input:
        rules.pigz.output,
        rules.seqkit_read_length_plot.output,
        rules.output_json_stats.output,
