#!/usr/bin/env python3
# mkdir prova_mqc && cd prova_mqc

# rm -rf data_for_mqc/ report_by_mqc/ && ../bin/dadaist2-mqc-report-test -i ../example-output -o data_for_mqc && multiqc -f -o report_by_mqc -c data_for_mqc/mqc.yaml data_for_mqc/
# ./bin/dadaist2-mqc-report -i example-output -o example-report && multiqc -f -o example-mqc -c example-report/config.yaml example-report

# apparentemente bisogna specificare -c 
# provare --lint
#from IPython import embed #FIXME

import argparse
import os

def read_fasta(path):
	import gzip
	name = None
	with (gzip.open if path.endswith('.gz') else open)(path, 'rt') as fasta:
		for line in fasta:
			if line.startswith('>'):
				if name is not None:
					yield name, seq
				name = line[1:].rstrip()
				seq = ''
			else:
				seq += line.rstrip()
	yield name, seq

# read taxonomy
def load_taxa(path):
	with open(path, 'rt') as f:
		header = f.readline().rstrip('\r\n')
		sep = '\t' if '\t' in header else ' '
		columns = header.split(sep)
		return pandas.read_csv(f, sep=sep, index_col=0, names=['ROW'] + columns, quotechar='"', dtype=str)

def get_taxa_slice(tax_df, i, j=None):
	start, stop = (0, i) if j is None else (i, j)
	assert start < stop
	return tax_df.fillna('NA').apply(lambda r: '-'.join(r[start:stop]), raw=True, axis=1)

