#!/usr/bin/env python3

import argparse

version='1.1.0'

parser = argparse.ArgumentParser(description="%s v%s: Breaks a fasta file into k-mers and outputs a tab delimited table with the number of k-mers at given thresholds." % ('%(prog)s', version), add_help=False)

# Add the arguments to the parser
requirArgs = parser.add_argument_group('Required arguments')
functiArgs = parser.add_argument_group('Optional arguments related to k-mers')
outputArgs = parser.add_argument_group('Optional arguments related to the output')
optionArgs = parser.add_argument_group('Other optional arguments')

requirArgs.add_argument("-f", "--file", dest="fileIn", required=True,
						help="A fasta file.")

outputArgs.add_argument("-o", "--output", dest="output", required=False,
						default=None,
						help="The output file name. By default will replace the extension of the input file with '_kmers.tsv'.\n")

functiArgs.add_argument("-l", "--length", dest="length", nargs="+", required=False, action='store', type=str,
						default=['15-25'],
						help="The length(s) of the oligonucleotides to be searched. A range can be specified with the '-' sign. Default = %(default)s.")

functiArgs.add_argument("-m", "--minimum", dest="minimum", required=False, nargs="+", action='store', type=str,
						default=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0],
						help="The minimum percentage (or absolute number) of sequences that the kmer has to appear in the file. Default = %(default)s. If the value provided is a 'float' [0.0-1.0], it will be interpreted as a percentage. If the value provided is a 'int(eger)' [1-Inf), it will be interpreted as an absolute number.")

outputArgs.add_argument("-e", "--fasta", dest="fasta", required=False, action="store",
						default=None,
						help="If selected, will export a fasta file with the extracted k-mers to the given output. The name of the sequences contains the length, an arbitrary identifier and the absolute count in the given file separated by an underscore (e.g., 'length15_kmer1_264'). If the given file exists, will append the k-mers at the end.")

outputArgs.add_argument("-p", "--plot", dest="plot", required=False, action="store_true",
						help="If selected, will prompt a scatterplot with the table.")

outputArgs.add_argument("-s", "--sort", dest="sort", required=False, action="store_true",
						help="If selected, will export the fasta file sorted by abundance.")

outputArgs.add_argument("-a", "--abundance", dest="abundance", required=False, action="store", type=str,
						default=0,
						help="A minimum abundance value to export the kmer to the fasta file. Default = %(default)s")

optionArgs.add_argument("-v", "--verbose", dest="verbose", required=False, action="store_false",
					help="If selected, will not print information to the console.")

optionArgs.add_argument("-h", "--help", action="help",
						help="Show this help message and exit.")

optionArgs.add_argument("-V", "--version", action='version',
					version='oligoN-design %s v%s' % ('%(prog)s', version),
					help='Show the version number and exit.')

args = parser.parse_args()

# Setting variables and functions __________________________________________________________________
if args.verbose:
	print("  Setting variables...", flush=True)

def string2numeric(string):
	try:
		return int(str(string))
	except ValueError:
		try:
			return float(str(string))
		except ValueError:
			return None

def getNumersFromArgs(stringList):
	out = list()
	for m in stringList:
		if '-' in m:
			first = int(m.split('-')[0])
			last = int(m.split('-')[1])
			if first < last:
				for m2 in range(first, last+1):
					out.append(m2)
			elif first > last:
				for m2 in range(first, last-1, -1):
					out.append(m2)
		elif " " in m:
			tocut = m.strip().split(" ")
			for m2 in tocut:
				out.append(int(m2))
		else:
			out.append(int(m))
	return  out

def readFasta(fastafile):
	out = {}
	w = False
	for line in open(fastafile):
		if line.startswith(">"):
			name = line.replace(">", "")
			name = name.replace("\n", "")
			out[name] = str()
		else:
			if "-" in line:
				w = True
			sequence = line.replace("\n", "")
			sequence = sequence.replace("-", "")
			sequence = sequence.upper()
			out[name] = (out[name] + sequence)
	if w:
		print("    Warning!!", fastafile, "contains gaps, they were removed", flush=True)
	return out

def kmersBreak(dictionary, k, verbose=True):
	# Set variables for printing process
	if verbose:
		print("  Extracting ", k, "-mers", end="", sep="", flush=True)
		i = 0
		pcti = 0
	# Get the length of the input file
	lenDict = len(dictionary)
	kmers = {}
	# Loop through all the sequences in the input file
	for tseq in dictionary.values():
		if verbose:
			i += 1
			pct = round(i/lenDict*100)
			if pct > pcti:
				pcti = pct
				print("\r  Extracting ", k, "-mers: ", pct, "%", end="", sep="", flush=True)
		lentseq = len(tseq)
		# Split the given sequence into kmers of length 'k'
		for p in range(0, lentseq-k+1):
			kmer = tseq[int(p):(int(p+k))]
			if kmer not in kmers.keys():
				kmers[kmer] = 1
			else:
				kmers[kmer] += 1
	if verbose:
		print("\n    Total ", k, "-mers extracted: ", len(kmers), sep="", flush=True)
	return kmers

