#!/usr/bin/env python3

import argparse
import statistics as st
import math
from Bio import Align

version='1.1.0'

parser = argparse.ArgumentParser(description="%s v%s: From one or several fasta files, estimates the similarity of all sequences to the first sequences of every file by pairwise sequence alignment. See 'PairwiseAligner' from the Bio.Align module for further details. Briefly, it will output summary statistics of the alignment score, relative to the first sequence. " % ('%(prog)s', version), add_help=False)

# Add the arguments to the parser
requirArgs = parser.add_argument_group('Required arguments')
finputArgs = parser.add_argument_group('Optional arguments related to the input')
outputArgs = parser.add_argument_group('Optional arguments related to the output')
optionArgs = parser.add_argument_group('Other optional arguments')

requirArgs.add_argument("-f", "--file", dest="filein", nargs="+", required=True,
						help="Fasta file(s) containing the region to be compared against the first sequence.")

finputArgs.add_argument("-t", "--target", dest="target", required=False, action="store_true",
						help="If selected, will assume that the input file contains sequences from the target file, and will ignore all the sequence before that one ending with '_HMM_profile', and compare the HMM profile with all the sequences after that one.")

outputArgs.add_argument("-o", "--output", dest="output", required=False,
						default=None,
						help="The name of the output file. By default will print the information to the console.")

outputArgs.add_argument("-a", "--absolute", dest="absolute", required=False, action="store_true",
						help="If selected, it will output the statistics as absolute values, and not dividing the scores by the first sequence length.")

outputArgs.add_argument("-p", "--plot", dest="plot", required=False, action="store_true",
						help="If selected, will prompt a histogram with the scores for every file.")

optionArgs.add_argument("-v", "--verbose", dest="verbose", required=False, action="store_false",
						help="If selected, will not print information to the console.")

optionArgs.add_argument("-h", "--help", action="help",
						help="Show this help message and exit.")

optionArgs.add_argument("-V", "--version", action='version',
						version='oligoN-design %s v%s' % ('%(prog)s', version),
						help='Show the version number and exit.')

args = parser.parse_args()

# Define functions _________________________________________________________________________________

def readFasta(fastafile, target=False):
	out = {}
	for line in open(fastafile):
		if target:
			name = line.replace("\n", "")
			if name.endswith("_HMM_profile"):
				target = False
				name = name.replace(">", "")
				out[name] = str()
		else:
			if line.startswith(">"):
				name = line.replace(">", "")
				name = name.replace("\n", "")
				out[name] = str()
			else:
				sequence = line.replace("\n", "")
				sequence = sequence.replace(".", "-")
				sequence = sequence.upper()
				out[name] = (out[name] + sequence)
	else:
		return out

def percentile(listvals, qs):
	order = sorted(listvals)
	n = len(order)
	out = list()
	for q in qs:
		index = math.ceil(q * n)
		if index > n-1:
			index = n-1
		tmp = round(order[index], 3)
		out.append(str(tmp))
	return out

# Reading input file _______________________________________________________________________________
if args.plot:
	import matplotlib.pyplot as plt
	plotData = {}

if args.output is not None:
	f = open(args.output, "w")
	f.write("file\tsequenceExc\tname\tseqNumber\tmin\t5pct\t25pct\tmedian\t75pct\t95pct\tmax\tmean\tsd\n")

j = 0
for filei in args.filein:
	j += 1
	fasta = readFasta(filei, target=args.target)
	i = 0
	scores = list()
	for key, val in fasta.items():
		i += 1
		if args.verbose and args.output is not None:
			print("\r  Reading file ", j, "/", len(args.filein), " and getting statistics", sep="", end="")
		if i == 1:
			excname = key
			excluding = val
			if args.verbose and args.output is None:
				print("  File:", filei, flush=True)
				print("    Excluding sequence: ", excluding, " (", excname, ")", sep="", flush=True)
		else:
			aligner = Align.PairwiseAligner()
			score = aligner.score(excluding, val)
			if not args.absolute:
				score = score / len(excluding)
			scores.append(score)
	mins = str(round(min(scores), 3))
	pcts = percentile(scores, [0.05, 0.25, 0.5, 0.75, 0.95])
	maxs = str(round(max(scores), 3))
	mean = st.mean(scores)
	mean = str(round(mean, 3))
	sd = st.stdev(scores)
	sd = str(round(sd, 3))
	if args.verbose and args.output is None:
		print("\tMin\t5pct\t25pct\tmedian\t75pct\t95pct\tmax\t\tmean\tsd\tseqs", flush=True)
		print("", mins, pcts[0], pcts[1], pcts[2], pcts[3], pcts[4], maxs, "", mean, sd, i-1, sep="\t", flush=True)
		print(flush=True)
	elif args.output is not None:
		f.write(filei +"\t"+ excluding +"\t"+ excname +"\t"+ str(i-1) +"\t"+ mins +"\t"+ pcts[0] +"\t"+ pcts[1] +"\t"+ pcts[2] +"\t"+ pcts[3] +"\t"+ pcts[4] +"\t"+ maxs +"\t"+ mean +"\t"+ sd +"\n")
	if args.plot:
		plotData[filei] = scores

if args.output is not None:
	f.close()

if args.plot:
	for key, value in plotData.items():
		if key != "x":
			plt.hist(value, label=key, alpha=0.5)
	plt.legend()
	plt.show()

if args.verbose:
	if args.output is not None:
		print("\n  Stats printed to:", str(args.output), flush=True)
	print("Done")
