#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1767619825/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python
""" Plot the rate of triggers across the template bank
"""
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import numpy, argparse, pycbc.pnutils
from pycbc.io.hdf import HFile

from pycbc import init_logging, add_common_pycbc_options

parser = argparse.ArgumentParser(description=__doc__)
add_common_pycbc_options(parser)
parser.add_argument('--trigger-files', nargs='+')
parser.add_argument('--bank-file')
parser.add_argument('--output-file')
parser.add_argument('--chisq-bins')
args = parser.parse_args()

init_logging(args.verbose)

bf = HFile(args.bank_file)
m1 = bf['mass1'][:]
m2 = bf['mass2'][:]

mc, et = pycbc.pnutils.mass1_mass2_to_mchirp_eta(m1, m2)

num_templates = len(m1) 
template_num = numpy.zeros(num_templates)
max_snr = numpy.zeros(num_templates)

num_trigs = 0
snrs = []
chisqs = []

for trig_filename in args.trigger_files:
    f = HFile(trig_filename)
    tid = f['template_id'][:]
    
    # Count triggers produced by each template
    tsort = tid.argsort()
    tid = tid[tsort]
    snr = f['snr'][:][tsort]
    chisq = f['chisq'][:][tsort]
    
    u = numpy.unique(tid)
    l = numpy.searchsorted(tid, u, side='left')
    r = numpy.searchsorted(tid, u, side='right')
    n = r - l
    template_num[u] += n
    
    snrs += [snr]
    chisqs += [chisq]
    
    num_trigs += len(tsort)

chisq = numpy.concatenate(chisqs) / (float(args.chisq_bins) * 2 - 2)
snr = numpy.concatenate(snrs)

plt.figure()
plt.scatter(snr[0:1000000], chisq[0:1000000])
plt.xlim(6, 8)
plt.ylim(.8, 3)
plt.savefig('snrchi.png')
   
plt.figure() 
plt.scatter(et, m1+m2, c=template_num, s=template_num/template_num.max()*50)
plt.ylabel('Total Mass')
plt.xlabel('Eta')
plt.colorbar()
plt.savefig(args.output_file)