if __name__ == '__main__':
	parser = argparse.ArgumentParser(description='Produce multiqc report')
	parser.add_argument('-i', '--input-dir', required=True)
	parser.add_argument('-t', '--toptaxa', type=int, default=10)
	parser.add_argument('-o', '--output-dir', required=True)
	args = parser.parse_args()

	# create output dir if needed
	os.makedirs(args.output_dir, exist_ok=True)

	import pandas
	feat_df = pandas.read_csv(args.input_dir + '/feature-table.tsv', sep='\t', index_col=0)
	samples = feat_df.columns

	# TAXONOMY
	tax_df = load_taxa(args.input_dir + '/taxonomy.txt')
	assert len(tax_df) == len(feat_df)
	tax_df.index = feat_df.index

	max_groups = args.toptaxa + 1

	taxa_labels = []
	taxa_data = []
	for level, level_name in enumerate(tax_df.columns, 1):
		tax_df['FULL_TAXA'] = get_taxa_slice(tax_df, level)
		for start in range(level, 0, -1):
			tax_df['SHORT_TAXA'] = get_taxa_slice(tax_df, start - 1, level)
			short2full = tax_df.groupby('SHORT_TAXA')['FULL_TAXA'].agg(set)
			if all(short2full.str.len() == 1):
				break
		
		feat_df['TAXA'] = tax_df['SHORT_TAXA']
		feat_by_taxa_df = feat_df.groupby('TAXA').sum()
		main_taxa = feat_by_taxa_df.sum(axis=1).sort_values(ascending=False).index
		if len(main_taxa) > max_groups:
			feat_df.loc[feat_df['TAXA'].isin(main_taxa[max_groups - 1:]), 'TAXA'] = 'OTHER'
			main_taxa = main_taxa[:max_groups - 1].tolist() + ['OTHER']
		feat_by_taxa_df = feat_df.groupby('TAXA').sum().loc[main_taxa]

		#feat_by_taxa_df.T.to_csv(args.output_dir + '/taxacounts_{}.tsv'.format(level), sep='\t')
		taxa_labels.append(level_name)
		taxa_data.append({sample: counts.to_dict() for sample, counts in feat_by_taxa_df.items()})
		
	# DATA STATS
	dadastats_df = pandas.read_csv(args.input_dir + '/dada2_stats.tsv', sep='\t', index_col=0)
	dadastats_raw_data = {sample: counts.to_dict() for sample, counts in dadastats_df.iterrows()}
	#dadastats_df.iloc[:,1:].div(dadastats_df.iloc[:,0])
	if False:
		dadastats_delta_df = pandas.DataFrame({
			'filter failed': dadastats_df['input'] - dadastats_df['filtered'],
			'denoise failed': dadastats_df['filtered'] - dadastats_df['denoised'],
			'merge failed': dadastats_df['denoised'] - dadastats_df['merged'],
			'chimeric': dadastats_df['merged'] - dadastats_df['non-chimeric'],
			'good': dadastats_df['non-chimeric'],
		})
		dadastats_delta_data = {sample: counts.to_dict() for sample, counts in dadastats_delta_df.iterrows()}
		dadastats_ratio_df = pandas.DataFrame({
			'filtered': dadastats_df['filtered']/dadastats_df['input'],
			'denoised': dadastats_df['denoised']/dadastats_df['filtered'],
			'merged': dadastats_df['merged']/dadastats_df['denoised'],
			'non-chimeric': dadastats_df['non-chimeric']/dadastats_df['merged'],
		})*100 # percentage
		dadastats_ratio_data = {sample: counts.to_dict() for sample, counts in dadastats_ratio_df.iterrows()}
		
	#dadastats_diff = dadastats_df.copy()
	#for step, name in enumerate(dadastats_diff.columns[:-1]):
	#	dadastats_diff[name] -= dadastats_diff.iloc[:, step + 1]
	#dadastats_data = {sample: counts.to_dict() for sample, counts in dadastats_diff.iterrows()}
	
	# FASTA REPORT
	max_seqs = 3
	top_feats = feat_df.sum(axis=1).sort_values(ascending=False)
	top_feats_unassigned = top_feats.index[feat_df.loc[top_feats.index, 'TAXA'].str.endswith('NA')][:max_seqs]
	top_feats_assigned = top_feats.index.difference(top_feats_unassigned)[:max_seqs]
	top_feats_union = top_feats_assigned.union(top_feats_unassigned)
	seqs = {name: seq for name, seq in read_fasta(args.input_dir + '/rep-seqs.fasta') if name in top_feats_union}
	with open(args.output_dir + '/seqs.html', 'wt') as f:
		def feat2html(feat):
			print('''<tr class="table-primary">
					<td><strong>&gt;{name}</strong> <span class="badge badge-pill badge-info">counts: {tot}</span> <span class="badge badge-pill badge-primary">ratio: {ratio:.1%}</span> <span style="background-color: #5bc0de;" class="badge badge-pill badge-primary">taxonomy: {taxonomy}</span></td>
					<!--<td>Tax</td> -->
				</tr>
				<tr>
					<td style="word-wrap: break-word;min-width: 160px;max-width: 160px;">{seq}</td>
				</tr>'''.format(name=feat, taxonomy=tax_df.loc[feat]['FULL_TAXA'], seq=seqs[feat], tot=top_feats.at[feat], ratio=top_feats.at[feat]/top_feats.sum()), file=f)

		def opentable(title, filehandle):	
			print('<h3>' + title + '</h3>', file=filehandle)
			print('<div class="table-responsive"><table class="table table-hover table-sm">', file=filehandle)
			print('<tbody><!-- <tr><th>Sequence name</th> <th>Taxonomy</th></tr>-->\n', file=filehandle)
		
		def closetable(filehandle):
			print('</tbody></table></div><div class="w-100 my-4"></div>', file=filehandle)

		opentable('Most common sequences (with taxonomy)', f)
		for feat in top_feats_assigned:
			feat2html(feat)
		closetable(f)

		opentable('Most common sequences (unclassified)', f) 
		print('<ul>', file=f)
		for feat in top_feats_unassigned:
			feat2html(feat)
		closetable(f)
			
		
	df = feat_df[samples]
	dfn = df.div(df.sum()) # normalized by total counts

	# OCTAVE PLOT
	import math
	octave_breaks = [2**i for i in range(math.ceil(math.log2(df.max().max())) + 1)]
	octave_plot = df.apply(lambda sample: pandas.cut(sample, octave_breaks, right=False).value_counts())
	octave_plot.index = octave_plot.index.map(lambda interval: '{}-{}'.format(interval.left, interval.right - 1))
	octave_plot.to_csv(args.output_dir + '/octaveplot.tsv', sep='\t')

	# BRAY-CURTIS dissimilarity
	import itertools
	bray_curtis = pandas.DataFrame(0.0, index=samples, columns=samples)
	for i, j in itertools.combinations(samples, 2):
		cij = feat_df[[i, j]] # FIXME qui prendo le conte grezze, forse dovrei normalizzare prima?
		S = cij.min(axis=1).sum()
		T = cij.sum(axis=1).sum()
		d = 1 - 2*S/float(T)
		bray_curtis.at[i, j] = d
		bray_curtis.at[j, i] = d
	bray_curtis.to_csv(args.output_dir + '/braycurtis.tsv', sep='\t')
	
	# mqc configuration
	# for plot config see https://github.com/ewels/MultiQC/blob/master/docs/plots.md
	mqc_config = {
		'title': 'Dadaist2',
		#'subtitle': 'FIXME854',
		'intro_text': 'MultiQC report generated from the Dadaist2 pipeline',
		#'report_header_info': {
		#	'Contact E-mail': 'andrea.telatin@quadram.ac.uk',
		#	'Application Type': 'FIXME',
		#},
		'run_modules': [
			'fastp',
			'custom_content',
		],
		'module_order': [
			'fastp',
			'custom_content',
		],
		'custom_content': {
			'order': [
				'qc', 'samples'
			],
		},
		'custom_data': {
			'dada2stats_raw': {
				'parent_id': 'qc',
				'parent_name': 'Quality Checks',
				'section_name': 'QC',
				#'parent_description': 'This section reports the taxonomic profiles of all ranks',
				'plot_type': 'bargraph',
				'data': dadastats_raw_data,
				'pconfig': {
					'id': 'cleaning',
					'title': 'Data2 Stats: read filtering',
					'stacking': False,
					'cpswitch': False,
					#'ylab': 'FIXME295',
				},
			},
#			'dada2stats_deltas': {
#				'parent_id': 'qc',
#				'parent_name': 'Quality Checks',
#				'section_name': 'QC',
#				'plot_type': 'bargraph',
#				#'data': [dadastats_delta_data, dadastats_ratio_data],
#				'data': dadastats_delta_data,
#				'pconfig': {
#					'id': 'cleaning',
#					'title': 'Data2 Stats: read filtering',
#					#'ylab': 'FIXME295',
#				},
#			},
#			'dada2stats_ratios': {
#				'parent_id': 'qc',
#				'parent_name': 'Quality Checks',
#				'section_name': 'QC',
#				#'parent_description': 'This section reports the taxonomic profiles of all ranks',
#				'plot_type': 'bargraph',
#				'data': dadastats_ratio_data,
#				'pconfig': {
#					'id': 'dada2stats_ratios',
#					'title': 'Data2 Stats: read filtering ratios',
#					'stacking': False,
#					'cpswitch': False,
#					'ylab': 'Percentages',
#				},
#			},
			'octaveplot': {
				'parent_id': 'qc',
				#'section_name': 'Octave plot',
				'description': 'Aboundance distribution of sequences per sample. Sequences are binned by their log2 counts for each sample. See https://doi.org/10.1101/389833',
				#'plot_type': 'linegraph # FIXME penso questo funzionerebbe meglio ma il report si pianta quando lo uso',
				'plot_type': 'bargraph',
				'pconfig': {
					'xlab': 'Number of sequences',
					'ylab': 'Count bin',
					'cpswitch': False,
					'stacking': False,
					'id': 'octaveplot',
					'title': 'Dadaist2: Octave plot',
				},
			},
			'braycurtis': {
				'parent_id': 'samples',
				'parent_name': 'Samples',
				#'parent_description': 'This section reports the taxonomic profiles of all ranks',
				#'section_name': 'Bray-Curtis dissimilarity',
				'description': 'Bray-Curtis dissimilarity between samples',
				'plot_type': 'heatmap',
				#'ymin': '0',
				#'ymax': '1',
				'pconfig': {
					'title': 'Dadaist2: Bray-Curtis dissimilarity',
					'id': 'braycurtis',
					#'ylab': 'FIXME694',
				},
			},
			'seqs': {
				'parent_id': 'qc',
				'parent_name': 'Quality Checks',
				'section_name': 'Most common sequences',
				#'description': 'FIXME11',
				'plot_type': 'html',
			},
			'taxa': {
				'parent_id': 'samples',
				#'title': 'Ciao',
				'plot_type': 'bargraph',
				'pconfig': {
					'id': 'taxonomy',
					'title': 'Dadaist2: Taxonomy',
					'data_labels': taxa_labels,
					'cpswitch_c_active': False,
					#'ylab': 'FIXME038',
				},
				'data': taxa_data,
			},
		},
		'sp': {
			'octaveplot': {'fn': 'octaveplot.tsv'},
			'braycurtis': {'fn': 'braycurtis.tsv'},
			'seqs': {'fn': 'seqs.html'},
		}
	}
	import yaml
	with open(args.output_dir + '/config.yaml', 'wt') as f:
		yaml.dump(mqc_config, f, sort_keys=False)
	
	#TODO aggiungere anche qc/A01.json  A02.json  F99.json


	if False:
		# prova matrice similarita'
		del feat_df['TAXA']
		import numpy
		numpy.log(feat_df[samples] + 100)
		norm_feat_df = numpy.log(feat_df[samples] + 100) # 100 
		#.std(axis=1).corr(feat_df[samples].sum(axis=1))
		feat_df.T.to_csv('giovanni_corr_mqc.tsv', sep='\t')

