#!/usr/bin/env python3
"""
Run DADA2 commanding R from python3
"""

import os, sys, subprocess
import logging
from string import Template
import time
import tempfile
import shutil

dada2list = ['forward_reads', 'reverse_reads', 'feature_table_output', 'stats_output',
'filt_forward', 'filt_reverse', 'truncLenF', 'truncLenR', 'trimLeftF', 'trimLeftR', 'maxEEF', 'maxEER',
'truncQ', 'chimeraMethod', 'minFold','threads', 'nreads_learn','baseDir', 'doPlots', 'taxonomyDb', 'saveRDS', 'noMerge', 'processPool']

 
def eprint(*args, **kwargs):
    """
    Print to stderr
    """
    print(*args, file=sys.stderr, **kwargs)

def checkR():
    """
    Check if R is installed
    """
    try:
        cmd = ["R", "--version"]
        subprocess.check_output(cmd)
        return 1
    except:
        print("R is not installed. Please install R and try again.")
        return 0

def checkMissingModules(modules):
    """
    Check if R modules are installed
    """
    missing = 0
    for module in modules:
        try:
            cmd = ["R", "--slave", "--no-save", "-e", "library(%s)" % module]
            subprocess.check_output(cmd, stderr=subprocess.DEVNULL)
        except:
            eprint("R module %s is not installed. Please install R module %s and try again." % (module, module))
            missing += 1
    return missing

def initInput(dir, outdir, forwardTag="_R1", reverseTag="_R2", sampleSeparator="_", sampleExtension=".fastq.gz", copy=False):
    """
    Prepare forward and input directory from a single directory of FASTQ files.
    Will return the name of the forwardDir and reverseDir
    """
    logger.debug("Checking input directory: {} with extension={}, tags={},{}".format(dir,sampleExtension, forwardTag, reverseTag))
    
    
    # Check input directory exists and is a directory
    if not os.path.exists(dir):
        logger.error("Input directory not found: {}".format(dir))
        raise
    if not os.path.isdir(dir):
        logger.error("Input directory is not a directory: {}".format(dir))
        raise

    # Scan files in directory
    files = os.listdir(dir)
    files = [f for f in files if f.endswith(sampleExtension)]

    # Split files at sample separator and count occurrences
    samples = {}
    for f in files:
        sample = f.split(sampleSeparator)[0]
        if sample not in samples:
            samples[sample] = 1
        else:
            samples[sample] += 1
            if samples[sample] > 2:
                logger.error("Sample {} has more than two files. Last: {}".format(sample, f))
                raise
    
    forFiles = [f for f in files if forwardTag in f]
    revFiles = [f for f in files if reverseTag in f]
    if len(forFiles) == 0:
        logger.error("No forward files found in {}".format(dir))
        raise
    elif len(revFiles) != len(forFiles):
        logger.error("Different number of FOR and REV files in {}.".format(dir))
    
    
    # Create output directory
    try:
        os.makedirs(outdir) if not os.path.exists(outdir) else None
    except:
        logger.error("Could not create output directory: {}".format(outdir))
        exit(1)

    forDir = os.path.join(outdir, "forward")
    revDir = os.path.join(outdir, "reverse")

    try:
        os.makedirs(forDir)
    except:
        logger.error("Could not create output directory: {}".format(forDir))
        exit(1)
    try:
        os.makedirs(revDir)
    except:
        logger.error("Could not create output directory: {}".format(revDir))
        exit(1)


    for f in forFiles:
        if copy:
            logger.debug("{}: >> Copying {} to {}".format(sample, f, forDir))
            shutil.copy(os.path.join(dir, f), forDir)
        else:
            os.symlink(os.path.abspath(os.path.join(dir, f)), os.path.join(forDir, f))
    for f in revFiles:
        if copy:
            logger.debug("{}: << Copying {} to {}".format(sample, f, revDir))
            shutil.copy(os.path.join(dir, f), revDir)
        else:
            os.symlink(os.path.abspath(os.path.join(dir, f)), os.path.join(revDir, f))

    logger.info("Temporary input directories: {};{}".format(forDir, revDir))
    return forDir, revDir

def checkDir(dir, extension, sampleseparator):
    """
    Check a directory of files, forward or reverse
    """
    logger.debug("Checking directory: {}".format(dir))

    files = os.listdir(dir)
    files = [f for f in files if f.endswith(extension)]
    if len(files) == 0:
        logger.error("No files found in {}".format(dir))
        raise
    samples = {}
    for file in files:
        sample = file.split(sampleseparator)[0]
        if sample not in samples:
            samples[sample] = 1
        else:
            logger.error("Sample name from {} would be {}, but it's not unique.".format(file, sample))
            raise
    
    return len(files)

