#!/usr/bin/env python3

import argparse
import re
import sys
import os
import subprocess
from Bio import Align

version='1.1.0'

parser = argparse.ArgumentParser(description="%s v%s: Given an aligned fasta file and a (unaligned) excluding file, it will create a HMM profile from the aligned file, align the excluding file and export the aligned regions to a fasta file." % ('%(prog)s', version), add_help=False,
								 epilog="This script uses HMMER, please cite accordingly (e.g., HMMER 3.4 (Aug 2023); http://hmmer.org/)")

# Add the arguments to the parser
requirArgs = parser.add_argument_group('Required arguments')
functiArgs = parser.add_argument_group('Optional arguments related to homology')
outputArgs = parser.add_argument_group('Optional arguments related to the output')
optionArgs = parser.add_argument_group('Other optional arguments')

requirArgs.add_argument("-f", "--file", dest="filein", required=True,
						help="An aligned fasta file.")

requirArgs.add_argument("-e", "--excluding", dest="excluding", required=True,
						help="A excluding fasta file to look against.")

functiArgs.add_argument("-l", "--length", dest="length", required=False, action="store", type=str,
						default="0.25",
						help="A minimum coverage of the aligned region. Default = %(default)s, this means that the aligned sequence from the excluding file must be at least 25%% long relative to the HMM profile. If the value provided is a 'float' [0.0-1.0], it will be interpreted as a percentage. If the value provided is a 'int(eger)' [1-Inf), it will be interpreted as an absolute number of bases.")

outputArgs.add_argument("-o", "--output", dest="file_out", required=False,
						default=None,
						help="The name of the output fasta file. By default will remove the extension of the input file and add '_homologs.fasta'. If the fasta file exists, it will append the sequences at the end and detect similar sequence headings that will be ignored.")

outputArgs.add_argument("-s", "--score", dest="score", required=False, action="store", type=str,
						default="0.90",
						help="A minimum pairwise alignment score to be exported (see 'PairwiseAligner' from the Bio.Align module for further details). Default = %(default)s, this means that the aligned sequence from the excluding file must be at least 90%% similar to the HMM profile. If the value provided is a 'float' [0.0-1.0], it will be interpreted as a percentage. If the value provided is a 'int(eger)' [1-Inf), it will be interpreted as an absolute number of bases.")

outputArgs.add_argument("-c", "--complete", dest="complete", required=False, action="store_true",
						help="If selected, will export the complete homolog sequence and not just the region.")

outputArgs.add_argument("-k", "--keep", dest="keep", required=False, action="store_true",
						help="If selected, will not delete the temporary files ('OUTPUT_hmm_profile.hmm' and 'OUTPUT_hmm_align.fasta') before exiting.")

outputArgs.add_argument("-t", "--target", dest="target", required=False, action="store_true",
						help="If selected, will also print the (unaligned) input file to the head of the output file.")

outputArgs.add_argument("-d", "--delete", dest="delete", required=False, action="store_true",
						help="If selected, will delete the output file (if exists) before writing it.")

optionArgs.add_argument("-v", "--verbose", dest="verbose", required=False, action="store_false",
						help="If selected, will not print information to the console.")

optionArgs.add_argument("-h", "--help", action="help",
						help="Show this help message and exit.")

optionArgs.add_argument("-V", "--version", action='version',
						version='oligoN-design %s v%s' % ('%(prog)s', version),
						help='Show the version number and exit.')

args = parser.parse_args()

# Define functions _________________________________________________________________________________

def string2numeric(string):
	try:
		return int(str(string))
	except ValueError:
		try:
			return float(str(string))
		except ValueError:
			return None

def readFasta(fastafile, lengthOut=False):
	out = {}
	for line in open(fastafile):
		if line.startswith(">"):
			name = line.replace(">", "")
			name = name.replace("\n", "")
			out[name] = str()
		else:
			sequence = line.replace("\n", "")
			sequence = sequence.replace(".", "-")
			sequence = sequence.upper()
			out[name] = (out[name] + sequence)
	if lengthOut:
		length = set()
		for v in out.values():
			length.add(len(v))
		return out, length
	else:
		return out

