#!/usr/bin/env python3

import argparse
import re
import sys

version='1.1.0'

parser = argparse.ArgumentParser(description="%s v%s: Trim an alignment on a given region." % ('%(prog)s', version),
								 add_help=False)

# Add the arguments to the parser
requirArgs = parser.add_argument_group('Required arguments')
eitherArgs = parser.add_argument_group('Required at least one of')
outputArgs = parser.add_argument_group('Optional arguments related to the output')
optionArgs = parser.add_argument_group('Optional arguments')

requirArgs.add_argument("-f", "--file", dest="filein", required=True,
						help="An aligned fasta file.")

eitherArgs.add_argument("-r", "--region", dest="region", required=False, type=str, nargs="+",
						default=None,
						help="The oligonucleotide sequence to be searched (i.e., 'CAGTATTCTGAAGCGAGA') or the name of the sequence containing the region to be trimmed (i.e., 'oligoN4').")

eitherArgs.add_argument("-p", "--positions", dest="positions", required=False, type=str, nargs="+",
						default=None,
						help="The positions in the alignment to be kept separated by an hyphen (i.e., 316-334 914-952).")

eitherArgs.add_argument("-s", "--sequences", dest="sequences", required=False, type=str, nargs="+",
						default=None,
						help="A fasta file(s) contaning the oligonucleotides sequences to be searched.")

outputArgs.add_argument("-o", "--output", dest="file_out", required=False, nargs="+",
						default=None,
						help="The name(s) of the output fasta file(s). By default will remove the extension of the input file and add '_trimmed.fasta'. If more than one trimmed regions are given, by default will add consecutive integers after '_trimmed' (i.e., '_trimmed1.fasta', '_trimmed2.fasta').")

outputArgs.add_argument("-d", "--path", dest="path", required=False,
						default="",
						help="The path of the output files. Useful when default output names are given, but in a different directory. By default it will use the current directory. If names are provided through the '-o/--output' argument, this argument will be ignored.")

outputArgs.add_argument("-n", "--name", dest="name", required=False, action="store_true",
						help="If selected, will use the name of the sequences from the fasta file parsed by '-s/--sequences' for the output name.")

outputArgs.add_argument("-k", "--keep", dest="keep", required=False, action="store_true",
						help="If selected, will not delete empty sequences or sequences composed only of the trimmed region.")

optionArgs.add_argument("-v", "--verbose", dest="verbose", required=False, action="store_false",
						help="If selected, will not print information to the console.")

optionArgs.add_argument("-h", "--help", action="help",
						help="Show this help message and exit.")

optionArgs.add_argument("-V", "--version", action='version',
						version='oligoN-design %s v%s' % ('%(prog)s', version),
						help='Show the version number and exit.')

args = parser.parse_args()

# Define functions _________________________________________________________________________________

def readFasta(fastafile):
	out = {}
	for line in open(fastafile):
		if line.startswith(">"):
			name = line.replace(">", "")
			name = name.replace("\n", "")
			out[name] = str()
		else:
			sequence = line.replace("\n", "")
			sequence = sequence.replace(".", "-")
			sequence = sequence.upper()
			out[name] = (out[name] + sequence)
	length = set()
	for v in out.values():
		length.add(len(v))
	return out, length

def testComplete(sequence, out, keep=True):
	test = True
	if keep == False:
		# if out.startswith("-") or out.endswith("-"):
		# 	test = False
		outt = out.replace("-", "")
		if outt == "":
			test = False
		else:
			tmp = sequence.replace("-", "")
			if tmp == outt:
				test = False
	return test

def unGap(seq, position, length):
	tmp = seq[position:]
	out = list()
	n = 0
	for i in list(tmp):
		if i == "-" or i == ".":
			n += 1
		else:
			out.append(i)
		if len(out) == length:
			break
	out = "".join(out)
	return [out, n]

def getPositions(fasta, region):
	length = len(region)
	start = None
	end = None
	for sequence in fasta.values():
		lenseq = len(sequence)
		for p in range(0, lenseq-length+1):
			tmp, n = unGap(seq=sequence, position=p, length=length)
			if tmp == region:
				start = p
				end = p+length+n
				break
		if start is not None:
			break
	return [start, end]

# Setting variables ________________________________________________________________________________
if args.verbose:
	print("  Setting variables...", flush=True)

