#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1767619825/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python
"""
Compute source probabilities using mchirp estimation method for all events
in a chunk with an IFAR above certain threshold.
"""
import json
import tqdm
import argparse
import logging
import numpy as np

import pycbc
from pycbc.io import hdf
from pycbc.pnutils import mass1_mass2_to_mchirp_eta
from pycbc import mchirp_area

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--trigger-file', required=True)
parser.add_argument('--bank-file', required=True)
parser.add_argument('--single-detector-triggers', nargs='+', required=True)
parser.add_argument('--search-tag', required=True,
                    help='String to add to the output file names '
                         'identifying the search. eg: PYCBC_AllSky, '
                         'PYCBC_HighMass')
parser.add_argument('--ifar-threshold', type=float, default=None,
                    help='Select only candidate events with IFAR '
                         'above threshold.')
mchirp_area.insert_args(parser)
args = parser.parse_args()

mc_area_args = mchirp_area.from_cli(args, parser)

pycbc.init_logging(args.verbose)

TRIGGER_FILE = args.trigger_file.split('/')[-1]
GPS_START_TIME = TRIGGER_FILE.split('-')[2]
GPS_START_TIME_NS = TRIGGER_FILE.split('-')[3].split('.')[0]

logging.info('Using files: %s, %s, %s', TRIGGER_FILE,
             args.bank_file.split('/')[-1],
             (',').join([name.split('/')[-1] for name in
                         args.single_detector_triggers]))

dir_path = 'source_probability_results/CHUNK_%s-%s/' % (GPS_START_TIME,
                                                        GPS_START_TIME_NS)
pycbc.makedir(dir_path)
logging.info('Saving results in %s', dir_path)

fortrigs = hdf.ForegroundTriggers(args.trigger_file, args.bank_file,
                                  sngl_files=args.single_detector_triggers)

ifar = fortrigs.get_coincfile_array('ifar')
N_original = len(ifar)

if args.ifar_threshold:
    idx = ifar > args.ifar_threshold
    ifar = ifar[idx]
    logging.info('%i triggers out of %i with IFAR > %s' %
                 (len(ifar), N_original, str(args.ifar_threshold)))
else:
    idx = np.full(N_original, True)

mass1 = fortrigs.get_bankfile_array('mass1')[idx]
mass2 = fortrigs.get_bankfile_array('mass2')[idx]
mchirp,_ = mass1_mass2_to_mchirp_eta(mass1, mass2)
end_time = fortrigs.get_end_time()[idx]
sngl_snr = fortrigs.get_snglfile_array_dict('snr')
sngl_sigmasq = fortrigs.get_snglfile_array_dict('sigmasq')

for event in tqdm.trange(len(ifar)):
    # sngl_snr[ifo] contains 2 arrays: first contains SNR values and
    # second (True/False) boolean values. If boolean is False, the ifo was
    # not present in the coincidence
    ifos_event = [ifo for ifo in sngl_snr.keys() if sngl_snr[ifo][1][event]]
    snrs_event = [sngl_snr[ifo][0][event] for ifo in ifos_event]
    coinc_snr = sum([snr**2 for snr in snrs_event]) ** 0.5
    min_eff_dist = min([(sngl_sigmasq[ifo][0][event])**0.5
                       / sngl_snr[ifo][0][event] for ifo in ifos_event])
    probs = mchirp_area.calc_probabilities(mchirp[event], coinc_snr,
                                           min_eff_dist, mc_area_args)
    # We do not expect a mass gap entry, but just in case
    if "Mass Gap" in probs:
        probs.pop("Mass Gap")
    ifo_names = ''.join(sorted(ifos_event))
    out_name = dir_path + '%s-%s-%i-1.json' % (ifo_names, args.search_tag,
                                               int(end_time[event]))
    with open(out_name, 'w') as outfile:
        json.dump(probs, outfile)
