#!/usr/bin/env python

version = '1.3'

# Main script / control is at the BOTTOM of this file

# Imports from Python Standard Library
import argparse
import shutil
import sys
import subprocess
import re
import tempfile
import os
from collections import Counter

def ShortTracks_args(version):
    parser = argparse.ArgumentParser(prog='ShortTracks')
    parser.add_argument('--version', action='version',
    version='%(prog)s {0}'.format(version))
    
    parser.add_argument('--mode', 
    choices = ['simple', 'readgroup', 'readlength'],
    default='simple',
    help='Parsing mode: simple, readgroup, or readlength')

    parser.add_argument('--stranded',
    action='store_true',
    help='If set, stranded splits + and - coverage into' +
    'separate tracks')

    parser.add_argument('--bamfile',
    required=True,
    help='Path to bamfile. The corresponding .bai or .csi file must also be there.')

    args = parser.parse_args()

    return args

def check_executables():
    rqd = ['samtools', 'wigToBigWig']
    for excbl in rqd:
        expath = shutil.which(excbl)
        if expath is not None:
            print('Required executable {0} : {1}'.format(excbl, expath))
        else:
            msg = 'Required executable {0} : Not found!'.format(excbl)
            sys.exit(msg)

def get_read_groups(args):
    bam = args.bamfile
    rgs = []
    rg = subprocess.run(f'samtools view -H {bam} | grep \'@RG\'', 
    shell=True, text=True, capture_output=True)
    rgpattern = re.compile(r'ID:(\S+)$')
    rglines = rg.stdout.split('\n')
    for rgline in rglines:
        if rgline:
            if rgpattern.search(rgline):
                this_rg = rgpattern.search(rgline).group(1)
                rgs.append(this_rg)
    return rgs

