#!/usr/bin/env python3

import argparse
import re
import sys

version='1.1.0'

parser = argparse.ArgumentParser(description="%s v%s: Select the N best oligonucleotides from a tab delimited table by ranking chosen columns and selecting the oligonucleotides with the lowest average rank. This script works best if you want to select by many columns (i.e. >~4). Otherwise a manual selection might work better." % ('%(prog)s', version), add_help=False)

# Add the arguments to the parser
requirArgs = parser.add_argument_group('Required arguments')
functiArgs = parser.add_argument_group('Optional arguments related to the selection')
outputArgs = parser.add_argument_group('Optional arguments related to the output')
optionArgs = parser.add_argument_group('Other optional arguments')

requirArgs.add_argument("-f", "--file", dest="table", required=True,
						help="A tab delimited table to be filtered.")

functiArgs.add_argument("-c", "--columns", dest="columns", nargs="+", required=False,
						default = ['hitsT', 'hitsE', 'mismatch1', 'mismatch2'],
						help="The columns to be ranked for selecting the best candidate probes. Default= %(default)s")

functiArgs.add_argument("-r", "--order", dest="order", nargs="+", required=False,
						default=['defaults'],
						help="The order in which the given columns will be ranked (i.e., descending or ascending numbers). The initials can also be used. Default= %(default)s, which will apply common orders after the column names used in this pipeline (i.e., 'descending ascending ascending ascending' for the default column names).")

functiArgs.add_argument("-i", "--ids", dest="ids", required=False, nargs="+", 
						default=[1,2],
						help="The columns bearing the 'oligonucleotide name' and the 'sequence'. It can be given either the column names or the number of the columns. Default = %(default)s (first and second columns). ")

functiArgs.add_argument("-w", "--weights", dest="weights", required=False, default=None, type=float, nargs="+",
						help="A list of weights for the respective columns to calculate the weighted mean. By default will calculate an arithmetic mean.")

functiArgs.add_argument("-T", "--thorough", dest="thorough", required=False, action="store_true",
						help="A shortcut to apply a thorough selection of the columns to be ranked: length, hitsT, hitsE, average_brightness, self-dimer_count, hairpin_count, max_consecutive, mismatch1_thorough, mismatch2_thorough.")

outputArgs.add_argument("-o", "--outFile", dest="outFile", required=False, default=None,
						help="The name of the output log file. By default, will print the ranked oligos to the console.")

outputArgs.add_argument("-n", "--number", dest="number", required=False, action='store', type=int,
						default=None,
						help="The number of the best oligonucleotides to be shown. By default will rank all.")

optionArgs.add_argument("-v", "--verbose", dest="verbose", required=False, action="store_false",
						help="If selected, will not print information to the console.")

optionArgs.add_argument("-h", "--help", action="help",
						help="Show this help message and exit.")

optionArgs.add_argument("-V", "--version", action='version',
						version='oligoN-design %s v%s' % ('%(prog)s', version),
						help='Show the version number and exit.')

args = parser.parse_args()

# Define functions _________________________________________________________________________________
def rank(unranked, order='ascending'):
	if order == 'ascending':
		unrankeds = sorted(unranked, reverse=False)
	if order == 'descending':
		unrankeds = sorted(unranked, reverse=True)
	out = list()
	for i in unranked:
		out.append(unrankeds.index(i) + 1)
	return out

def testIntOrStr(string):
	try:
		return int(string)
	except ValueError:
		try:
			return str(string)
		except ValueError:
			return None

def getIDs(idList, headers):
	out = list()
	for i in idList:
		tmp = testIntOrStr(i)
		if type(tmp) is int:
			tmp = int(i)-1
			if tmp > len(headers):
				tmp = str(i)
		elif type(tmp) is str:
			j = 0
			for h in headers:
				if h == i:
					tmp = j
				j += 1
		if tmp in out:
			print("  Warning!", i, "is duplicated...")
		else:
			out.append(tmp)
	return out

def mean(values):
	s = sum(values)
	n = len(values)
	out = s / n
	return out

def meanw(values, weights):
	n = sum(weights)
	s = 0
	for i in range(0, len(values)):
		s += values[i] * weights[i]
	out = s / n
	return out

ordering = {
	"length":           "ascending",
	"hitsT_prop":       "descending",
	"hitsT":            "descending",
	"hitsE_prop":       "ascending",
	"hitsE":            "ascending",
	"mismatch":         "ascending",
	"brightness":       "descending",
	"self-dimer_count": "ascending",
	"hairpin_count":    "ascending",
	"max_consecutive":  "ascending"
	}

def getOrdering(columns, ordering=ordering):
	out = list()
	for i in columns:
		if i in ordering.keys():
			out.append(ordering[i])
		elif "mismatch" in i:
			out.append("ascending")
		elif "brightness" in i:
			out.append("descending")
		else:
			print("  \033[91mError!\033[0m", i, "is not used for ranking oligos in this pipeline. Please remove it from this selection, or apply manual ordering.")
			sys.exit(0)
	return out

# Setting variables and troubleshooting ____________________________________________________________
if args.verbose:
	print("  Setting variables")

if len(args.ids) < 2:
	print("\033[91mError!\033[0m Please provide at least two values through -i/--ids. Exiting...")
	sys.exit(1)
if len(args.ids) > 2:
	print("  Warning! More than two values were passed through -i/--ids. Only the first two will be considered.")