# Setting variables ________________________________________________________________________________
lengths = getNumersFromArgs(args.length)

if args.verbose:
	print("    Lengths:    ", *lengths, flush=True)

minimum = list()
for m in args.minimum:
	tmp = string2numeric(m)
	minimum.append(tmp)

if args.verbose:
	print("    Thresholds: ", *minimum, flush=True)

abundance = string2numeric(args.abundance)

# Reading fasta file _______________________________________________________________________________
if args.verbose:
	print("  Reading fasta file:", str(args.fileIn), flush=True)

fasta = readFasta(args.fileIn)
lenFasta = len(fasta)

if args.verbose:
	print("  Sequences in file: ", str(lenFasta), flush=True)

# Setting output file name _________________________________________________________________________
if args.output is None:
	import re
	output = re.sub("\\.[^\\.]+$", "_kmers.tsv", args.fileIn)
else:
	output = args.output

# Start the search _________________________________________________________________________________
if args.plot:
	plotdata = {}
	plotdata["x"] = list()
with open(output, "w", buffering=2) as logfile:
	lineout = ["length"]
	for m in minimum:
		if type(m) == float:
			lineout.append("pct" + str(m))
			if args.plot:
				plotdata["x"].append(m*100)
		else:
			lineout.append("seqs" + str(m))
			if args.plot:
				plotdata["x"].append(m)
	print("\t".join(lineout), file=logfile)
	# Loop through the different lengths -----------------------------------------------------------
	for length in lengths:
		# Extract kmers from fasta file ------------------------------------------------------------
		kmers = kmersBreak(fasta, length, verbose=args.verbose)
		lenKmers = len(kmers)
		if args.verbose:
			i = 0
			pcti = 0
		# Initialize variables
		out = {}
		for m in minimum:
			if str(m) not in out.keys():
				out[str(m)] = 0
		if args.plot:
			plotdata["l" + str(length)] = list()
		# Check minimum criteria -------------------------------------------------------------------
		for kmer, countTarget in kmers.items():  # Loop through all unique oligonucleotides
			if args.verbose:
				i += 1
				pct = round(i/lenKmers*100)
				if pct > pcti:
					pcti = pct
					print("\r    Checking minimum presence ", pct, "%", sep="", end="", flush=True)
			for m in minimum:
				if type(m) == float:
					R = countTarget/lenFasta
				else:
					R = countTarget
				if R >= m:
					out[str(m)] += 1
		if args.verbose:
			print("")
		lineout = [str(length)]
		for val in out.values():
			lineout.append(str(val))
			if args.plot:
				plotdata["l" + str(length)].append(val)
		print("\t".join(lineout), file=logfile)
		if args.fasta is not None:
			if args.verbose:
				print("    Exporting kmers")
			if args.sort:
				kmers = dict(sorted(kmers.items(), key=lambda x: x[1], reverse=True))
			i = 0
			outFasta = 0
			with open(args.fasta, "a", buffering=1) as fasfile:
				for kmer, countTarget in kmers.items():
					export = False
					if type(abundance) == float:
						if countTarget/lenFasta >= abundance:
							export = True
					else:
						if countTarget >= abundance:
							export = True
					if export:
						i += 1
						outFasta += 1
						print(">length" + str(len(kmer)) + "_kmer" + str(i) + "_" + str(countTarget) + "\n" + str(kmer), file=fasfile, flush=True)

# Plot if selected _________________________________________________________________________________
if args.plot:
	if args.verbose:
		print("  Plotting table")
	import matplotlib.pyplot as plt
	x = plotdata["x"]
	for key, value in plotdata.items():
		if key != "x":
			plt.scatter(plotdata["x"], plotdata[key], label=key, alpha=0.5)
	plt.legend()
	plt.show()

# Print concluding information _____________________________________________________________________
if args.verbose:
	print("  Kmers count written to:", output)
	if args.fasta is not None:
		print("  Kmer sequences written to:", args.fasta)
		print("    Exported kmers:", outFasta)
		if abundance > 0:
			if type(abundance) == float:
				print("    Excluding kmers with an abundance < ", abundance*100 , "% (", str(round(abundance*lenFasta)), " sequences)", sep="")
			else:
				print("    Excluding kmers with an abundance < ", abundance , " sequences (", str(round(abundance/lenFasta*100)), "%)", sep="")
	print("Done")