def hmmBuild(hmmOut, alignment):
	command = f"hmmbuild --dna " + hmmOut + " " + alignment
	subprocess.run(command, shell=True, check=True, stdout=subprocess.DEVNULL)

def hmmAlign(alignment, hmmProfile, fasta, output, trim=False):
	if trim:
		command = f"hmmalign --outformat afa --trim --mapali " + alignment + " -o " + output + " " + hmmProfile + " " + fasta
	else:
		command = f"hmmalign --outformat afa --mapali " + alignment + " -o " + output + " " + hmmProfile + " " + fasta
	subprocess.run(command, shell=True, check=True, stdout=subprocess.DEVNULL)

# Setting variables ________________________________________________________________________________
if args.verbose:
	print("  Setting variables...", flush=True)

if args.file_out is None:
	outFile = re.sub("\\.[^\\.]+$", "_homologs.fasta", args.filein)
	outProfile = re.sub("\\.[^\\.]+$", "_hmmProfile.hmm", args.filein)
	outAlign = re.sub("\\.[^\\.]+$", "_hmmAlign.fasta", args.filein)
else:
	outFile = args.file_out
	outProfile = re.sub("\\.[^\\.]+$", "_hmm_profile.hmm", args.file_out)
	outAlign = re.sub("\\.[^\\.]+$", "_hmm_align.fasta", args.file_out)

length = string2numeric(args.length)
if type(length) is float:
	lengthAlign = "percentage"
elif type(length) is int:
	lengthAlign = "absolute"
else:
	print("Error: argument -l/--length needs a numerical value: either a float [0.0-1.0] or an integer [1-Inf] (i.e., '0.9')")
	sys.exit(1)

minScore = string2numeric(args.score)
if type(minScore) is float:
	scoreType = "percentage"
elif type(minScore) is int:
	scoreType = "absolute"
else:
	print("Error: argument -l/--length needs a numerical value: either a float [0.0-1.0] or an integer [1-Inf] (i.e., '0.9')")
	sys.exit(1)

# Reading input file _______________________________________________________________________________
if args.verbose:
	print("  Reading region:", str(args.filein), flush=True)
regions, lenRegs = readFasta(args.filein, lengthOut=True)

if len(lenRegs) != 1:
	print("\033[91mError!\033[0m", args.filein, "must be aligned. Please provide an aligned fasta file.")
	sys.exit(0)

if args.verbose:
	print("    Sequences in region file:   ", len(regions), flush=True)
	print("    Aligned positions:          ", *lenRegs, flush=True)

# Check if output file exists, and if it so read and store sequence headers ________________________
toIgnore = set()
if os.path.isfile(outFile):
	if args.delete:
		if args.verbose:
			print("  Output file already exists. Deleting it.")
		os.remove(outFile)
	elif args.verbose:
		print("  Output file already exists: ", outFile, sep="", flush=True)
		print("    Reading and storing identifiers", flush=True)
		for line in open(outFile):
			if line.startswith(">"):
				tmp = line.replace(">", "")
				tmp = tmp.replace("\n", "")
				toIgnore.add(tmp)

# Create a HMM profile _____________________________________________________________________________
if args.verbose:
	print("  Creating a HMM profile", flush=True)
hmmBuild(outProfile, args.filein)

profile = list()
for line in open(outProfile):
	if re.search("^ +[0-9]+ ", line):
		tmp = line.strip().split()
		profile.append(tmp[6])
profile = "".join(profile)

if args.verbose:
	print("    HMM profile:", profile, flush=True)
	print("    Length:     ", len(profile), flush=True)
if lengthAlign == "percentage":
	if args.verbose:
		print("    Minimum alignment length: ", round(length*100), "% (", round(length*len(profile)), " bases)", sep="", flush=True)
	length = round(length*len(profile))
