#!/usr/bin/env python3

import argparse
import time
import re
import itertools
from Bio.Seq import Seq

start_time = time.time()
version='1.1.0'

parser = argparse.ArgumentParser(description="%s v%s: Test whether the sequences of a fasta file contain or not a set of oligonucleotides in another fasta file." % ('%(prog)s', version), add_help=False)

# Add the arguments to the parser
requirArgs = parser.add_argument_group('Required arguments')
finputArgs = parser.add_argument_group('Optional arguments related to the input')
functiArgs = parser.add_argument_group('Optional arguments related to the testing')
outputArgs = parser.add_argument_group('Optional arguments related to the output')
optionArgs = parser.add_argument_group('Other optional arguments')

requirArgs.add_argument("-t", "--target", dest="file_in", required=True,
						help="A target fasta file.")

requirArgs.add_argument("-f", "--file", dest="oligos", required=True,
						help="A fasta file containing the oligonucleotides.")

finputArgs.add_argument("-p", "--revComp", dest="revComp", required=False, action="store_true",
						help="If selected, will reverse complement the oligonucleotides before testing.")

functiArgs.add_argument("-m", "--mismatches", dest="mismatches", nargs="+", required=False, action='store', type=str,
						default=['0'],
						help="The number of mismatches allowed. Default = %(default)s, will not test for mismatches. A range can be specified with the '-' sign (i.e. '0-3'). Please bear in mind that an excessively high number of mismatches will considerably slow down the search.")

functiArgs.add_argument("-x", "--indels", dest="indels", required=False, action="store_false",
						help="If selected, will not test for insertions and/or deletions.")

outputArgs.add_argument("-o", "--output", dest="file_out", required=False,
						default=None,
						help="The name of the output file. By default will remove the extension of the input file and add '_testTarget.tsv'. This will generate a table with the sequences to be tested as rows, oligos as columns and the values represent the position in the sequence at which the oligo was found.")

outputArgs.add_argument("-a", "--absence", dest="absence", required=False,
						default="0",
						help="The value to represent absence of oligonucleotide. Default = '%(default)s'.")

outputArgs.add_argument("-b", "--binary", dest="binary", required=False, action="store_true",
						help="If selected, values exported will be replaced by '1' and '0', oligonucleotide present and absent respectively.")

outputArgs.add_argument("-d", "--detailed", dest="detailed", required=False,
						default=None,
						help="If a file name is provided, it will export to the given output a fasta file with all oligonucleotides found aligned to each sequence. This option becomes visually impractical with many oligonucleotides (e.g., > ~100 seqs).")

optionArgs.add_argument("-v", "--verbose", dest="verbose", required=False, action="store_false",
						help="If selected, will not print information to the console.")

optionArgs.add_argument("-h", "--help", action="help",
						help="Show this help message and exit.")

optionArgs.add_argument("-V", "--version", dest="version", action='version',
						version='oligoN-design %s v%s' % ('%(prog)s', version),
						help='Show the version number and exit.')

args = parser.parse_args()

# Define functions _________________________________________________________________________________

def seconds2string(seconds, longOut=False):
	hours = seconds // 3600
	seconds %= 3600
	minutes = seconds // 60
	seconds %= 60
	if hours == 0:
		if minutes == 0:
			if seconds < 10:
				seconds, remainder = divmod(seconds, 1)
				millisec = int(remainder * 1000)
				if longOut:
					out = str("%2d seconds %02d milliseconds" % (seconds, millisec))
				else:
					out = str("%ds %02dms" % (seconds, millisec))
			else:
				if longOut:
					out = str("%02d seconds" % (seconds))
				else:
					out = str("%02ds" % (seconds))
		else:
			if longOut:
				out = str("%02d minutes %02d seconds" % (minutes, seconds))
			else:
				out = str("%02dm%02ds" % (minutes, seconds))
	else:
		if longOut:
			out = str("%d hours %02d minutes %02d seconds" % (hours, minutes, seconds))
		else:
			out = str("%dh%02dm%02ds" % (hours, minutes, seconds))
	return out

def string2numeric(string):
	try:
		return int(string)
	except ValueError:
		try:
			return float(string)
		except ValueError:
			return None

def getNumersFromArgs(stringList):
	out = list()
	for m in stringList:
		if '-' in m:
			first = int(m.split('-')[0])
			last = int(m.split('-')[1])
			if first < last:
				for m2 in range(first, last+1):
					out.append(m2)
			elif first > last:
				for m2 in range(last, first+1):
					out.append(m2)
		elif " " in m:
			tocut = m.strip().split(" ")
			for m2 in tocut:
				out.append(int(m2))
		else:
			out.append(int(m))
	out.sort()
	return  out

