#!/opt/conda/conda-bld/zol_1751092618345/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold/bin/python

### Program: atpoc
### Author: Rauf Salamzade
### Kalan Lab
### UW Madison, Department of Medical Microbiology and Immunology

# BSD 3-Clause License
#
# Copyright (c) 2023, Kalan-Lab
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
os.environ["OMP_NUM_THREADS"] = "1"
import sys
import argparse
from time import sleep
from zol import util, fai
from Bio import SeqIO
from collections import defaultdict
import pandas as pd
import math 
import traceback
from rich_argparse import RawTextRichHelpFormatter	

def create_parser():
	""" Parse arguments """
	parser = argparse.ArgumentParser(description="""
	Program: atpoc
	Author: Rauf Salamzade
	Affiliation: Kalan Lab, UW Madison, Department of Medical Microbiology and Immunology

	atpoc - Assess Temperate Phage-Ome Conservation 
	
	atpoc wraps fai to assess the conservation of a sample's integrated/temperate phage-ome 
	relative to a set of target genomes (e.g. genomes belonging to the same genus). Alternatively, 
	it can run a simple DIAMOND BLASTp analysis to just assess the presence of prophage genes
	individually - without the requirement they are co-located like in the focal sample.
	""", formatter_class=RawTextRichHelpFormatter)

	parser.add_argument('-i', '--sample-genome', help='Path to sample genome in GenBank or FASTA format.', required=True, default=True)
	parser.add_argument('-vi', '--vibrant-results', help='Path to VIBRANT results directory for a single sample/genome.', required=False, default=None)
	parser.add_argument('-ps', '--phispy-results', help='Path to PhiSpy results directory for a single sample/genome.', required=False, default=None)
	parser.add_argument('-gn', '--genomad-results', help='Path to GeNomad results directory for a single sample/genome.', required=False, default=None)
	parser.add_argument('-tg', '--target-genomes-db', help='prepTG database directory for target genomes of interest.', required=True)
	parser.add_argument('-up', '--use-pyrodigal', action='store_true', help="Use default pyrodigal instead of prodigal-gv to call genes in\nphage regions to use as queries in fai/simple-blast. This\nis perhaps preferable if target genomes db was created with default pyrodigal/prodigal.", required=False, default=False)
	parser.add_argument('-fo', '--fai-options', help='Provide fai options to run. Should be surrounded by quotes. [Default is "-e 1e-10 -m 0.5 -dm -sct 0.4"]', required=False, default="-e 1e-10 -m 0.5 -dm -sct 0.4")
	parser.add_argument('-s', '--use-simple-blastp', action='store_true', help='Use a simple DIAMOND BLASTp search with no requirement for co-localization of hits.', required=False, default=False)
	parser.add_argument('-si', '--simple-blastp-identity-cutoff', type=float, help='If simple BLASTp mode requested : cutoff for identity between query proteins and matches in target genomes [Default is 40.0].', required=False, default=40.0)
	parser.add_argument('-sc', '--simple-blastp-coverage-cutoff', type=float, help='If simple BLASTp mode requested : cutoff for coverage between query proteins and matches in target genomes [Default is 70.0].', required=False, default=70.0)
	parser.add_argument('-se', '--simple-blastp-evalue-cutoff', type=float, help='If simple BLASTp mode requested : cutoff for E-value between query proteins and matches in target genomes [Default is 1e-5].', required=False, default=1e-5)
	parser.add_argument('-sm', '--simple-blastp-sensitivity-mode', help='Sensitivity mode for DIAMOND BLASTp. [Default is "very-sensititve"].', required=False, default="very-sensitive")
	parser.add_argument('-o', '--outdir', help='Output directory.', required=True)
	parser.add_argument('-c', '--threads', type=int, help='The number of threads to use [Default is 1].', required=False, default=1)

	args = parser.parse_args()
	return args

