#!/usr/bin/env python3

import argparse
import re
import itertools

version='1.1.0'

parser = argparse.ArgumentParser(description="%s v%s: Given a sequence and 'm' mismatches, it will export a fasta file with the mismatched region (default) or the complete sequence." % ('%(prog)s', version), add_help=False)

# Add the arguments to the parser
requirArgs = parser.add_argument_group('Required arguments')
functiArgs = parser.add_argument_group('Optional arguments related to mismatches search')
outputArgs = parser.add_argument_group('Optional arguments related to the output')
optionArgs = parser.add_argument_group('Other optional arguments')

requirArgs.add_argument("-s", "--sequence", dest="sequence", required=True,
						help="An oligonucleotide sequence.")

requirArgs.add_argument("-e", "--excluding", dest="excluding", required=True,
						help="A excluding fasta file to look against.")

functiArgs.add_argument("-m", "--mismatches", dest="mismatches", nargs="+", required=False, action='store', type=str,
						default=['1','2'],
						help="The number of mismatches allowed. Default = %(default)s. A range can be specified with the '-' sign (i.e. '1-3'). Please bear in mind that an excessively high number of mismatches will considerably slow down the search.")

functiArgs.add_argument("-p", "--positions", dest="positions", nargs="+", required=False, action='store', type=str,
						default=None,
						help="The positions of the mismatch(es). By default will consider all positions.")

functiArgs.add_argument("-x", "--indels", dest="indels", required=False, action="store_false",
						help="If selected, will not test for insertions and/or deletions.")

outputArgs.add_argument("-o", "--output", dest="file_out", required=False,
						default=None,
						help="The name of the output fasta file. By default will remove the extension of the input file and add '_mismatches.fasta'. If a name is provided with the '-n/--name' function, it will write the searched sequence in the first entry.")

outputArgs.add_argument("-c", "--complete", dest="complete", required=False, action="store_true",
						help="If selected, will export the complete mismatched sequence and not just the region mismatched.")

outputArgs.add_argument("-n", "--name", dest="name", required=False,
						default=None,
						help="If selected, will include the searched sequence in the output fasta file with the given name, and used for the exported fasta file if not given.")

optionArgs.add_argument("-v", "--verbose", dest="verbose", required=False, action="store_false",
						help="If selected, will not print information to the console.")

optionArgs.add_argument("-h", "--help", action="help",
						help="Show this help message and exit.")

optionArgs.add_argument("-V", "--version", action='version',
						version='oligoN-design %s v%s' % ('%(prog)s', version),
						help='Show the version number and exit.')

args = parser.parse_args()

# Define functions _________________________________________________________________________________

def getNumersFromArgs(stringList):
	out = list()
	for m in stringList:
		if '-' in m:
			first = int(m.split('-')[0])
			last = int(m.split('-')[1])
			if first < last:
				for m2 in range(first, last+1):
					out.append(m2)
			elif first > last:
				for m2 in range(last, first+1):
					out.append(m2)
		else:
			out.append(int(m))
	out.sort()
	return  out

def readFasta(fastafile, revcom=False):
	out = {}
	w = False
	for line in open(fastafile):
		if line.startswith(">"):
			name = line.replace(">", "")
			name = name.replace("\n", "")
			out[name] = str()
		else:
			if "-" in line:
				w = True
			sequence = line.replace("\n", "")
			sequence = sequence.upper()
			if revcom:
				sequence = Seq(sequence).reverse_complement()
			out[name] = (out[name] + sequence)
	if w:
		print("  Warning!!", fastafile, "contains gaps", flush=True)
	return out

def getMismatchedOligos(oligo, mismatches, positions=None, indels=True):
	out = list()
	length = len(oligo)
	for m in mismatches:
		if positions == None:
			pos = list(range(1, length+1))
			if m > 1:
				pos = list(itertools.combinations(pos,m))
		else:
			pos = list()
			for p in positions:
				if p <= length:
					pos.append(p)
				else:
					if args.verbose:
						print("    Ignoring mismatches in position ", p, " (length of oligo: ", length,")", sep="", flush=True)
			if m > 1:
				pos = list(itertools.combinations(pos,m))
		for p in pos:
			moligo = list(oligo)
			for l in range(0, length):
				if m == 1:
					if l+1 == p:
						if indels and l+1 != 1 and l+1 != length:
							moligo[l] = ".{0,2}"
						else:
							moligo[l] = "."
				else:
					if l+1 in p:
						if indels and l+1 != 1 and l+1 != length:
							moligo[l] = ".{0,2}"
						else:
							moligo[l] = "."
			moligo = "".join(moligo)
			out.append(moligo)
	return out

# Setting variables ________________________________________________________________________________
if args.verbose:
	print("  Setting variables...", flush=True)

if args.file_out is None:
	if args.name is not None:
		outFile = args.name + "_mismatches.fasta"
	else:
		outFile = args.sequence + "_mismatches.fasta"
else:
	outFile = args.file_out

if args.indels is False:
	print("    Info: -x/--indels has been selected, so insertions and deletions won't be searched.")

# Reading input ____________________________________________________________________________________
mismatches = getNumersFromArgs(args.mismatches)
if len(mismatches) == 1:
	if mismatches[0] == 0:
		print("  \033[91mError!\033[0m Please set mismatches to a value higher than 0...")
		import sys
		sys.exit(1)
if args.verbose:
	print("    Mismatches:    ", *mismatches, flush=True)

if args.positions == None:
	positions = None
	if args.verbose:
		print("    Considering all positions", flush=True)
else:
	positions = getNumersFromArgs(args.positions)
	if args.verbose:
		print("    Positions:     ", *positions, flush=True)

# Reading input file _______________________________________________________________________________
if args.verbose:
	print("  Sequence to be tested:", str(args.sequence), flush=True)

oligos = getMismatchedOligos(args.sequence, mismatches, positions=positions, indels=args.indels)

# Reading excluding file ___________________________________________________________________________
if args.verbose:
	print("  Reading excluding file:", str(args.excluding), flush=True)

exc = readFasta(args.excluding)
lenExc = len(exc)

if args.verbose:
	print("    Sequences in excluding file:", lenExc, flush=True)

# Testing __________________________________________________________________________________________
if args.verbose:
	print("  Exporting...", end="", flush=True)
	i = 0
	pcti = 0
with open(outFile, "w", buffering=1) as outfile:
	if args.name is not None:
		print(">" + args.name + "\n", args.sequence, sep="", file=outfile, flush=True)
	identityHits = set()
	hits = 0
	out = 0
	for oligo in oligos:
		if args.verbose:
			i += 1
			pct = round(i/len(oligos)*100)
			if pct > pcti:
				print("\r  Exporting...\t", pct, "%", end="", sep="", flush=True)
		for key, val in exc.items():
			tmp = re.search(oligo, val)
			if tmp != None and tmp.group(0) != oligo:
				hits += 1
				if key not in identityHits:
					out += 1
					identityHits.add(key)
					print(">" + key, sep="", file=outfile, flush=True)
					if args.complete:
						print(val, sep="", file=outfile, flush=True)
					else:
						print(tmp.group(0), sep="", file=outfile, flush=True)

if args.verbose:
	print("\n  \033[1mThere were", hits, "total hits\033[0m in", out, "unique sequences")
	print("  Output file written to:", outFile)
	print("Done")