def saveTextToFile(text, output):
    """
    Save text to file
    """
    try:
        with open(output, "w") as f:
            f.write(text)
        return 1
    except:
        logger.error("Could not write to file {}".format(output))
        return 0

if __name__ == "__main__":
    import argparse
    script_dir = os.path.dirname(os.path.realpath(__file__))
    dada2script = os.path.join(script_dir, "D2-dada.R")

    args = argparse.ArgumentParser(description="Run DADA2")
    
    # Main arguments
    main = args.add_argument_group("Main")
    main.add_argument("-i", "--input-dir", help="Input directory with both R1 and R2")
    main.add_argument("-f", "--for-dir", help="Input directory with R1 reads")
    main.add_argument("-r", "--rev-dir", help="Input directory with R2 reads")
    main.add_argument("-o", "--output-dir", help="Output directory", required=True)
    main.add_argument("--tmp", help="Temporary directory", default=os.environ["TMPDIR"] if "TMPDIR" in os.environ else "/tmp")
    # Input selection
    inopt = args.add_argument_group("Input filtering")
    inopt.add_argument("--fortag", help="String defining a file as forward [default: _R1]", default="_R1")
    inopt.add_argument("--revertag", help="String defining a file as reverse [default: _R2]", default="_R2")
    inopt.add_argument("--sample-separator", help="String acting as samplename separator [default: _]", default="_")
    inopt.add_argument("--sample-extension", help="String acting as samplename extension [default: .fastq.gz]", default=".fastq.gz")
    

    # Add group of parameters
    dada2_params = args.add_argument_group("DADA2 parameters")
    dada2_params.add_argument("-q", "--trunc-qual" ,help="Truncate at the first occurrence of a base with Q lower [default: 8]", default=8.0)
    dada2_params.add_argument("-j", "--join" ,help="Join without merging", action="store_true")
    dada2_params.add_argument("-p", "--pool" ,help="Pool samples",action="store_true")
    
    dada2_params.add_argument("--trunc-len-1", help="Position at which to truncate forward reads [default: 0]", default=0)
    dada2_params.add_argument("--trunc-len-2", help="Position at which to truncate reverse reads [default: 0]", default=0)
    dada2_params.add_argument("--trim-left-1", help="Number of nucleotide to trim from the beginning of forward reads [default: 0]", default=0)
    dada2_params.add_argument("--trim-left-2", help="Number of nucleotide to trim from the beginning of reverse reads [default: 0]", default=0)
    dada2_params.add_argument("--maxee-1", help="Maximum expected errors in forward reads [default: 1.0]", default=1.0)
    dada2_params.add_argument("--maxee-2", help="Maximum expected errors in reverse reads [default: 1.00]", default=1.0)

    # --chimera can be set to "remove" or "split"
    dada2_params.add_argument("--chimera", help="Chimera handling can be none, pooled or consensus [default: pooled]", choices=["none", "pooled", "consensus"],default="pooled")
    dada2_params.add_argument("--min-parent-fold", help="Minimum abundance of parents of a chimeric sequence (>1.0) [default: 1.0]", default=1.0)
    dada2_params.add_argument("--n-learn", help="Number of reads to learn the model, 0 for all [default: 0]", default=0)

    
    
    
    
    # Misc arguments
    m = args.add_argument_group("Other parameters")
    m.add_argument("-t", "--threads", help="Number of threads", type=int, default=1)
    m.add_argument("--keep-temp", help="Keep temporary files", action="store_true")    
    m.add_argument("--log", help="Log file", default=None)
    m.add_argument("--copy", help="Copy input files instead of symbolic linking", action="store_true")
    m.add_argument('--skip-checks', help="Do not check installation of dependencies", action="store_true")
    m.add_argument("--verbose", help="Verbose mode", action="store_true")

    opts = args.parse_args()



    ## Set logger
    if opts.verbose:
        llevel = logging.DEBUG
    else:
        llevel = logging.WARNING

    if opts.log:
        logging.basicConfig(filename=opts.log, level=llevel, format='%(asctime)s|%(levelname)s|%(message)s')
    else:
        logging.basicConfig(level=llevel, format='%(asctime)s\t%(levelname)s\t%(message)s')
        
    logger = logging.getLogger("runDADA2")


    ## Check DADA2 Script
    if not os.path.exists(dada2script):
        logger.error("DADA2 script not found: {}".format(dada2script))
        raise

    ## Make random temporary directory
    tmpdir = tempfile.mkdtemp(dir=opts.tmp)
    logger.debug("Temporary directory: {}".format(tmpdir))
    ## Initial checks
    if not opts.skip_checks:
        if not checkR():
            eprint("R is not installed. Please install R and try again.")
            sys.exit(1)

        if checkMissingModules(["dada2"]):
            eprint("%i modules are missing" % (checkMissingModules(["dada2"])))
            sys.exit(1)
    else:
        logger.info("Skipping dependency checks")
        
    
    # Input directory
    if opts.input_dir and (opts.for_dir or opts.rev_dir):
        logger.error("Specify either input directory or forward/reverse directories.")
        exit(1)
    
    if opts.input_dir:
        forDirectory, revDirectory = initInput(opts.input_dir, tmpdir, opts.fortag, opts.revertag, opts.sample_separator, opts.sample_extension, opts.copy)
    else:
        f = checkDir(opts.for_dir, opts.sample_extension, opts.sample_separator)
        r = checkDir(opts.rev_dir, opts.sample_extension, opts.sample_separator)
        if f == 0:
            logger.error("Input directory does not found/valid: {}".format(opts.for_dir))
            exit(1)
        if r == 0:
            logger.error("Input directory does not found/valid: {}".format(opts.rev_dir))
            exit(1)
        if f != r:
            logger.error("Number of forward and reverse files are not equal: {} vs {}".format(f, r))
            exit(1)
        forDirectory = opts.for_dir
        revDirectory = opts.rev_dir

    forFiltered = os.path.join(forDirectory, "filtered")
    revFiltered = os.path.join(revDirectory, "filtered")
    try:
        os.makedirs(forFiltered)
        os.makedirs(revFiltered)
    except:
        raise
    tsvFile = os.path.join(tmpdir, "dada2.tsv")
    statsFile = os.path.join(tmpdir, "dada2.stats")

    joinPairs = 1 if opts.join else 0
    poolSamples = 1 if opts.pool else 0

    dada2args = [
        forDirectory,   # 1
        revDirectory,   # 2
        tsvFile,        # 3
        statsFile,      # 4
        forFiltered,    # 5
        revFiltered,    # 6
        opts.trunc_len_1, # 7
        opts.trunc_len_2, # 8
        opts.trim_left_1, # 9
        opts.trim_left_2, # 10
        opts.maxee_1,    # 11
        opts.maxee_2,    # 12
        opts.trunc_qual, # 13
        opts.chimera,    # 14
        opts.min_parent_fold, # 15
        opts.threads,    # 16
        opts.n_learn,    # 17
        tmpdir, # 18
        'do_plots',     # 19
        'skip',         # 20   taxonomy
        'save',         # 21save RDS
        joinPairs,
        poolSamples
    ]

    cmd = ["Rscript", "--vanilla", "--no-save", dada2script] + dada2args
    # Convert cmd elements to string
    cmd = [str(x) for x in cmd]
    logger.info("Running DADA2 wrapper")

    # Execute cmd keeping stdout and stderr in two variables
    try:
        stdout, stderr = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate()
        # Save STDOUT to log
        saveTextToFile(stdout.decode("utf-8"), os.path.join(tmpdir, "dada2.execution.log"))
        saveTextToFile(stderr.decode("utf-8"), os.path.join(tmpdir, "dada2.execution.txt"))    
    except Exception as e:
        logger.error("Error running DADA2 wrapper: {}".format(" ".join(cmd)))
        logger.error("Exception: {}".format(e))
        raise

    # Move files and cleanup
    if not os.path.isdir(opts.output_dir):
        try:
            os.makedirs(opts.output_dir)
        except:
            logger.error("Error creating output directory: {}".format(opts.out_dir))
            logger.info("Output files are in temp dir: " + tmpdir)
            raise
    
    for file in ["dada2.stats", "dada2.tsv", "quality_R1.pdf", "quality_R2.pdf", "dada2.rds", "dada2.execution.txt", "dada2.execution.log"]:
        if os.path.exists(os.path.join(tmpdir, file)):
            logger.debug("Moving file: {}".format(file))
            shutil.move(os.path.join(tmpdir, file), os.path.join(opts.output_dir, file))
        else:
            logger.warning("File not found: {}".format(os.path.join(tmpdir, file)))

    
    # Remove temp dir
    if not opts.keep_temp:
        logger.info("Removing temporary directory: {}".format(tmpdir))
        shutil.rmtree(tmpdir)