#!/usr/bin/env python3
"""
Check DADA2 stats table like (six columns: sample, input, filtered, denoise, merged, non-chimeric):
                                input   filtered   denoised  merged  non-chimeric
M0614DD20_R1.fastq.gz   1        65425  150525  150525  3228    3224
M0614DD2plus165_R1.fastq.gz     254245  225049  225049  35382   35114
M0614DD2plus45_R1.fastq.gz      296332  281027  281027  12126   12114
M0614DD3plus120_R1.fastq.gz     2879433 2706733 2706733 124381  123007
M0614DD3plus30_R1.fastq.gz      96132   89573   89573   117     113
M0614DD4min10_R1.fastq.gz       43880   39328   39328   651     651
M0614DD4plus135_R1.fastq.gz     233748  219565  219565  14805   14578
M0614DD4plus30_R1.fastq.gz      175323  164848  164848  2194    2194
M0614GD20_R1.fastq.gz   27637   23756   23756   2694    2694
M0614GD2plus165_R1.fastq.gz     10367   5095    5095    456     456
M0614GD2plus45_R1.fastq.gz      31538   28490   28490   9523    9523
M0614GD3plus120_R1.fastq.gz     644     445     445     389     389
M0614GD3plus30_R1.fastq.gz      2388    2039    2039    1997    1997
M0614GD4plus135_R1.fastq.gz     32941   27562   27562   21465   21465

"""

import os, sys, subprocess
import logging
import tempfile
import shutil
import pandas as pd
import json
from rich.logging import RichHandler

MODULE_NAME = "CheckDADA2stats"

def eprint(*args, **kwargs):
    """
    Print to stderr
    """
    print(*args, file=sys.stderr, **kwargs)
 

if __name__ == "__main__":
    import argparse
    script_dir = os.path.dirname(os.path.realpath(__file__))
    table_headers = 'input,filtered,denoised,merged,non-chimeric'
    args = argparse.ArgumentParser(description="Check DADA2 stats")
    
    # Main arguments
    main = args.add_argument_group("Main")
    main.add_argument("-i", "--input", help="DADA2 stats table", required=True)
    main.add_argument("-l", "--loss", help="Warn when loss is above this value [default: %(default)s]", type=float, default=0.33)
    main.add_argument("--sample", help="Also check sample by sample", action="store_true")
    main.add_argument("--all", help="Report loss for all the steps", action="store_true")
    main.add_argument("--keys", help="Comma separated headers [default: %(default)s]", default=table_headers)
    main.add_argument("--tmp", help="Temporary directory", default=os.environ["TMPDIR"] if "TMPDIR" in os.environ else "/tmp")
    
    
    # Misc arguments
    m = args.add_argument_group("Other parameters")
    m.add_argument("--log", help="Log file", default=None)
    m.add_argument("--verbose", help="Verbose mode", action="store_true")

    opts = args.parse_args()


    table_keys = table_headers.split(',')

    ## Set logger
    if opts.verbose:
        llevel = logging.DEBUG
    else:
        llevel = logging.WARNING

    if opts.log:
        logging.basicConfig(filename=opts.log, level=llevel, format='%(asctime)s|%(levelname)s|%(message)s') 
        
    else:
        logging.basicConfig(level=llevel, format="", handlers=[RichHandler()]) #'%(asctime)s|%(levelname)s|%(message)s'
        
    logger = logging.getLogger(MODULE_NAME)

    if not os.path.exists(opts.input):
        logger.error("Input file does not exist: {}".format(opts.input))
        sys.exit(1)
    
    stats_df = pd.read_csv(opts.input, sep="\t", index_col=0)
    
    # Sum all values by column
    stats_df_sum = stats_df.sum(axis=0)
    
    if opts.verbose:
        eprint("GLOBAL STATS\n",stats_df_sum.to_string(), sep="")
    
    #input           9798752
    #filtered        9113013
    #denoised        9113013
    #merged           443382
    #non-chimeric     440150

    errors = {
        "failed_steps": {},
        "failed_by_sample": {}
    }
    prev = 0
    total_samples = stats_df.shape[0]
    for key in table_keys:
        errors["failed_by_sample"][key] = []

        if prev > 0:
            loss = (stats_df_sum[key] ) / prev
            if loss < opts.loss:
                logger.error("{}: {:.1f}% reads kept : from  {} to {}".format(key, 100*loss,  prev, stats_df_sum[key]))
                errors["failed_steps"][key] = 100*loss
            else:
                if opts.all:
                    errors["failed_steps"][key] = 100*loss
                logger.info("{}: {:.1f}% reads kept : from  {} to {}".format(key, 100*loss,  prev, stats_df_sum[key]))
        prev = stats_df_sum[key]


    # Per sample check
    if opts.sample:
        
        for sample in stats_df.index:
            prev = 0
            for key in table_keys:
                if prev > 0:
                    loss = (stats_df.loc[sample, key] ) / prev
                    if loss < opts.loss:
                        logger.warning("{}: {:.1f}% reads kept {}: from  {} to {}".format(sample, 100*loss, key, prev, stats_df.loc[sample, key]))
                        errors["failed_by_sample"][key].append(sample)
                prev = stats_df.loc[sample, key]
        for key in table_keys:
            if len(errors["failed_by_sample"][key]) > 0:
                ratio = 100 * len(errors["failed_by_sample"][key]) / total_samples
                logger.info("{:.1f}% failed samples at <{}> ({})".format(ratio, key, len(errors["failed_by_sample"][key])))

    else:
        # Delete "failed_by_sample"
        del errors["failed_by_sample"]



    # Print errors as JSON pretty print
    if len(errors) > 0:
 
        print(json.dumps(errors, indent=4, sort_keys=False))
        

        