for i in args.order:
	if i != "descending" and i != "d" and i != "ascending" and i != "a" and i != "defaults":
		print(args.order, i, len(i))
		print("  \033[91mError!\033[0m Either provide 'ascending' (or 'a'), 'descending' (or 'd') or 'defaults' as parsed values to -o/--order. Exiting...")
		sys.exit(1)

if len(args.order) == 1:
	if args.order[0] == "defaults":
		order = "defaults"
elif len(args.columns) != len(args.order):
	print("  \033[91mError!\033[0m Different number of columns (", len(args.columns), ") and orders (", len(args.order),") have been provided. Exiting...", sep="")
	sys.exit(1)
else:
	order = None

# Setting variables and troubleshooting ____________________________________________________________
with open(args.table) as filein:
	headers = filein.readline()
headers = headers.strip("\t").split()

ids = getIDs(args.ids, headers)
colOligo = ids[0]
colSeq = ids[1]

if args.thorough:
	if args.verbose:
		print("  Applying a thorough ranking ")
	colnames = ["length", "hitsT", "hitsE", "average_brightness", "self-dimer_count", "hairpin_count", "max_consecutive", "mismatch1_thorough", "mismatch2_thorough"]
	columns = getIDs(colnames, headers)
	order = getOrdering(colnames)
else:
	columns = getIDs(args.columns, headers)
	exit = False
	for i in ids + columns:
		if type(i) == str:
			if type(testIntOrStr(i)) == int:
				print("  \033[91mError!\033[0m Headers have length", len(headers), "and", str(i), "has been provided.")
			else:
				print("  \033[91mError!\033[0m", i, "is not found in the headers table.")
			exit = True
	if exit:
		print("  Please provide valid column names or column numbers.\nExiting...")
		sys.exit(1)
	colnames = list()
	for i in columns:
		colnames.append(headers[i])
	if order == "defaults":
		order = getOrdering(colnames)
	else:
		order = list()
		for i in args.order:
			if i == "descending" or i == "d":
				order.append("descending")
			if i == "ascending" or i == "a":
				order.append("ascending")

weights = None
if args.weights is not None:
	if len(args.weights) != len(columns):
		print("  \033[91mError!\033[0m Different number of columns and orders (", len(args.columns), ") that weights (", len(args.weights),"). Exiting...", sep="")
		sys.exit(1)
	weights = list()
	for i in args.weights:
		weights.append(float(i))

# Reading input ____________________________________________________________________________________
if args.verbose:
	print("  Ranking oligonucleotides by:")
	if args.weights is None:
		for i in range(0, len(colnames)):
			print("", colnames[i], order[i], sep="\t")
	else:
		print("\tColumn\tOrder\tWeight")
		for i in range(0, len(colnames)):
			print("", colnames[i], order[i], args.weights[i], sep="\t")
data = {}
variables = {}
rows = 0
for line in open(args.table):
	rows += 1
	if rows > 1:
		line = line.strip("\t").split()
		data[line[colOligo]] = line[colSeq]
		if len(ids) > 2:
			for j in ids[2:]:
				data[colOligo].append(line[j])
		for j in columns:
			if j not in variables.keys():
				variables[j] = list()
				variables[j].append(line[j])
			else:
				variables[j].append(line[j])

rows = rows - 1

# Ranking variables ________________________________________________________________________________
variablesR = {}
for i in range(0, len(variables)):
	variablesR[columns[i]] = rank(variables[columns[i]], order[i])

# Calculating the mean of the ranks ________________________________________________________________
means = list()
for i in range(0, rows):
	tmp = list()
	for j in columns:
		tmp.append(variablesR[j][i])
	if args.weights is None:
		means.append(mean(tmp))
	else:
		means.append(meanw(tmp, weights))

# Ordering by the mean of the ranks ________________________________________________________________
if args.number is None:
	number = rows
else:
	number = args.number
meansR = rank(means)
out = {}
for n in range(0, number):
	i = 0
	for m in meansR:
		if n + 1 == m:
			out[i] = list()
			out[i].append(m)
		i += 1

if args.outFile:
	idOligos = {}
for n in out.keys():
	oligo = list(data.keys())[n]
	if args.outFile:
		if oligo not in idOligos.keys():
			idOligos[oligo] = out[n][0]
	sequence = data[oligo]
	out[n].append(oligo)
	out[n].append(sequence)

# Printing the ranking _____________________________________________________________________________
if args.outFile is None:
	if args.number is None:
		print("  The ranking of the oligonucleotides is:")
	else:
		print("  The", number, "best oligonucleotides are:")
	tmp = list()
	for i in ids:
		tmp.append(headers[i])
	print("", "rank", *tmp, sep="\t")
	i = 0
	for val in out.values():
		i += 1
		print("", *val, sep="\t")

# Export a table with the ranking if selected ______________________________________________________
if args.outFile is not None:
	if args.verbose:
		print("  Exporting table to:", args.outFile)
	with open(args.outFile, "w", buffering=1) as outfile:
		lineout = "\t".join(headers) + "\trank"
		print(lineout, file=outfile, flush=True)
		for line in open(args.table):
			tmp = line.strip().split("\t")[colOligo]
			if tmp in idOligos.keys():
				line = line.replace("\n", "")
				lineout = line + "\t" + str(idOligos[tmp])
				print(lineout, file=outfile, flush=True)

if args.verbose:
	print("Done")