# Main
if __name__ == '__main__':
    args = ShortTracks_args(version)

    # Talk to user
    print(f'ShortTracks version {version}')
    print('')

    # Check for required executables.
    check_executables()

    # buckets
    cmds = [] # commands to be executed
    d_files = [] # raw depth files that will be made
    counts = [] # denominators for rpm calculations

    # Begin !

    # We need a file of chrom names / lengths for bigwig creation
    # Pull chrom names and lengths from the header of the BAM file
    chr_sizes = {} # need this in memory during writing of wiggle files later
    sq_cmd = f'samtools view -H {args.bamfile} | grep "@SQ"'
    sq_job = subprocess.run(sq_cmd, shell=True, capture_output=True, text=True)
    sq_lines = sq_job.stdout.split('\n')
    #sq_file = 'chromSizes.txt'
    with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp:
        print(f"Temporary file for chrom sizes: {tmp.name}") 
        #sqfh = open(sq_file, "w")
        sn_pattern = re.compile('SN:(.*)$')
        ln_pattern = re.compile('LN:(.*)$')
        for sq_line in sq_lines:
            if sq_line:
                fields = sq_line.split('\t')
                # No guaranteed order of SN and LN in @SQ lines, so we
                # . need to search each field
                sn = ''
                ln = ''
                for field in fields:
                    if sn_pattern.match(field):
                        sn = sn_pattern.match(field).group(1)
                    if ln_pattern.match(field):
                        ln = ln_pattern.match(field).group(1)
                #sqfh.write(sn)
                tmp.write(sn)
                #sqfh.write('\t')
                tmp.write('\t')
                #sqfh.write(ln)
                tmp.write(ln)
                #sqfh.write('\n')
                tmp.write('\n')
                chr_sizes[sn] = int(ln)
        #sqfh.close()

    print('')
    print('Preparing commands, counting overall reads')

    # Need to check bam_type
    bam_check_job = subprocess.run(f'samtools view {args.bamfile} | head -n 1', shell=True, text=True, capture_output=True)
    samline = bam_check_job.stdout.split('\n')[0]
    sf = samline.split('\t')
    cd_pat = r'_Cd\d+_\d+'
    xw_pat = r'\tXW:i:(\d+)'
    cd_match = re.search(cd_pat, sf[0])
    xw_match = re.search(xw_pat, samline)
    if (xw_match) and (cd_match):
        bam_type = 'condensed'
    else:
        bam_type = 'verbose'


    if args.mode == 'simple':
        # get counts
        if bam_type == 'verbose':
            count_job = subprocess.run(f'samtools view -c -F 256 {args.bamfile}',
            shell=True, text=True, capture_output=True)
            count = int(count_job.stdout.split('\n')[0])
        else:
            count_job = subprocess.run(f'samtools view -F 256 {args.bamfile}',
            shell=True, text=True, capture_output=True)
            count = 0
            for samline in count_job.stdout.splitlines():
                xw_match = re.search(xw_pat, samline)
                if(xw_match):
                    count += int(xw_match.group(1))
                else:
                    print('Fatal parse error to count reads, no XW found')
                    sys.exit()

        if args.stranded == True:
            f1 = args.bamfile.replace('.bam', '_p_depths.txt')
            d_files.append(f1)
            counts.append(count)
            if bam_type == 'verbose':
                c1 = (f'samtools view -F 272 -h {args.bamfile} | ' +
                f'samtools depth - > {f1}')
            else:
                c1 = (f'samtools view -F 276 {args.bamfile}')
            cmds.append(c1)

            f2 = args.bamfile.replace('.bam', '_m_depths.txt')
            d_files.append(f2)
            counts.append(count)
            if bam_type == 'verbose':
                c2 = (f'samtools view -F 256 -f 16 -h {args.bamfile} | ' +
                f'samtools depth - > {f2}')
            else:
                c2 = (f'samtools view -F 260 -f 16 {args.bamfile}')
            cmds.append(c2)
        else:
            # simple, not-stranded
            f1 = args.bamfile.replace('.bam', '_depths.txt')
            d_files.append(f1)
            counts.append(count)
            if bam_type == 'verbose':
                c1 = (f'samtools view -F 256 -h {args.bamfile} | ' +
                f'samtools depth - > {f1}')
            else:
                c1 = (f'samtools view -F 260 {args.bamfile}')
            cmds.append(c1)

    elif args.mode == 'readgroup':
        # get read groups or die tryin
        rgs = get_read_groups(args)
        if not rgs:
            sys.exit('No readgroups found in bamfile!')
        for rg in rgs:
            if bam_type == 'verbose':
                rg_c_cmd = f'samtools view -c -F 256 -r {rg} {args.bamfile}'
                rg_c_job = subprocess.run(rg_c_cmd, shell=True,
                text=True, capture_output=True)
                count = int(rg_c_job.stdout.split('\n')[0])
            else:
                count_job = subprocess.run(f'samtools view -F 256 -r {rg} {args.bamfile}',
                shell=True, text=True, capture_output=True)
                count = 0
                for samline in count_job.stdout.splitlines():
                    xw_match = re.search(xw_pat, samline)
                    if(xw_match):
                        count += int(xw_match.group(1))
                    else:
                        print('Fatal parse error to count reads, no XW found')
                        sys.exit()
            if args.stranded == True:
                repl = '_' + rg + '_p_depths.txt'
                f1 = args.bamfile.replace('.bam', repl)
                d_files.append(f1)
                counts.append(count)
                if bam_type == 'verbose':
                    c1 = (f'samtools view -F 272 -r {rg} -h {args.bamfile} | ' +
                    f'samtools depth - > {f1}')
                else:
                    c1 = (f'samtools view -F 276 -r {rg} {args.bamfile}')
                cmds.append(c1)

                repl = '_' + rg + '_m_depths.txt'
                f2 = args.bamfile.replace('.bam', repl)
                d_files.append(f2)
                counts.append(count)
                if bam_type == 'verbose':
                    c2 = (f'samtools view -F 256 -f 16 -r {rg} -h {args.bamfile} | ' +
                    f'samtools depth - > {f2}')
                else:
                    c2 = (f'samtools view -F 260 -f 16 -r {rg} {args.bamfile}')
                cmds.append(c2)

            else:
                # readgroup, not-stranded
                repl = '_' + rg + '_depths.txt'
                f1 = args.bamfile.replace('.bam', repl)
                d_files.append(f1)
                counts.append(count)
                if bam_type == 'verbose':
                    c1 = (f'samtools view -F 256 -r {rg} -h {args.bamfile} | ' +
                    f'samtools depth - > {f1}')
                else:
                    c1 = (f'samtools view -F 260 -r {rg} {args.bamfile}')
                cmds.append(c1)

    elif args.mode == 'readlength':
        # counts from the whole file, like for simple
        if bam_type == 'verbose':
            count_job = subprocess.run(f'samtools view -c -F 256 {args.bamfile}',
            shell=True, text=True, capture_output=True)
            count = int(count_job.stdout.split('\n')[0])
        else:
            count_job = subprocess.run(f'samtools view -F 256 {args.bamfile}',
            shell=True, text=True, capture_output=True)
            count = 0
            for samline in count_job.stdout.splitlines():
                xw_match = re.search(xw_pat, samline)
                if(xw_match):
                    count += int(xw_match.group(1))
                else:
                    print('Fatal parse error to count reads, no XW found')
                    sys.exit()

        # define expressions for qlen selection
        filts = ['"qlen==21"', '"qlen==22"', '"qlen==23 || qlen==24"']
        filts.append('"qlen < 21 || qlen > 24"')
        filt_names = ['21', '22', '23-24', 'other']

        # cycle through each
        idx = 0
        for filt_expression in filts:
            if args.stranded == True:
                repl = '_' + filt_names[idx] + '_p_depths.txt'
                f1 = args.bamfile.replace('.bam', repl)
                d_files.append(f1)
                counts.append(count)
                if bam_type == 'verbose':
                    c1 = (f'samtools view -F 272 -e {filt_expression} -h {args.bamfile} | ' +
                    f'samtools depth - > {f1}')
                else:
                    c1 = (f'samtools view -F 276 -e {filt_expression} {args.bamfile}')
                cmds.append(c1)


                repl = '_' + filt_names[idx] + '_m_depths.txt'
                f2 = args.bamfile.replace('.bam', repl)
                d_files.append(f2)
                counts.append(count)
                if bam_type == 'verbose':
                    c2 = (f'samtools view -F 256 -f 16 -e {filt_expression} -h {args.bamfile} | ' +
                    f'samtools depth - > {f2}')
                else:
                     c2 = (f'samtools view -F 260 -f 16 -e {filt_expression} {args.bamfile}')
                cmds.append(c2)

            else:
                # readgroup, not-stranded
                repl = '_' + filt_names[idx] + '_depths.txt'
                f1 = args.bamfile.replace('.bam', repl)
                d_files.append(f1)
                counts.append(count)
                if bam_type == 'verbose':
                    c1 = (f'samtools view -F 256 -e {filt_expression} -h {args.bamfile} | ' +
                    f'samtools depth - > {f1}')
                else:
                    c1 = (f'samtools view -F 260 -e {filt_expression} {args.bamfile}')
                cmds.append(c1)
            idx += 1
    
    # Execute jobs
    print('')
    print('Getting raw depths')
    for index, cmd in enumerate(cmds):
        print(cmd)
        if bam_type == 'verbose':
            subprocess.run(cmd, shell=True)
        else:
            outfile = d_files[index]
            with open(outfile, "w") as of:
                sv = subprocess.run(cmd, shell=True, text=True, capture_output=True)
                last_chr = 'notachromname'
                for samline in sv.stdout.splitlines():
                    sf = samline.split('\t')
                    if sf[2] != last_chr:
                        if last_chr != 'notachromname':
                            # flush the last chromosome data
                            for key in sorted(d.keys()):
                                this_depth = str(d[key])
                                of.write(f'{last_chr}\t{key}\t{this_depth}\n')
                        # reset for new
                        d = Counter()
                    xw_match = re.search(xw_pat, samline)
                    pos = int(sf[3])
                    endpos = pos + len(sf[9])
                    if(xw_match):
                        for i in range(pos, endpos):
                            d[i] += int(xw_match.group(1))
                    else:
                        print('Fatal parse error to parse depths, no XW found')
                        sys.exit()
                    last_chr = sf[2]
                
                # Need to do the final chr
                if last_chr != 'notachromname':
                    for key in sorted(d.keys()):
                        this_depth = str(d[key])
                        of.write(f'{last_chr}\t{key}\t{this_depth}\n')

        print('')
    
    # testing
    # sys.exit()

    # Write wiggle files, converting to rpm, negative to - strand
    print('')
    print('Writing wiggle files')
    idx = 0
    wigfiles = []
    for file in d_files:
        count = counts[idx]
        minus_pattern = re.compile('_m_depths.txt$')
        x = 1000000
        if minus_pattern.search(file):
            x = -1000000

        wigfile = file.replace('_depths.txt', '.wig')
        print(wigfile)
        wigfiles.append(wigfile)
        wigfh = open(wigfile, "w")

        last_chr = 'foobarbaz'
        last_pos = 0
        with open(file) as fh:
            for line in fh:
                fields = line.rstrip().split('\t')
                if fields[0] != last_chr:
                    # might we need a right-hand pad of a 0?
                    # be on guard for edge of chrom
                    if last_chr != 'foobarbaz':
                        if last_pos < chr_sizes[last_chr]:
                            wigfh.write(str(last_pos + 1))
                            wigfh.write('\t0\n')

                    o = f'variableStep chrom={fields[0]}\n'
                    wigfh.write(o)
                    last_pos = 0
                
                this_pos = int(fields[1])
                if this_pos > last_pos + 1:
                    # We want 0's flanking all non-zero numbers
                    # This makes line plots look better when rendered

                    # write a 0 at last_pos + 1 unless last_pos was zero
                    if last_pos != 0:
                        wigfh.write(str(last_pos + 1))
                        wigfh.write('\t0\n')

                    # possibly write a 0 at this_pos - 1
                    if (this_pos - 1) > (last_pos + 1):
                        wigfh.write(str(this_pos - 1))
                        wigfh.write('\t0\n')

                # write position
                wigfh.write(fields[1])
                wigfh.write('\t')

                # Compute rpm value
                rpm = str(round((int(fields[2]) / count) * x, 4))
                wigfh.write(rpm)
                wigfh.write('\n')

                last_chr = fields[0]
                last_pos = this_pos
        # end of the final chromosome
        #if last_chr != 'notachromname':
        if last_chr in chr_sizes:
            if last_pos < chr_sizes[last_chr]:
                wigfh.write(str(last_pos + 1))
                wigfh.write('\t0\n')
        wigfh.close()
        idx += 1
    
    # Make bigwigs
    print('')
    print('Writing bigwig files')
    for wigfile in wigfiles:
        bwfile = wigfile.replace('.wig', '.bw')
        #cmd = f'wigToBigWig {wigfile} {sq_file} {bwfile}'
        cmd = f'wigToBigWig {wigfile} {tmp.name} {bwfile}'
        print(cmd)
        subprocess.run(cmd, shell=True)
    
    # Intermediate file clean up
    #for d_file in d_files:
        #subprocess.run(f'rm -f {d_file}', shell=True)
    #for wigfile in wigfiles:
        #subprocess.run(f'rm -f {wigfile}', shell=True)
    #subprocess.run(f'rm -f {sq_file}', shell=True)
    os.remove(tmp.name)
