#!/opt/conda/conda-bld/zol_1751092618345/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold/bin/python

### Program: cgcg
### Author: Rauf Salamzade
### Kalan Lab
### UW Madison, Department of Medical Microbiology and Immunology

# BSD 3-Clause License
#
# Copyright (c) 2023, Kalan-Lab
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
os.environ["OMP_NUM_THREADS"] = "1"
import sys
import argparse
from rich_argparse import RawTextRichHelpFormatter
from zol import util
import traceback
import gravis as gv
from colour import Color
import json
import numpy

def create_parser():
	""" Parse arguments """
	parser = argparse.ArgumentParser(description="""
	Program: cgcg
	Author: Rauf Salamzade
	Affiliation: Kalan Lab, UW Madison, Department of Medical Microbiology and Immunology

	collapsed gene clusters graph (cgcg; pun: see(ಠಠ)-gcg)

	cgcg takes as input a zol results directory and uses the gravis library to summarize 
	gene clusters into a graphical representation where nodes are ortholog groups and edges
	indicate the Markovian information determined by zol as part of its algorithm for 
	determining consensus order/direction. The coloring of the node corresponds to a 
	quantitative evolutionary statistic. The size of the node is the median length in 100 bp 
	units. The border color of the node indicates whether it is in the sense or antisense 
	direction. It is inspired by various pangenome graph visualization software and the STRING 
	web application by the Bork lab.

	Note, if you want to remove border colors - you can do so by setting --node-border-size to 0,
	but then you should probably also switch to an undirected graph to avoid misintrepretation 
	of gene order/directionality. 

	Also, the "major path" which can be shown in gold is simply the path which is most supported
	across gene cluster instances, it is not all inclusive of every ortholog group like the 
	consensus path calculated in zol.
	---------------------------------------------------------------------------------------------
	Example command:
	---------------------------------------------------------------------------------------------
	$ cgcg -i zol_Results_Directory/ -o cgcg_Results/ --color-track conservation
	---------------------------------------------------------------------------------------------
	Citation notice:
	---------------------------------------------------------------------------------------------
	- gravis: Interactive graph visualizations with Python and HTML/CSS/JS. 
	  https://github.com/robert-haas/gravis by Haas 2021. 
								  
	- zol & fai: large-scale targeted detection and evolutionary investigation of gene clusters 
	  https://www.biorxiv.org/content/10.1101/2023.06.07.544063v3 by Salamzade et al. 2023

	""", formatter_class=RawTextRichHelpFormatter)

	parser.add_argument('-i', '--zol-results-dir', help='Path to zol results directory.', required=True)
	parser.add_argument('-o', '--outdir', help='Output directory for cgc analysis.', required=True)
	parser.add_argument('-c', '--color-track', help='The track from the zol spreadsheet to use for coloring.\nOptions include: conservation, tajimas_d, entorpy,\nupstream_entropy, median_beta_rd, max_beta_rd, fst,\nalternative_conservation, bgc_score, viral_score, busrted_pval.\n[Default is conservation].', required=False, default="conservation")
	parser.add_argument('-vl', '--low-value', type=float, help='Low value [Default is 0.0].', required=False, default=0.0)
	parser.add_argument('-vh', '--high-value', type=float, help='High value [Default is 1.0].', required=False, default=1.0)
	parser.add_argument('-cl', '--low-value-color', help='Hex code for color for low-value. Surround by quotes\n[Default is "#a2b7e8"].', required=False, default="#a2b7e8")
	parser.add_argument('-ch', '--high-value-color', help='Hex code for color for high-value. Surround by quotes\n[Default is "#102f75"].', required=False, default="#102f75")
	parser.add_argument('-nc', '--na-value-color', help='The hex code for setting color of NA/non-numeric values.\nSuround by quotes. [Default is" #adadad"].', required=False, default="#adadad")
	parser.add_argument('-mc', '--min-conservation', help='Minimum conservation of ortholog group to be shown\n[Default is 0.25].', required=False, default=0.25)
	parser.add_argument('-mer', '--min-edge-ratio', help='Minimum ratio of weight between two ortholog groups\nto the maximum weight observed to be shown [Default is 0.05].', required=False, default=0.05) 
	parser.add_argument('-sl', '--show-labels', action='store_true', help='Show orthogroup labels.', required=False, default=False)
	parser.add_argument('-cul', '--custom-labels', help='Tab-separated file with OG identifiers as first column\nand labels to use as the second column.', required=False, default=None)
	parser.add_argument('-cuc', '--custom-colors', help='Tab-separated file with OG identifiers as first column\nand hex-code for colors to use as the second column.', required=False, default=None)
	parser.add_argument('-as', '--arrow-size', type=int, help='Arrow size [Default is 10].', required=False, default=10)
	parser.add_argument('-nbs', '--node-border-size', type=int, help='Node border size [Default is 1].', required=False, default=1)
	parser.add_argument('-f', '--flip', action='store_true', help='Flag to flip the direction of arrows and the border\ncoloring for orthogroup consensus direction.', required=False, default=False)
	parser.add_argument('-u', '--undirected-graph', action='store_true', help='Flag for hiding arrows.', required=False, default=False)
	parser.add_argument('-bc', '--background-color', help='The background color. Surround by quotes [Default is\n"#FFFFFF"].', required=False, default="#FFFFFF")
	parser.add_argument('-sm', '--show-major-path', action='store_true', help='Flag to show edges which belong to the major path in\ngold.', required=False, default=False)

	args = parser.parse_args()
	return args