if lengthAlign == "absolute":
	if args.verbose:
		print("    Minimum alignment length: ", length, " bases (", round(length/len(profile)*100), "%)", sep="", flush=True)
if scoreType == "percentage":
	if args.verbose:
		print("    Minimum alignment similarity: ", round(minScore*100), "% (", round(minScore*len(profile)), " identical bases)", sep="", flush=True)
	minScore = round(minScore*len(profile))
if scoreType == "absolute":
	if args.verbose:
		print("    Minimum alignment similarity: ", minScore, " identical bases (", round(minScore/len(profile)*100), "%)", sep="", flush=True)

# Align profile to excluding file __________________________________________________________________
if args.verbose:
	print("  Alignining all sequences from excluding file to the HMM profile", flush=True)
hmmAlign(args.filein, outProfile, args.excluding, outAlign)

# Extract aligned region ___________________________________________________________________________
if args.verbose:
	if args.complete:
		print("  Extracting aligned regions", flush=True)
	else:
		print("  Extracting and completing aligned regions", flush=True)
aligned = readFasta(outAlign)

regionStart = None
regionEnd = None
regionAligned = None
out = {}
sort = {}
for key, value in aligned.items():
	if key in regions.keys():
		if regionStart is None or regionEnd is None:
			tmp = set(list(value.replace("-", "")))
			s, e = list(), list()
			for i in tmp:
				s.append(value.find(i))
				e.append(value.rfind(i))
			regionStart, regionEnd = min(s), max(e)+1
			regionAligned = value[regionStart:regionEnd]
	else:
		if key not in toIgnore:
			regioni = value[regionStart:regionEnd]
			regionUngapped = re.sub("-", "", regioni)
			if len(regionUngapped) >= length:
				start = len(regioni) - len(regioni.lstrip('-'))
				posstart = regionAligned[:start]
				posstart = posstart.replace("-", "")
				start = len(posstart)
				end = len(regioni) - len(regioni.rstrip('-'))
				if end != 0:
					posend = regionAligned[-end:]
					posend = posend.replace("-", "")
					end = len(posend)
				prefix = value[regionStart-start:regionStart]
				prefix = prefix.replace("-", "")
				suffix = value[regionEnd:regionEnd+end]
				suffix = suffix.replace("-", "")
				regionOut = prefix + regionUngapped + suffix
				if args.complete:
					out[key] = re.sub("-", "", value)
				else:
					out[key] = regionOut
				aligner = Align.PairwiseAligner()
				score = aligner.score(profile.upper(), regionOut)
				sort[key] = score

sort = dict(sorted(sort.items(), key=lambda x: x[1], reverse=True))
inFile = re.sub("\\.[^\\.]+$", "", args.filein)
count = 0
with open(outFile, "a", buffering=2) as outfile:
	if args.target:
		for key, val in regions.items():
			print(">" + key, file=outfile, flush=True)
			valout = val.replace("-", "")
			print(valout, file=outfile, flush=True)
	print(">" + inFile + "_HMM_profile", file=outfile, flush=True)
	print(profile, file=outfile, flush=True)
	for key, val in sort.items():
		if val >= minScore:
			count += 1
			print(">" + key, file=outfile, flush=True)
			print(out[key], file=outfile, flush=True)

if args.keep is False:
	if args.verbose:
		print("  Deleting temporary files", flush=True)
	os.remove(outProfile)
	os.remove(outAlign)

if args.verbose:
	print("  \033[1mThere were", len(out), "total hits\033[0m", flush=True)
	if scoreType == "percentage":
		minScore = string2numeric(args.score)
		print("    Of which ", count, " were at least ", minScore, "% similar", sep="", flush=True)
	if scoreType == "absolute":
		print("    Of which ", count, " had at least ", minScore, "identical aligned bases", sep="", flush=True)
	if args.keep:
		print("  Aligned regions exported to:", outAlign, flush=True)
		print("  HMM profile written to:", outProfile, flush=True)
	print("  Output file written to:", outFile, flush=True)
	print("Done")