def readFasta(fastafile, revcom=False):
	out = {}
	w = False
	for line in open(fastafile):
		if line.startswith(">"):
			name = line.replace(">", "")
			name = name.replace("\n", "")
			out[name] = str()
		else:
			if "-" in line:
				w = True
			sequence = line.replace("\n", "")
			sequence = sequence.upper()
			if revcom:
				sequence = Seq(sequence).reverse_complement()
			out[name] = (out[name] + sequence)
	if w:
		print("Warning!!", fastafile, "contains gaps", flush=True)
	return out

iupac = {"R":"(A|G)",
		 "Y":"(C|T)",
		 "S":"(G|C)",
		 "W":"(A|T)",
		 "K":"(G|T)",
		 "M":"(A|C)",
		 "B":"(C|G|T)",
		 "D":"(A|G|T)",
		 "H":"(A|C|T)",
		 "V":"(A|C|G)"}

def getMismatchedOligos(oligo, mismatches, indels=True):
	out = list()
	length = len(oligo)
	for m in mismatches:
		positions = list(range(1, length+1))
		if m == 0:
			out.append(oligo)
		else:
			if m > 1:
				positions = list(itertools.combinations(positions,m))
			for p in positions:
				moligo = list(oligo)
				for l in range(0, length):
					if m == 1:
						if l+1 == p:
							if indels and l+1 != 1 and l+1 != length:
								moligo[l] = ".{0,2}"
							else:
								moligo[l] = "."
					else:
						if l+1 in p:
							if indels and l+1 != 1 and l+1 != length:
								moligo[l] = ".{0,2}"
							else:
								moligo[l] = "."
				for i in range(0, len(moligo)):
					if moligo[i] in iupac.keys():
						moligo[i] = iupac[moligo[i]]
				moligo = "".join(moligo)
				out.append(moligo)
	return out

# Setting variables and troubleshooting arguments __________________________________________________
if args.verbose:
	print("  Setting variables...", flush=True)
	if args.binary:
		print("    Binary output format")

# Setting output name
if args.file_out is None:
	outFile = re.sub("\\.[^\\.]+$", "_testTarget.tsv", args.file_in)
else:
	outFile = args.file_out

# Reading mismatches arguments
mismatches = getNumersFromArgs(args.mismatches)
if args.verbose and mismatches != [0]:
	print("    Mismatches:", *mismatches, flush=True)
	if args.indels is False:
		print("    Info: -x/--indels has been disabled")

# Reading input file _______________________________________________________________________________
if args.verbose:
	print("  Reading file to be tested:", str(args.file_in), flush=True)

target = readFasta(args.file_in)
lenTarget = len(target)

if args.verbose:
	print("    Sequences in file to be tested:", lenTarget, flush=True)

# Reading oligos file ______________________________________________________________________________
if args.verbose:
	print("  Reading oligonucleotides file:", str(args.oligos), flush=True)

oligos = readFasta(args.oligos, revcom=args.revComp)
lenOligos = len(oligos)

if args.verbose:
	print("    Number of oligonucleotides:", lenOligos, flush=True)

# Testing __________________________________________________________________________________________
if args.verbose:
	print("  Testing...", flush=True, end="")
	i = 0
	n = 0
	pcti = 0
if args.detailed is not None:
	detailed = open(args.detailed, 'w', buffering=1)
with open(outFile, "w", buffering=1) as outfile:
	headers = list(oligos.keys())
	headers = ("\t").join(headers)
	print("identifier\t" + headers, file=outfile)
	for keyt, seq in target.items():
		if args.detailed is not None:
			detailed.write(">" + keyt +"\n" + seq + "\n")
		hits = list()
		n += 1
		for keyo, oligo in oligos.items():
			if args.verbose:
				i += 1
				pct = round((i / (lenTarget * lenOligos) )* 100)
				if pct > pcti:
					pcti = pct
					print("\r  Testing...  ", pct, "%  ", sep="", end="", flush=True)
			if args.binary and mismatches == [0] and args.detailed is None:
				if oligo in seq:
					position = "1"
				else:
					position = args.absence
			else:
				oligom = getMismatchedOligos(oligo, mismatches, indels=args.indels)
				position = args.absence
				for oligomk in oligom:
					hit = re.search(oligomk, seq)
					if hit is not None:
						position = hit.start()+1
						break
				if args.detailed is not None and position != args.absence:
					tail = len(seq) - hit.start() - len(oligo)
					detailedOut = "-" * hit.start() + oligo + "-" * tail
					detailed.write(">" + keyo + "_" + str(n) + "\n" + detailedOut + "\n")
				if args.binary:
					if position != args.absence:
						position = "1"
			hits.append(str(position))
		hits = ("\t").join(hits)
		print(keyt + "\t" + hits, file=outfile)

if args.detailed is not None:
	detailed.close()

if args.verbose:
	t = time.time() - start_time
	t = seconds2string(t, longOut=True)
	print("\n  Run time:", t, flush=True)
	print("  Output file written to: \033[1m", outFile, "\033[0m", sep="")
	if args.detailed is not None:
		print("  Aligned oligos written to: \033[1m", args.detailed, "\033[0m", sep="")
	print("Done")
	