track_name_to_full_label = {'conservation': 'Proportion of Total Gene Clusters with OG', 'tajimas_d': 'Tajima\'s D', 'entropy': 'Entropy',
							'upstream_entropy': 'Upstream Region Entropy', 'median_beta_rd': 'Median Beta-RD-gc', 'max_beta_rd': 'Max Beta-RD-gc',
							'focal_conservation': 'Proportion of Focal Gene Clusters with OG', 'fst': 'Fixation Index',
							'alterante_conservation': 'Proportion of Comparator Gene Clusters with OG', 'bgc_score': 'BGC score (GECCO weights)', 
							'viral_score': 'Viral score (V-Score)', 'busted_pval': 'P-value for gene-wide episodic selection by BUSTED'}

track_full_label_to_name = dict((v,k) for k,v in track_name_to_full_label.items())

def cgcg():
	myargs = create_parser()

	zol_results_dir = os.path.abspath(myargs.zol_results_dir) + '/'
	outdir = myargs.outdir
	color_track = myargs.color_track
	low_value = myargs.low_value
	high_value = myargs.high_value
	low_value_color = myargs.low_value_color
	high_value_color = myargs.high_value_color
	na_value_color = myargs.na_value_color
	min_conservation = myargs.min_conservation
	min_edge_ratio = myargs.min_edge_ratio
	show_labels_flag = myargs.show_labels
	custom_labels_file = myargs.custom_labels
	custom_colors_file = myargs.custom_colors
	flip_flag = myargs.flip
	arrow_size = myargs.arrow_size
	node_border_size = myargs.node_border_size
	undirected_graph = myargs.undirected_graph
	show_consensus_flag = myargs.show_major_path
	background_color = myargs.background_color

	try:
		assert (os.path.isdir(zol_results_dir))
	except:
		sys.stderr.write('The zol results directory required as input does not exist.\n')
		sys.exit(1)

	try:
		if custom_labels_file != None:
			assert (os.path.isfile(custom_labels_file))
	except:
		sys.stderr.write('A custom labels file was provided but the path could not be validated to exist.\n')
		sys.exit(1)

	try:
		if custom_colors_file != None:
			assert (os.path.isfile(custom_colors_file))
	except:
		sys.stderr.write('A custom colors file was provided but the path could not be validated to exist.\n')
		sys.exit(1)

	if os.path.isdir(outdir):
		sys.stderr.write("Output already directory exists. Overvwriting...\n ")

	outdir = os.path.abspath(outdir) + '/'
	findir = outdir + 'Final_Results/'

	util.setupReadyDirectory([outdir, findir])

	# create logging object
	log_file = outdir + 'Progress.log'
	logObject = util.createLoggerObject(log_file)
	version_string = util.getVersion()

	sys.stdout.write('Running cgcg version %s\n' % version_string)
	logObject.info('Running cgcg version %s' % version_string)

	# log command used
	parameters_file = outdir + 'Command_Issued.txt'
	parameters_handle = open(parameters_file, 'a+')
	parameters_handle.write(' '.join(sys.argv) + '\n')
	parameters_handle.close()

	# Step 0: Parse custom colors/labels files
	og_to_color = None
	if custom_colors_file != None:
		og_to_color = {}
		with open(custom_colors_file) as occf:
			for line in occf:
				line = line.strip()
				ls = line.split('\t')
				og_to_color[ls[0]] = ls[1]
	
	og_to_label = None
	if custom_labels_file != None:
		og_to_label = {}
		with open(custom_labels_file) as oclf:
			for line in oclf:
				line = line.strip()
				ls = line.split('\t')
				og_to_label[ls[0]] = ls[1]

	# Step 1: Parse relationships between gene clusters detected by fai and prepTG genomes
	msg = '------------------Step 1------------------\nAssessing input provided.'
	logObject.info(msg)
	sys.stdout.write(msg + '\n')

	zol_spreadsheet_tsv = zol_results_dir + 'Final_Results/Consolidated_Report.tsv'
	order_information_tsv = zol_results_dir + 'Markovian_Order_Information.txt'
	consensus_path_tsv = zol_results_dir + 'Consensus_Path_Information.txt'
	try:
		assert(os.path.isfile(zol_spreadsheet_tsv))
	except:
		msg = 'The "Final_Results/Consolidated_Report.tsv" table results from zol analysis is not available in the zol directory provided.'
		logObject.error(msg)
		sys.stderr.write(msg + '\n')
		sys.exit(1)

	try:
		assert(os.path.isfile(order_information_tsv))
	except:
		msg = 'The "Markovian_Order_Information.txt" table results from zol analysis is not available in the zol directory provided.'
		logObject.error(msg)
		sys.stderr.write(msg + '\n')
		sys.exit(1)

	try:
		assert(os.path.isfile(consensus_path_tsv))
	except:
		msg = 'The "Consensus_Path_Information.txt" table results from zol analysis is not available in the zol directory provided.'
		logObject.error(msg)
		sys.stderr.write(msg + '\n')
		sys.exit(1)


	# Step 2: Generate color mapping
	msg = '------------------Step 2------------------\nDetermining color scale.'
	logObject.info(msg)
	sys.stdout.write(msg + '\n')

	color_map = {}
	try:
		low_color_obj = Color(low_value_color)
		color_list = list(low_color_obj.range_to(high_value_color, 100))
		range_length = high_value - low_value
		assert(range_length > 0)
		for i, val in enumerate(numpy.linspace(low_value, high_value, num=100)):
			color_map[float(val)] = color_list[i].hex
	except:
		msg = 'Issue creating color palette using the python colour package.'
		logObject.error(msg)
		sys.stderr.write(msg + '\n')
		logObject.error(traceback.format_exc())
		sys.stderr.write(traceback.format_exc() + '\n')
		sys.exit(1)

	if custom_colors_file == None:
		createColorLegend(low_value, high_value, low_value_color, high_value_color, outdir, findir, logObject)	

	# Step 3: Create gJCF input for gravis
	msg = '------------------Step 3------------------\nCreating gJCF input for gravis.'
	logObject.info(msg)
	sys.stdout.write(msg + '\n')

	try:
		assert(background_color.startswith('#') and len(background_color) == 7)
	except:
		msg = 'background color provided is not valid - either does not start with "#" or length != 7.'
		sys.stderr.write(msg + '\n')
		logObject.error(msg)
		sys.exit(1)

	metadata = {'background_color': background_color, 'node_border_size': node_border_size,
			    'arrow_size': arrow_size, 'arrow_color': '#999696', 'node_label_size': 14}
	graph_object = {'graph': {'directed': (not undirected_graph), 'metadata': metadata, 'nodes': {}, 'edges': []}}
	node_name_to_id = {}
	retained_ogs = set([])
	dom_mode = False
	try:
		header = []
		with open(zol_spreadsheet_tsv) as ozst:
			for i, line in enumerate(ozst):
				line = line.strip('\n')
				ls = line.split('\t')
				if i == 0: 
					header = ls
					if 'Single-Linkage Full Protein Cluster' in header:				
						dom_mode = True
				else:
					graph_object['graph']['nodes']
					og = ls[0]
					og_id = i
					node_name_to_id[og] = og_id
					conservation = float(ls[2])
					if conservation < min_conservation: continue
					og_metadata = {'label': og}
					if og_to_label != None:
						if og in og_to_label:
							og_metadata['label'] = og_to_label[og]
						else:
							og_metadata['label'] = ""
					retained_ogs.add(og)
					median_length = 'NA'
					conservation = 'NA'
					tajimas_d = 'NA'
					pgap_annotation = 'NA'
					pfam_domains = 'NA'
					og_con_dir = '+' 
					og_con_ord = -1
					for j, val in enumerate(ls):
						col_name = header[j]
						if col_name == track_name_to_full_label[color_track]:	
							try:
								val = float(val)
								closest_color_val = None
								min_dist = 1e100
								for increment in color_map:
									dist = abs(increment-val)
									if dist < min_dist:
										closest_color_val = color_map[increment]
										min_dist = dist
								assert(closest_color_val != None)
								og_metadata['color'] = closest_color_val
								og_metadata['hover'] = str(val)
							except:
								og_metadata['color'] = na_value_color
						if col_name == "OG Median Length (bp)":
							median_length = float(val)
							og_metadata[col_name] = median_length/100.0
						elif col_name == 'Proportion of Total Gene Clusters with OG':
							conservation = val
						elif col_name == 'Tajima\'s D':
							tajimas_d = val
						elif col_name == 'Pfam Domains':
							pfam_domains = val
						elif col_name == 'GAP Annotation (E-value)':
							pgap_annotation = val
						elif col_name == 'OG Consensus Direction':
							og_con_dir = val.replace('"', '')
						elif col_name == 'OG Consensus Order':
							og_con_ord = int(val)

					og_con_color = '#c94962'
					if flip_flag:
						if og_con_dir == '-':
							og_con_color = '#242323'
					else:
						if og_con_dir == '+':
							og_con_color = '#242323'
						
					og_metadata['border_color'] = og_con_color
					median_length = str(median_length)
					if og_to_color != None:
						if og in og_to_color:
							og_metadata['color'] = og_to_color[og]
						else:
							og_metadata['color'] = na_value_color

					og_details = f"""
<table style="width:100%">
  <tr>
    <td>Category</td>
    <td>Value</td>
  </tr>
  <tr>
    <td>Ortholog Group (OG) ID</td>
    <td>{og}</td>
  </tr>
  <tr>
	<td>Ortholog Group Consensus Direction</td>
	<td>{og_con_dir}</td>
  </tr>
  <tr>
	<td>Ortholog Group Consensus Order</td>
	<td>{og_con_ord}</td>
  </tr>
  <tr>
    <td>Median Length (bp)</td>
    <td>{median_length}</td>
  </tr>
  <tr>
    <td>Conservation</td>
    <td>{conservation}</td>
  </tr>
  <tr>
    <td>Tajima's D</td>
    <td>{tajimas_d}</td>
  </tr>
  <tr>
    <td>PGAP Annotation (E-value)</td>
    <td>{pgap_annotation}</td>
  </tr>
  <tr>
    <td>Pfam Domains</td>
    <td>{pfam_domains}</td>
  </tr>
</table>
					"""

					og_metadata['click'] = og_details
					graph_object['graph']['nodes'][og_id] = {'label': og_metadata['label'], 'metadata': og_metadata}
		
	except:
		msg = 'Issue appending node information to gJCF.'
		logObject.error(msg)
		sys.stderr.write(msg + '\n')
		logObject.error(traceback.format_exc())
		sys.stderr.write(traceback.format_exc() + '\n')
		sys.exit(1)
	
	consensus_og_pairs = set([])
	try:
		with open(consensus_path_tsv) as ocpt:
			for line in ocpt:
				line = line.strip()
				ls = line.split('\t')
				if 'start' in ls: continue
				if 'end' in ls: continue
				if dom_mode:
					consensus_og_pairs.add(tuple(sorted(['D' + ls[0], 'D' + ls[1]])))
				else:					
					consensus_og_pairs.add(tuple(sorted([ls[0], ls[1]])))
	except:
		msg = 'Issue reading in consensus path information.'
		logObject.error(msg)
		sys.stderr.write(msg + '\n')
		logObject.error(traceback.format_exc())
		sys.stderr.write(traceback.format_exc() + '\n')
		sys.exit(1)		

	try:
		edge_weights = []
		with open(order_information_tsv) as ooit:
			for i, line in enumerate(ooit):
				line = line.strip('\n')
				ls = line.split('\t')
				if i == 0: continue
				og, og_after, weight, og_dir, og_after_dir = ls
				if og == 'start': continue
				if og_after == 'end': continue
				if dom_mode:
					og = 'D' + og
					og_after = 'D' + og_after
				if not og in retained_ogs or not og_after in retained_ogs: continue
				edge_weights.append(int(weight))
		
		max_edge_weight = 0
		if len(edge_weights) > 0:
			max_edge_weight = max(edge_weights)

		with open(order_information_tsv) as ooit:
			for i, line in enumerate(ooit):
				line = line.strip('\n')
				ls = line.split('\t')
				if i == 0: continue
				og, og_after, weight, og_dir, og_after_dir = ls
				if og == 'start': continue
				if og_after == 'end': continue
				if dom_mode:
					og = 'D' + og
					og_after = 'D' + og_after
				if flip_flag:
					og_after = ls[0]
					og = ls[1]
				
				if not og in retained_ogs or not og_after in retained_ogs: continue
				
				og_pair = tuple(sorted([og, og_after]))
				edge_color = '#999696'
				if show_consensus_flag:
					if og_pair in consensus_og_pairs:
						edge_color = '#b57c0b'
	
				if max_edge_weight > 0:
					edge_weight_ratio = int(weight)/max_edge_weight
					if edge_weight_ratio >= min_edge_ratio:
						edge_dict = {'source': node_name_to_id[og], 'target': node_name_to_id[og_after], 
						             'metadata': {'edge_size': edge_weight_ratio, 
									 'color': edge_color}}
						graph_object['graph']['edges'].append(edge_dict)
	except:
		msg = 'Issue appending edge information to gJCF.'
		logObject.error(msg)
		sys.stderr.write(msg + '\n')
		logObject.error(traceback.format_exc())
		sys.stderr.write(traceback.format_exc() + '\n')
		sys.exit(1)

	try:
		gJCF_json_file = outdir + 'gJCF_input.json'
		with open(gJCF_json_file, 'w') as f:
			json.dump(graph_object, f)
	except:
		msg = 'Issue saving gJCF graph object (input for gravis) to json file.'
		logObject.error(msg)
		sys.stderr.write(msg + '\n')
		logObject.error(traceback.format_exc())
		sys.stderr.write(traceback.format_exc() + '\n')
		sys.exit(1)

	# Step 4: Run gravis for plot generation and save to HTML report
	msg = '------------------Step 4------------------\nRunning gravis for final graph creation.'
	logObject.info(msg)
	sys.stdout.write(msg + '\n')
	html_file = findir + 'cgcg_gravis_visual.html'	
	try:
		fig = gv.d3(graph_object, use_node_size_normalization=False, 
			        node_size_data_source='OG Median Length (bp)', node_label_data_source='label',
					use_edge_size_normalization=False, edge_curvature=0.3, 
					show_node_label=show_labels_flag, show_menu=True, show_details=True,
					graph_height=400, details_height=200, edge_size_factor=2, 
					edge_size_data_source='edge_size')
		fig.export_html(html_file)
	except:
		msg = 'Issue running gravis or saving to HTML.'
		logObject.error(msg)
		sys.stderr.write(msg + '\n')
		logObject.error(traceback.format_exc())
		sys.stderr.write(traceback.format_exc() + '\n')
		sys.exit(1)

	msg = '------------------DONE !------------------\ncgcg finished successfully!\nResulting files can be found at: %s' % findir
	logObject.info(msg)
	sys.stdout.write(msg + '\n')
	sys.exit(0)