# debug parsed arguments
if args.region is None and args.positions is None and args.sequences is None:
	print("\033[91mError!\033[0m Please provide a region (-r/--region) and/or positions (-p/--positions) and/or a fasta file (-s/--sequences) to trim.")
	sys.exit(0)

# get the order of the given regions to trim
raw_args = sys.argv[1:]
toTrim = {}
for i in raw_args:
	if i == "-p" or i == "--positions":
		for j in args.positions:
			toTrim[j] = "position"
	if i == "-r" or i == "--region":
		for j in args.region:
			toTrim[j] = "region"
	if i == "-s" or i == "--sequences":
		for j in args.sequences:
			tmp = readFasta(j)
			for name, value in tmp[0].items():
				tmp = name.replace(">", "")
				toTrim[value] = "region-" + tmp

if args.path != "":
	import os
	if not os.path.exists(args.path):
		os.makedirs(args.path)

# Get output file name
if args.file_out is None:
	outFile = list()
	path = args.path
	if path != "":
		if not path.endswith("/"):
			path = path + "/"
	if len(toTrim) > 1:
		i = 0
		for k,v in toTrim.items():
			if args.name and v.startswith("region-"):
				name = v.replace("region-", "")
				outFile.append(path + name + ".fasta")
			else:
				i += 1
				ext = "_trimmed" + str(i) + ".fasta"
				name = re.sub("\\.[^\\.]+$", ext, args.filein)
				outFile.append(path + name)
	else:
		name = re.sub("\\.[^\\.]+$", "_trimmed.fasta", args.filein)
		outFile.append(path + name)
else:
	outFile = list()
	for j in args.file_out:
		outFile.append(j)

# debug parsed arguments
if len(toTrim) != len(outFile):
	print("\033[91mError!\033[0m Please provide the same number of trimming regions and output files.")
	sys.exit(0)

# Reading input file _______________________________________________________________________________
if args.verbose:
	print("  Reading fasta:", str(args.filein), flush=True)
fasta, lenFas = readFasta(args.filein)

if len(lenFas) != 1:
	print("\033[91mError!\033[0m", args.filein, "must be aligned. Please provide an aligned fasta file.")
	sys.exit(0)

if args.verbose:
	print("    Sequences in file:", len(fasta), flush=True)
	print("    Aligned positions:", *lenFas, flush=True)
	if args.keep:
		print("  Empty, incomplete or sequences composed only of the trimmed region will be deleted")
	else:
		print("  All sequences will be kept")

# Loop through the alignment searching for the region to trim ______________________________________
for outFilei, (region, value) in zip(outFile, toTrim.items()):
	if args.verbose:
		print("  Trimming region:", region)
	count = 0
	with open(outFilei, "w", buffering=1) as outfile:
		if value == "position":
			tmp = region.split('-')
			start = int(tmp[0])
			end = int(tmp[1])
			for name, sequence in fasta.items():
				out = sequence[start-1:end]
				if testComplete(sequence, out, keep=args.keep):
					count += 1
					print(">" + name, file=outfile, flush=True)
					print(out, file=outfile, flush=True)
		else:
			if region in fasta.keys():
				start = None
				end = None
				tmp = list(fasta[region])
				i = -1
				for j in tmp:
					i += 1
					if j != '-' and start is None:
						start = i
					if j == '-' and start is not None and end is None:
						end = i
				for name, sequence in fasta.items():
					out = sequence[start:end]
					if testComplete(sequence, out, keep=args.keep):
						count += 1
						print(">" + name, file=outfile, flush=True)
						print(out, file=outfile, flush=True)
			else:
				start, end = getPositions(fasta, region)
				for name, sequence in fasta.items():
					out = sequence[start:end]
					if testComplete(sequence, out, keep=args.keep):
						count += 1
						print(">" + name, file=outfile, flush=True)
						print(out, file=outfile, flush=True)
	if count == 0:
		print("    Warning! Region", region, "not found", flush=True)
		print("      Output file", outFilei, "not written", flush=True)
		import os
		os.remove(outFilei)
	else:
		print("    Output file written to:", outFilei, flush=True)
		print("      Sequences out:    ", str(count), flush=True)
		print("      Aligned positions:", len(out), flush=True)

if args.verbose:
	print("Done")
