#!/usr/bin/env python3

import argparse
import sys

version='1.1.0'

parser = argparse.ArgumentParser(description="%s v%s: From a fasta file containing different potential oligonucleotides, identifies regions of interest and groups together oligonucleotides related to the same region." % ('%(prog)s', version), add_help=False)

# Add the arguments to the parser
requirArgs = parser.add_argument_group('Required arguments')
functiArgs = parser.add_argument_group('Optional arguments related to defining regions')
outputArgs = parser.add_argument_group('Optional arguments related to the output')
optionArgs = parser.add_argument_group('Other optional arguments')

requirArgs.add_argument("-f", "--file", dest="file_in", required=True,
						help="A fasta file containing the oligonucleotides.")

functiArgs.add_argument("-a", "--align", dest="align", required=False, action="store", type=float,
						default=0.5,
						help="An alignment threshold. Default = %(default)s, meaning that it is required that at least 50%% of the length of the smallest given oligonucleotide is identical to be considered within a region.")

outputArgs.add_argument("-o", "--output", dest="output", required=False,
						default=None,
						help="The output file name. By default will replace the extension of the target file with '_regions.tsv'.")

outputArgs.add_argument("-e", "--export", dest="fasta", required=False, action="store",
						default=None,
						help="If selected, will export a fasta file with the selected regions to the given output.")

optionArgs.add_argument("-v", "--verbose", dest="verbose", required=False, action="store_false",
						help="If selected, will not print information to the console.")

optionArgs.add_argument("-h", "--help", action="help",
						help="Show this help message and exit.")

optionArgs.add_argument("-V", "--version", action='version',
						version='oligoN-design %s v%s' % ('%(prog)s', version),
						help='Show the version number and exit.')

args = parser.parse_args()

# Setting variables and functions __________________________________________________________________
if args.verbose:
	print("  Setting variables...", flush=True)

def readFasta(fastafile):
	out = {}
	w = False
	for line in open(fastafile):
		if line.startswith(">"):
			name = line.replace(">", "")
			name = name.replace("\n", "")
			out[name] = str()
		else:
			if "-" in line:
				w = True
			sequence = line.replace("\n", "")
			sequence = sequence.upper()
			out[name] = (out[name] + sequence)
	if w:
		print("    Warning!!", fastafile, "contains gaps", flush=True)
	return out

def align(seqA, seqB, cover):
	seqOut = None
	if seqB in seqA:
		seqOut = seqA
	else:
		length = min([len(seqA), len(seqB)])
		minlen = round(length * cover)+1
		for i in range(length, minlen, -1):
			if seqA[-i:] == seqB[:i]:
				seqOut =  seqA + seqB[i:]
	return seqOut

def alignBoth(seqA, seqB, cover):
	seqOut = None
	tmp1 = align(seqA=seqA, seqB=seqB, cover=cover)
	tmp2 = align(seqA=seqB, seqB=seqA, cover=cover)
	if tmp1 is not None:
		seqOut = tmp1
	elif tmp2:
		seqOut = tmp2
	return seqOut

def renameKeyDict(old_name, new_name, dictionary):
	out = {new_name if k == old_name else k:v for k,v in dictionary.items()}
	return out

if args.output is None:
	import re
	output = re.sub("\\.[^\\.]+$", "_regions.tsv", args.file_in)
else:
	output = args.output

# Reading input file _______________________________________________________________________________
if args.verbose:
	print("  Reading input file:", str(args.file_in), flush=True)

oligos = readFasta(args.file_in)
lenOligos = len(oligos)

if args.verbose:
	print("    Oligonucleotides in input file:", lenOligos, flush=True)

# Start the identification _________________________________________________________________________
if args.verbose:
	print("  Identifying regions...", flush=True, end="")
	i = 0
	pct = 0
regions = {}
for name, oligo in oligos.items():
	if args.verbose:
		i += 1
		pcti = round(i/lenOligos*100)
		if pcti > pct:
			pct = pcti
			print("\r  Identifying regions...\t", pct, "%", sep="", end="", flush=True)
	if len(regions) == 0:
		regions[oligo] = list()
		regions[oligo].append(name)
	elif oligo in regions.keys():
		regions[oligo].append(name)
	else:
		for region, values in regions.items():
			tmp = alignBoth(oligo, region, args.align)
			if tmp is not None:
				break
		if tmp is None:
			regions[oligo] = list()
			regions[oligo].append(name)
		else:
			regions = renameKeyDict(region, tmp, regions)
			regions[tmp].append(name)

# Loop again over the identified regions in case there are new regions that overlap
regionsOut = {}
for region, values in regions.items():
	clustered = set()
	for reg, val in regions.items():
		if region != reg and reg not in clustered:
			tmp = alignBoth(reg, region, args.align)
			if tmp is not None:
				break
	if tmp is None:
		regionsOut[region] = values
	else:
		regionsOut[tmp] = values + val
		clustered.add(reg)

if args.verbose:
	print("\n  \033[1mIn total", len(regionsOut), "regions have been identified\033[0m", flush=True)

# Exporting regions ________________________________________________________________________________
if args.verbose:
	print("  Exporting regions to:", output, flush=True)
with open(output, "w", buffering=2) as logfile:
	print("region\tsequence\tlength\toligos\tidentifiers", file=logfile, flush=True)
	i = 0
	for region, identifiers in regionsOut.items():
		i += 1
		name = "region" + str(i)
		length = str(len(region))
		size = str(len(identifiers))
		ids = "|".join(identifiers)
		print(name + "\t" + region + "\t" + length + "\t" + size + "\t" +  ids, file=logfile, flush=True)

# Export fasta file if selected ____________________________________________________________________
if args.fasta is not None:
	if args.verbose:
		print("  Exporting a fasta file with the regions to:", args.fasta, flush=True)
	with open(args.fasta, "w", buffering=1) as fasfile:
		i = 0
		for seq, ids in regionsOut.items():
			i += 1
			name = "region" + str(i) + "_size" + str(len(ids))
			print(">" + str(name) + "\n" + str(seq), file=fasfile, flush=True)

# Print concluding information _____________________________________________________________________
if args.verbose:
	print("Done")