def createColorLegend(low_value, high_value, low_value_color, high_value_color, workspace, findir, logObject):
	rscript_path = workspace + 'plot_legend.R'
	try:
		legend_png_file = findir + 'plot_legend.png'
		rout = open(rscript_path, 'w')
		rout.write('library(ggplot2)\n')
		rout.write('library(grid)\n')
		rout.write('library(cowplot)\n\n')
		rout.write('png("' + legend_png_file + '", height=10, width=5, units="in", res=600)\n')
		rout.write('V1 <- c(' + str(low_value) + ', ' + str(high_value) + ')\n')
		rout.write('V2 <- c("low_value", "high_value")\n')
		rout.write('dat <- data.frame(V1, V2)\n')
		rout.write('my_hist <- ggplot(dat, aes(x=V1, y=1, fill = V1)) + geom_bar(stat="identity") + theme(legend.title=element_blank()) +\n')
		rout.write('scale_fill_gradient(low="' + low_value_color + '", high="' + high_value_color + '", limits=c(' + str(low_value) + ', ' + str(high_value) + '))\n')
		rout.write('legend <- cowplot::get_legend(my_hist)\n')
		rout.write('grid.newpage()\n')
		rout.write('grid.draw(legend)\n')
		rout.write('dev.off()\n\n')
		rout.close()
		rscript_cmd = ['Rscript', rscript_path]
		util.runCmdViaSubprocess(rscript_cmd, logObject=logObject, check_files=[legend_png_file])
	except:
		msg = 'Issues creating legend for coloring - check out the Rscript for creating it for more details: %s' % rscript_path
		sys.stderr.write(msg + '\n')
		logObject.error(msg)
		sys.exit(1)

if __name__ == '__main__':
	cgcg()
