#!/usr/bin/env python3

import argparse
import re
import sys

version='1.1.0'

parser = argparse.ArgumentParser(description="%s v%s: Builds a consensus sequence of an alignment." % ('%(prog)s', version), add_help=False)

# Add the arguments to the parser
requirArgs = parser.add_argument_group('Required arguments')
functiArgs = parser.add_argument_group('Optional arguments related to consensus building')
outputArgs = parser.add_argument_group('Optional arguments related to the output')
optionArgs = parser.add_argument_group('Other optional arguments')

requirArgs.add_argument("-f", "--file", dest="inFile", required=True,
						help="An aligned fasta file.")

outputArgs.add_argument("-o", "--output", dest="outFile", required=False, default=None,
						help="The output fasta file name. If the file exists, the consensus sequence will be appended at the end of the file. By default, the consensus sequence will be printed to the console.")

functiArgs.add_argument("-t", "--threshold", dest="threshold", required=False, action='store', type=float,
						default=0.7,
						help="The threshold to build a consensus given in percentage (0-100). Default=%(default)s.")

functiArgs.add_argument("-b", "--base", dest="baseThreshold", required=False, action='store', type=float,
						default=0.3,
						help="The threshold for a base to be considered. Default=%(default)s")

functiArgs.add_argument("-g", "--gaps", dest="gaps", required=False, action='store', type=float,
						default=0.8,
						help="If the proportion of gaps is more than 'g', the consensus is a gap ('-'). Default=%(default)s")

outputArgs.add_argument("-a", "--ambiguities", dest="ambiguities", required=False, action="store_true",
						help="If selected, ambiguites will be set to 'N' and not the IUPAC bases.")

outputArgs.add_argument("-m", "--most", dest="most", required=False, action="store_true",
						help="If selected, the most abundant base at each position will be used instead of the IUPAC code.")

outputArgs.add_argument("-r", "--removeGaps", dest="removeGaps", required=False, action="store_true",
						help="If selected, gaps in the consensus sequence will be remove.")

optionArgs.add_argument("-v", "--verbose", dest="verbose", required=False, action="store_false",
						help="If selected, will not print information to the console.")

optionArgs.add_argument("-h", "--help", action="help",
						help="Show this help message and exit.")

optionArgs.add_argument("-V", "--version", action='version',
						version='oligoN-design %s v%s' % ('%(prog)s', version),
						help='Show the version number and exit.')

args = parser.parse_args()

# Setting variables and functions __________________________________________________________________
def readFasta(fastafile):
	out = {}
	for line in open(fastafile):
		if line.startswith(">"):
			name = line.replace(">", "")
			name = name.replace("\n", "")
			out[name] = str()
		else:
			sequence = line.replace("\n", "")
			sequence = sequence.upper()
			out[name] = (out[name] + sequence)
	return out

def parseAlignment(dictionary, length):
	out = {}
	for i in range(1, length+1):
		out[i] = list()
	for vals in fasta.values():
		i = 0
		for position in list(vals):
			i += 1
			out[i].append(position)
	return out

if args.verbose:
	print("  Setting variables")
	print("    Alignment:  ", args.inFile)
	print("    Threshold:  ", args.threshold, sep="")
	print("    Gaps:       ", args.gaps, sep="")
	print("    Base:       ", args.baseThreshold, sep="")
	if args.outFile is not None:
		print("    Output file: ", args.outFile, sep="")
	if args.ambiguities:
		print("    Ambiguities will be set to 'N'")
	if args.most:
		print("    The most abundant base will be taken for consensus")
	if args.removeGaps:
		print("    Gaps will be removed in output sequence")

iupac = {"A":"A", "C":"C", "G":"G", "T":"T", "AG":"R", "CT":"Y", "CG":"S", "AT":"W", "GT":"K", "AC":"M", "CGT":"B", "AGT":"D", "ACT":"H", "ACG":"V","ACGT":"N"}

# Reading input file and parsing ___________________________________________________________________
if args.verbose:
	print("  Reading and parsing alignment")

fasta = readFasta(args.inFile)

length = set()
for vals in fasta.values():
	length.add(len(vals))

if len(length) != 1:
	print("  \033[91mError!\033[0m Input file is not aligned. Exiting...")
	sys.exit(1)
else:
	length = next(iter(length))

sequences = len(fasta)

fasta = parseAlignment(fasta, length)

if args.verbose:
	print("    Sequences: ", sequences, sep="")
	print("    Positions: ", length, sep="")

# Building consensus _______________________________________________________________________________
if args.verbose:
	print("  Building consensus", end="")
	pcti = 0
consensus = list()
for position, values in fasta.items():
	if args.verbose:
		pct = round(position/length*100)
		if pct > pcti:
			pcti = pct
			print("\r  Building consensus\t", pct, "%", sep="", end="")
	bases = {}
	for b in values:
		if b not in bases.keys():
			bases[b] = 1
		else:
			bases[b] += 1
	# Initialize variables
	base = None
	candidate = list()
	most = None
	mostAbun = 0
	total = sequences
	# Check for gaps
	if "-" in bases.keys():
		c = bases["-"]
		C = c / sequences
		if C >= args.gaps:
			base = "-"
		else:
			total = sequences - c
			del bases["-"]
	# If consensus is not gap, now deal with the bases
	if base is None:
		for b, c in bases.items():
			C = c/total
			if C >= args.threshold:
				base = b
			elif C >= args.baseThreshold:
				candidate.append(b)
			if C > mostAbun:
				most = b
				mostAbun = C
	# Now finally calling the base
	if base is not None:
		if args.removeGaps:
			if base != "-":
				consensus.append(base)
		else:
			consensus.append(base)
	elif args.most:
		consensus.append(most)
	elif len(candidate) > 0:
		if "N" in candidate:
			base = "N"
		else:
			candidate.sort()
			base = "".join(candidate)
			base = iupac[base]
		if args.ambiguities:
			if base not in "ACGT":
				base = "N"
		consensus.append(base)

consensusOut = "".join(consensus)

if args.verbose:
	print("\n  Consensus positions:     ", len(consensusOut)-consensusOut.count("-"))
	if args.most is False:
		print("    Of which are ambiguous:", len(consensusOut)-len(re.sub("[^ACTG-]", "", consensusOut)))
	if args.removeGaps is False:
		print("    Of which are gaps:     ", consensusOut.count("-"))
	if args.outFile is None:
		print("  Consensus sequence:")
		print("")

if args.outFile is not None:
	f = open(args.outFile, "a")
	tmp = re.sub("\\.[^\\.]+$", "", args.inFile)
	ext = ""
	if args.most:
		ext = "_mostAbundant"
	f.write(str(">" + str(tmp) + "_consensus_t" + str(round(args.threshold*100)) + "_b" + str(round(args.baseThreshold*100)) + "_g" + str(round(args.gaps*100)) + ext + "\n"))
	f.write(str(str(consensusOut) + "\n"))
	f.close()
	if args.verbose:
		print("  Consensus sequence exported to:", args.outFile)
else:
	print(str(consensusOut))

# __________________________________________________________________________________________________
if args.verbose:
	if args.outFile is None:
		print("")
	print("Done")