def atpoc():
	myargs = create_parser()

	sample_genome = myargs.sample_genome
	vibrant_results_dir = myargs.vibrant_results
	phispy_results_dir = myargs.phispy_results
	genomad_results_dir = myargs.genomad_results
	preptg_db_dir = myargs.target_genomes_db
	outdir = os.path.abspath(myargs.outdir) + '/'
	fai_options = myargs.fai_options
	use_simple_blastp_flag = myargs.use_simple_blastp
	use_pyrodigal_flag = myargs.use_pyrodigal
	simple_blastp_identity_cutoff = myargs.simple_blastp_identity_cutoff
	simple_blastp_coverage_cutoff = myargs.simple_blastp_coverage_cutoff
	simple_blastp_evalue_cutoff = myargs.simple_blastp_evalue_cutoff
	simple_blastp_sensitivity_mode = myargs.simple_blastp_sensitivity_mode
	threads = myargs.threads

	# create output directory if needed, or warn of over-writing
	if os.path.isdir(outdir):
		sys.stderr.write("Output directory exists. Overwriting in 5 seconds ...\n ")
		sleep(5)
	util.setupReadyDirectory([outdir])

	if (vibrant_results_dir == None or not os.path.isdir(vibrant_results_dir)) and (phispy_results_dir == None or not os.path.isdir(phispy_results_dir) and (genomad_results_dir == None or not os.path.isdir(genomad_results_dir))):
		sys.stderr.write("Neither PhiSpy, VIBRANT, nor geNomad sample results directory provided as input or the paths are faulty!\n")
		sys.exit(1)

	if util.is_genbank(sample_genome):
		try:
			sample_genome_fasta = outdir + '.'.join(sample_genome.split('/')[-1].split('.')[:-1]) + '.fna'
			util.convertGenomeGenBankToFasta([sample_genome, sample_genome_fasta])
			assert(os.path.isfile(sample_genome_fasta) and util.is_fasta(sample_genome_fasta))
			sample_genome = sample_genome_fasta
		except:
			sys.stderr.write("Issue converting sample's genome from GenBank to FASTA format.\n")
			sys.exit(1)

	if not util.is_fasta(sample_genome):
		sys.stderr.write("Issue with processing/validating focal sample's genome. Was it in FASTA or GenBank format?\n.")
		sys.exit(1)

	# create logging object
	log_file = outdir + 'Progress.log'
	logObject = util.createLoggerObject(log_file)
	version_string = util.getVersion()

	sys.stdout.write('Running atpoc version %s\n' % version_string)
	logObject.info('Running atpoc version %s' % version_string)

	# log command used
	parameters_file = outdir + 'Command_Issued.txt'
	parameters_handle = open(parameters_file, 'a+')
	parameters_handle.write(' '.join(sys.argv) + '\n')
	parameters_handle.close()

	sys.stdout.write('Parsing phage prediction results\n')
	logObject.info('Parsing phage prediction results')


	scaff_lengths = {}
	try:
		with open(sample_genome) as osg:
			for rec in SeqIO.parse():
				scaff_lengths[rec.id] = len(str(rec.seq))
	except:
		msg = 'Issue processing genome file.'
		sys.stderr.write(msg + '\n')
		sys.stderr.write(traceback.format_exc() + '\n')

	sample_phages = []
	if vibrant_results_dir != None and os.path.isdir(vibrant_results_dir):
		for root, dirs, files in os.walk(vibrant_results_dir):
			for basename in files:
				filename = os.path.join(root, basename)
				if basename.startswith('VIBRANT_integrated_prophage_coordinates_') and basename.endswith('.tsv'):
					with open(filename) as ofn:
						for i, line in enumerate(ofn):
							if i == 0: continue
							line = line.strip()
							scaff, phage_id, _, _, _, start_coord, end_coord, phage_length = line.split('\t')
							sample_phages.append(['VIBRANT', 'VB-' + phage_id, scaff, start_coord, end_coord, phage_length, ''])

	if phispy_results_dir != None and os.path.isdir(phispy_results_dir):
		for root, dirs, files in os.walk(phispy_results_dir):
			for basename in files:
				filename = os.path.join(root, basename)
				if basename == 'prophage_coordinates.tsv':
					with open(filename) as ofn:
						for i, line in enumerate(ofn):
							line = line.strip('\n')
							phage_id, scaff, start_coord, end_coord, _, _, _, _, attL_seq, attR_seq, att_explanation = line.split('\t')
							phage_length = abs(int(start_coord)-int(end_coord))
							additional_info = '; '.join(['attL_sequence=', attL_seq, 'attR_sequence=' + attR_seq, 'att_explanation=' + att_explanation])
							sample_phages.append(['PhiSpy', 'PS-' + phage_id, scaff, start_coord, end_coord, str(phage_length), additional_info])

	if genomad_results_dir != None and os.path.isdir(genomad_results_dir):
		for root, dirs, files in os.walk(genomad_results_dir):
			for basename in files:
				filename = os.path.join(root, basename)
				if basename.endswith('_virus_summary.tsv'):
					with open(filename) as ofn:
						for i, line in enumerate(ofn):
							# seq_name        length  topology        coordinates     n_genes genetic_code    virus_score     fdr     n_hallmarks     marker_enrichment       taxonomy
							if i == 0: continue
							line = line.strip('\n')
							phage_id, phage_length, topology, coords, n_genes, genetic_code, virus_score, fdr, n_hallmarks, marker_enrichment, taxonomy = line.split('\t')
							scaff = phage_id.split('|')[0]
							phage_id = phage_id.replace('|', '_')
							if coords != 'NA':
								start_coord, end_coord = coords.split('-')
							else:
								start_coord = 1 
								end_coord = scaff_lengths[scaff]
							additional_info = '; '.join(['virus_score' + virus_score, 'topology=' + topology, 'genetic_code=' + genetic_code, 'n_hallmarks=' + n_hallmarks, 'marker_enrichment=' + marker_enrichment, 'taxonomy=' + taxonomy])
							sample_phages.append(['geNomad', 'GN-' + phage_id, scaff, start_coord, end_coord, phage_length, additional_info])
	
	sys.stdout.write('Found %d prophage predictions!\n' % len(sample_phages))
	logObject.info('Found %d prophage predictions!' % len(sample_phages))

	fai_results_dir = outdir + 'fai_or_blast_Results/'
	util.setupReadyDirectory([fai_results_dir], delete_if_exist=True)
	phage_info_file = outdir + 'TemperatePhageOme_Info.tsv'
	phage_info_handle = open(phage_info_file, 'w')
	phage_lengths = []
	phage_info_handle.write('\t'.join(['phage_id', 'phage_prediction_software', 'scaffold_id', 'start', 'end', 'phage_length', 'additional_attributes']) + '\n')
	numeric_columns_t1 = ['phage_length']
	for phage in sample_phages:
		phage_software, phage_id, scaff, start_coord, end_coord, phage_length, additional_info = phage

		phage_fai_res_dir = fai_results_dir + phage_id + '/'

		phage_info_handle.write('\t'.join([str(x) for x in [phage_id, phage_software, scaff, start_coord, end_coord, phage_length, additional_info]]) + '\n')	
		phage_lengths.append(float(phage_length))

		if use_simple_blastp_flag:
			# run simple blastp
			util.setupReadyDirectory([phage_fai_res_dir], delete_if_exist=True)
			reference_genome_gbk = phage_fai_res_dir + 'reference.gbk'
			gene_call_cmd = ['runProdigalAndMakeProperGenbank.py', '-i', sample_genome, '-s', 'reference', '-l',
								'OG', '-o', phage_fai_res_dir]
			if use_pyrodigal_flag:
				gene_call_cmd += ['-gcm', 'pyrodigal']
			else:
				gene_call_cmd += ['-gcm', 'prodigal-gv']
			util.runCmdViaSubprocess(gene_call_cmd, logObject, check_files=[reference_genome_gbk])
			locus_genbank = phage_fai_res_dir + 'Reference_Query.gbk'
			locus_proteome = phage_fai_res_dir + 'Reference_Query.faa'
			fai.subsetGenBankForQueryLocus(reference_genome_gbk, locus_genbank, locus_proteome, scaff, int(start_coord), int(end_coord), logObject)
			util.diamondBlastAndGetBestHits(phage_id, locus_proteome, locus_proteome, preptg_db_dir, phage_fai_res_dir, logObject,
								   			identity_cutoff=simple_blastp_identity_cutoff, coverage_cutoff=simple_blastp_coverage_cutoff, 
											blastp_mode=simple_blastp_sensitivity_mode, prop_key_prots_needed=0.0,
											evalue_cutoff=simple_blastp_evalue_cutoff, threads=threads)
		else:
			# run fai 
			fai_cmd = ['fai', '-r', sample_genome, '-rc', scaff, '-rs', start_coord, '-re', end_coord, 
			           '-tg', preptg_db_dir, fai_options, '-o', phage_fai_res_dir, '-c', str(threads)]
			if not use_pyrodigal_flag:
				fai_cmd += ['-rgv']
			genome_wide_tsv_result_file = phage_fai_res_dir + 'Spreadsheet_TSVs/total_gcs.tsv'
			util.runCmdViaSubprocess(fai_cmd, logObject, check_files=[genome_wide_tsv_result_file])
	phage_info_handle.close()

	tg_phage_aai = defaultdict(lambda: defaultdict(lambda: 'NA'))
	tg_phage_sgp = defaultdict(lambda: defaultdict(lambda: 'NA'))
	phages = set([])
	tg_total_phages_found = defaultdict(int)
	for phage in sample_phages:
		phage_software, phage_id, scaff, start_coord, end_coord, phage_length, additional_info = phage
		phages.add(phage_id)
		genome_wide_tsv_result_file = fai_results_dir + phage_id + '/Spreadsheet_TSVs/total_gcs.tsv'
		if use_simple_blastp_flag:
			genome_wide_tsv_result_file = fai_results_dir + phage_id + '/total_gcs.tsv'

		with open(genome_wide_tsv_result_file) as ogwtrf:
			for i, line in enumerate(ogwtrf):
				line = line.strip()
				ls = line.split('\t')
				if i == 0: continue
				tg = ls[0]
				aai = ls[2]
				sgp = ls[4]
				tg_phage_aai[tg][phage_id] = aai
				tg_phage_sgp[tg][phage_id] = sgp
				if float(aai) > 0.0:
					tg_total_phages_found[tg] += 1
	
	target_genome_listing_file = preptg_db_dir + 'Target_Genome_Annotation_Files.txt'
	result_file = outdir + 'TemperatePhageOme_to_Target_Genomes_Similarity.tsv'
	result_handle = open(result_file, 'w')
	result_handle.write('\t'.join(['target_genome', 'shared_phages', 'aggregate_score'] + [x + ' - AAI' + '\t' + x + ' - Proportion Shared Genes' for x in sorted(phages)]) + '\n')
	numeric_columns_t2 = ['shared_phages', 'aggregate_score'] + [x + ' - AAI' for x in sorted(phages)] + [x + ' - Proportion Shared Genes' for x in sorted(phages)]
	tgs = set([])
	all_scores = []
	with open(target_genome_listing_file) as otglf:
		for line in otglf:
			line = line.strip()
			ls = line.split('\t')
			tg = ls[0]
			tgs.add(tg)
			shared_phages = tg_total_phages_found[tg]
			total_score = 0.0
			for phage in sorted(phages):
				if tg_phage_aai[tg][phage] != "NA":
					total_score += float(tg_phage_aai[tg][phage])*float(tg_phage_sgp[tg][phage])
			printlist = [tg, str(shared_phages), str(total_score)]
			all_scores.append(total_score)
			for phage in sorted(phages):
				printlist.append(tg_phage_aai[tg][phage])
				printlist.append(tg_phage_sgp[tg][phage])
			result_handle.write('\t'.join(printlist) + '\n')
	result_handle.close()

	# construct spreadsheet
	atpoc_spreadsheet_file = outdir + 'ATPOC_Results.xlsx'
	writer = pd.ExcelWriter(atpoc_spreadsheet_file, engine='xlsxwriter')
	workbook = writer.book
	header_format = workbook.add_format({'bold': True, 'text_wrap': True, 'valign': 'top', 'fg_color': '#D7E4BC', 'border': 1})

	t1_results_df = util.loadTableInPandaDataFrame(phage_info_file, numeric_columns_t1)
	t1_results_df.to_excel(writer, sheet_name='Sample ProphageOme Overview', index=False, na_rep="NA")

	t2_results_df = util.loadTableInPandaDataFrame(result_file, numeric_columns_t2)
	t2_results_df.to_excel(writer, sheet_name='Similarity to Target Genomes', index=False, na_rep="NA")

	t1_worksheet = writer.sheets['Sample ProphageOme Overview']
	t2_worksheet = writer.sheets['Similarity to Target Genomes']

	t1_worksheet.conditional_format('A1:F1', {'type': 'cell', 'criteria': '!=', 'value': 'NA', 'format': header_format})
	t2_worksheet.conditional_format('A1:HA1', {'type': 'cell', 'criteria': '!=', 'value': 'NA', 'format': header_format})

	row_count = len(phages) + 1
	t2_row_count = len(tgs) + 1

	# phage length (t1)
	t1_worksheet.conditional_format('F2:F' + str(row_count), {'type': '2_color_scale', 'min_color': "#d5f0ea", 'max_color': "#83a39c", 'min_type': 'num', 'max_type': 'num', "min_value": 0, "max_value": max(phage_lengths)})

	# total phage shared (t2)
	t2_worksheet.conditional_format('B2:B' + str(t2_row_count), {'type': '2_color_scale', 'min_color': "#d9d9d9", 'max_color': "#949494", 'min_type': 'num', 'max_type': 'num', "min_value": 0, "max_value": len(phages)})

	# aggregate score (t2)
	t2_worksheet.conditional_format('C2:C' + str(t2_row_count), {'type': '2_color_scale', 'min_color': "#f7db94", 'max_color': "#ad8f40", 'min_type': 'num', 'max_type': 'num', "min_value": 0, "max_value": max(all_scores)})

	phage_index = 3
	for phage in phages:
		columnid_pid = util.determineColumnNameBasedOnIndex(phage_index)
		columnid_sgp = util.determineColumnNameBasedOnIndex(phage_index+1)
		phage_index += 2

		# aai
		t2_worksheet.conditional_format(columnid_pid + '2:' + columnid_pid + str(t2_row_count), {'type': '2_color_scale', 'min_color': "#d8eaf0", 'max_color': "#83c1d4", "min_value": 0.0, "max_value": 100.0, 'min_type': 'num', 'max_type': 'num'})

		# prop genes found
		t2_worksheet.conditional_format(columnid_sgp + '2:' + columnid_sgp + str(t2_row_count), {'type': '2_color_scale', 'min_color': "#e6bec0", 'max_color': "#f26168", "min_value": 0.0, "max_value": 1.0, 'min_type': 'num', 'max_type': 'num'})

	# Freeze the first row of both sheets
	t1_worksheet.freeze_panes(1, 0)
	t2_worksheet.freeze_panes(1, 0)

	# close workbook
	workbook.close()

	sys.stdout.write('Done running atpoc!\nFinal results can be found at:\n%s\n' % (atpoc_spreadsheet_file))
	logObject.info('Done running atpoc!\nFinal results can be found at:\n%s\n' % (atpoc_spreadsheet_file))

	# Close logging object and exit
	util.closeLoggerObject(logObject)
	sys.exit(0)

def determineColumnNameBasedOnIndex(index):
	"""
	Function to determine spreadsheet column name for a given index
	"""
	# offset at 0 
	num_to_char = {}
	alphabet = list('ABCDEFGHIJKLMNOPQRSTUVWXYZ')
	alphabet_first_spot = [''] + alphabet
	for i, c in enumerate(alphabet):
		num_to_char[i] = c
	level = math.floor(index / 26)
	remainder = index % 26
	columname = alphabet_first_spot[level]
	columname += alphabet[remainder]
	return columname

if __name__ == '__main__':
	atpoc()
