#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1767619825/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python
"""
The program combines output files generated
by pycbc_sngls_findtrigs to generate a mapping between SNR and FAP/FAR, along
with producing the combined foreground and background triggers.
"""

import argparse
import logging, numpy, copy
from pycbc.events import veto
from pycbc.events import significance
import pycbc.pnutils, pycbc.io
import pycbc.conversions as conv

class fw(object):
    def __init__(self, name):
        self.f = pycbc.io.HFile(name, 'w')
        self.attrs = self.f.attrs

    def __setitem__(self, name, data):
        # Make a new item if isn't in the hdf file
        if not name in self.f:
            self.f.create_dataset(name, data=data, compression="gzip",
                                  compression_opts=9, shuffle=True,
                                  maxshape=data.shape)
        # Else reassign values
        else:
            self.f[name][:] = data

    def __getitem__(self, *args):
        return self.f.__getitem__(*args)

parser = argparse.ArgumentParser()
# General required options
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--sngls-files', nargs='+',
                    help='List of files containing trigger and statistic '
                         'information.')
parser.add_argument('--ifos', nargs=1,
                    help='List of ifos used in these coincidence files')
parser.add_argument('--cluster-window', type=float, default=10,
                    help='Length of time window in seconds to cluster coinc '
                         'events [default=10s]')
parser.add_argument('--veto-window', type=float, default=.1,
                    help='Time around each zerolag trigger to window out '
                         '[default=.1s]')
significance.insert_significance_option_group(parser)
parser.add_argument('--hierarchical-removal-ifar-threshold',
                    type=float, default=0.5,
                    help="Threshold to hierarchically remove foreground "
                         "triggers with IFAR (years) above this value "
                         "[default=0.5yr]")
parser.add_argument('--hierarchical-removal-window', type=float, default=1.0,
                    help='Time around each trigger to window out for a very '
                         'louder trigger in the hierarchical removal '
                         'procedure [default=1.0s]')
parser.add_argument('--max-hierarchical-removal', type=int, default=0,
                    help='Maximum number of hierarchical removals to carry '
                         'out. Choose -1 for unlimited hierarchical removal '
                         'until no foreground triggers are louder than the '
                         '(inclusive/exclusive) background. Choose 0 to do '
                         'no hierarchical removals, choose 1 to do at most '
                         '1 hierarchical removal and so on. If given, must '
                         'also provide --hierarchical-removal-against to '
                         'indicate which background to remove triggers '
                         'against. [default=0]')
parser.add_argument('--hierarchical-removal-against', type=str,
                    default='none', choices=['none', 'inclusive', 'exclusive'],
                    help='If doing hierarchical removal, remove foreground '
                         'triggers that are louder than either the "inclusive"'
                         ' (little-dogs-in) background, or the "exclusive" '
                         '(little-dogs-out) background. [default="none"]')
parser.add_argument('--output-file')
args = parser.parse_args()

significance.check_significance_options(args, parser)

# Check that the user chose inclusive or exclusive background to perform
# hierarchical removals of foreground triggers against.
if args.max_hierarchical_removal == 0:
    if args.hierarchical_removal_against != 'none':
        parser.error("User Error: 0 maximum hierarchical removals chosen but "
                     "option for --hierarchical-removal-against was given. "
                     "These are conflicting options. Use with --help for more "
                     "information.")
else :
    is_bkg_inc = (args.hierarchical_removal_against == 'inclusive')
    is_bkg_exc = (args.hierarchical_removal_against == 'exclusive')
    if not(is_bkg_inc or is_bkg_exc):
        parser.error("--max-hierarchical-removal requires a choice of which "
                    "background to remove foreground triggers against, "
                     "inclusive or exclusive. Use with --help for more "
                     "information.")


pycbc.init_logging(args.verbose)

logging.info("Loading triggers")
ifo = args.ifos[0]
logging.info("IFO input: %s" % ifo)
all_trigs = pycbc.io.MultiifoStatmapData(files=args.sngls_files,
                                         ifos=[ifo])
assert ifo + '/time' in all_trigs.data

logging.info("We have %s triggers" % len(all_trigs.stat))
logging.info("Clustering triggers")
all_trigs = all_trigs.cluster(args.cluster_window)
logging.info("%s triggers remain" % len(all_trigs.stat))

fg_time = float(all_trigs.attrs['foreground_time'])

logging.info("Dumping foreground triggers")
f = fw(args.output_file)
f.attrs['num_of_ifos'] = 1
f.attrs['ifos'] = ifo

f.attrs['timeslide_interval'] = all_trigs.attrs['timeslide_interval']

# Copy over the segment info
for key in all_trigs.seg.keys():
    f['segments/%s/start' % key] = all_trigs.seg[key]['start'][:]
    f['segments/%s/end' % key] = all_trigs.seg[key]['end'][:]

f['segments/foreground_veto/start'] = numpy.array([0])
f['segments/foreground_veto/end'] = numpy.array([0])
for k in all_trigs.data:
    f['foreground/' + k] = all_trigs.data[k]
    f['background/' + k] = all_trigs.data[k]

logging.info("Estimating FAN from background statistic values")
# Ranking statistic of foreground and background
fore_stat = back_stat = all_trigs.stat
bkg_dec_facs = all_trigs.decimation_factor

significance_dict = significance.digest_significance_options([ifo], args)

# Cumulative array of inclusive background triggers and the number of
# inclusive background triggers louder than each foreground trigger
bg_far, fg_far, sig_info = significance.get_far(
    back_stat,
    fore_stat,
    bkg_dec_facs,
    fg_time,
    **significance_dict[ifo]
)

fg_far = significance.apply_far_limit(
    fg_far,
    significance_dict,
    combo=ifo,
)
bg_far = significance.apply_far_limit(
    bg_far,
    significance_dict,
    combo=ifo,
)

bg_ifar = 1. / bg_far
fg_ifar = 1. / fg_far

f['background/ifar'] = conv.sec_to_year(bg_ifar)

f.attrs['background_time'] = fg_time
f.attrs['foreground_time'] = fg_time

# Find foreground triggers with IFAR > the set limit and remove from
# the exclusive background

# Need to make copies for use as exclusive triggers as these will have
# items removed from them, and don't want to overwrite the original
fg_time_exc = fg_time
fg_ifar_exc = copy.deepcopy(fg_ifar)
bg_ifar_exc = copy.deepcopy(bg_ifar)
back_stat_exc = copy.deepcopy(back_stat)
bkg_exc_dec_facs = copy.deepcopy(bkg_dec_facs)

# Record indices into all_trigs for the exclusive background
back_exc_locs = numpy.arange(len(all_trigs.stat))

# Remove trigs from 'exclusive' background if their IFAR is > livetime
to_keep = bg_ifar_exc <= fg_time_exc

n_removed = bg_ifar_exc.size - sum(to_keep)
logging.info("Removing %s event(s) from exclusive background",
             n_removed)

back_stat_exc = back_stat_exc[to_keep]
bkg_exc_dec_facs = bkg_exc_dec_facs[to_keep]
back_exc_locs = back_exc_locs[to_keep]

# Cumulative array of exclusive background triggers and the number
# of exclusive background triggers louder than each foreground trigger
bg_far_exc, fg_far_exc, exc_sig_info = significance.get_far(
    back_stat_exc,
    fore_stat,
    bkg_exc_dec_facs,
    fg_time_exc,
    **significance_dict[ifo])

fg_far_exc = significance.apply_far_limit(
    fg_far_exc,
    significance_dict,
    combo=ifo)

bg_far_exc = significance.apply_far_limit(
    bg_far_exc,
    significance_dict,
    combo=ifo)

bg_ifar_exc = 1. / bg_far_exc
fg_ifar_exc = 1. / fg_far_exc

# Remove a small amount of time from the exclusive fore/background
# time to account for this removal
fg_time_exc -= n_removed * args.veto_window

for k in all_trigs.data:
    f['background_exc/' + k] = all_trigs.data[k][back_exc_locs]

f['background_exc/ifar'] = conv.sec_to_year(bg_ifar_exc)
f.attrs['background_time_exc'] = fg_time_exc
f.attrs['foreground_time_exc'] = fg_time_exc

logging.info("calculating foreground ifar/fap values")

fap = 1 - numpy.exp(- fg_time / fg_ifar)
f['foreground/ifar'] = conv.sec_to_year(fg_ifar)
f['foreground/fap'] = fap
fap_exc = 1 - numpy.exp(- fg_time_exc / fg_ifar_exc)
f['foreground/ifar_exc'] = conv.sec_to_year(fg_ifar_exc)
f['foreground/fap_exc'] = fap_exc
for key, value in sig_info.items():
    f['foreground'].attrs[key] = value
for key, value in exc_sig_info.items():
    f['foreground'].attrs[f'{key}_exc'] = value

if 'name' in all_trigs.attrs:
    f.attrs['name'] = all_trigs.attrs['name']

# Incorporate hierarchical removal for any other loud triggers
logging.info("Beginning hierarchical removal of foreground triggers")

# Step 1: Create a copy of foreground trigger ranking statistic for reference
#         in the hierarchical removal while loop when updating ifar and fap o
#         hierarchically removed foreground triggers.

# Set an index to keep track of how many hierarchical removals we want to do.
h_iterations = 0

orig_fore_stat = fore_stat

# Assign a new variable to keep track of whether we care about inclusive or
# exclusive ifar depending on whether we want to remove hierarchically
# against inclusive or exclusive background.

if args.max_hierarchical_removal != 0:
    # If user wants to remove against inclusive background.
    if is_bkg_inc:
        ifar_louder = fg_ifar
    # Otherwise user wants to remove against exclusive background
    else :
        ifar_louder = fg_ifar_exc
else :
    # It doesn't matter if you choose inclusive or exclusive,
    # the while loop below will break if none are louder than
    # ifar_louder, or at the comparison
    # h_iterations == args.max_hierarchical_removal. But this avoids
    # a NameError
    ifar_louder = fg_ifar

# Step 2 : Loop until we don't have to hierarchically remove anymore. This
#          will happen when ifar_louder has no elements that are
#          above the set threshold, or a set maximum.

# Convert threshold into seconds
hier_ifar_thresh_s = args.hierarchical_removal_ifar_threshold /  \
    conv.sec_to_year(1)

while numpy.any(ifar_louder > hier_ifar_thresh_s):
    # If the user wants to stop doing hierarchical removals after a set
    # number of iterations then break when that happens.
    if (h_iterations == args.max_hierarchical_removal):
        break

    # Write foreground trigger info before hierarchical removals for
    # downstream codes.
    if h_iterations == 0:
        f['background_h%s/stat' % h_iterations] = back_stat
        f['background_h%s/ifar' % h_iterations] = conv.sec_to_year(bg_ifar)
        for k in all_trigs.data:
            f['background_h%s/' % h_iterations + k] = all_trigs.data[k]
        f['foreground_h%s/stat' % h_iterations] = fore_stat
        f['foreground_h%s/ifar' % h_iterations] = conv.sec_to_year(fg_ifar)
        f['foreground_h%s/ifar_exc' % h_iterations] = conv.sec_to_year(fg_ifar_exc)
        f['foreground_h%s/fap' % h_iterations] = fap
        for key, value in sig_info.items():
            f['foreground_h%s' % h_iterations].attrs[key] = value
        for key, value in exc_sig_info.items():
            f['foreground_h%s' % h_iterations].attrs[key + "_exc"] = value
        for k in all_trigs.data:
            f['foreground_h%s/' % h_iterations + k] = all_trigs.data[k]
    # Add the iteration number of hierarchical removals done.
    h_iterations += 1

    # Among foreground triggers, find the one with the largest ifar
    # and mark it for removal.
    max_stat_idx = ifar_louder.argmax()

    # Step 3: Remove that trigger from the list of zerolag triggers

    # Find the index of the loud foreground trigger to remove next. And find
    # the index in the list of original foreground triggers.
    rm_trig_idx = numpy.where(all_trigs.stat[:] == fore_stat[max_stat_idx])[0][0]
    orig_fore_idx = numpy.where(orig_fore_stat == fore_stat[max_stat_idx])[0][0]

    # Store any foreground trigger's information that we want to
    # hierarchically remove.
    f['foreground/ifar'][orig_fore_idx] = conv.sec_to_year(fg_ifar[max_stat_idx])
    f['foreground/fap'][orig_fore_idx] = fap[max_stat_idx]

    logging.info("Removing foreground trigger that is louder than the inclusive background.")

    # Remove the foreground trigger and all of the background triggers that
    # are associated with it.
    ave_rm_time = all_trigs.data['%s/time' % ifo][rm_trig_idx]

    ind_to_rm = {}
    ind_to_rm[ifo] = veto.indices_within_times(all_trigs.data['%s/time' % ifo],
                              [ave_rm_time - args.hierarchical_removal_window],
                              [ave_rm_time + args.hierarchical_removal_window])
    indices_to_rm = []
    indices_to_rm = numpy.concatenate([indices_to_rm, ind_to_rm[ifo]])

    all_trigs = all_trigs.remove(indices_to_rm.astype(int))
    logging.info("We have %s triggers after hierarchical removal." % len(all_trigs.stat))

    # Step 4: Re-cluster the triggers and calculate the inclusive ifar/fap
    logging.info("Clustering coinc triggers (inclusive of zerolag)")
    all_trigs = all_trigs.cluster(args.cluster_window)

    logging.info("%s clustered foreground triggers" % len(all_trigs))
    logging.info("%s hierarchically removed foreground trigger(s)" % h_iterations)

    logging.info("Dumping foreground triggers")
    logging.info("Dumping background triggers (inclusive of zerolag)")
    for k in all_trigs.data:
        f['background_h%s/' % h_iterations + k] = all_trigs.data[k]

    logging.info("Calculating FAN from background statistic values")
    back_stat = fore_stat = all_trigs.stat
    bkg_dec_facs = all_trigs.decimation_factor

    bg_far, fg_far, sig_info = significance.get_far(
        back_stat,
        fore_stat,
        bkg_dec_facs,
        fg_time,
        **significance_dict[ifo])

    fg_far = significance.apply_far_limit(
        fg_far,
        significance_dict,
        combo=ifo,
    )

    bg_far = significance.apply_far_limit(
        bg_far,
        significance_dict,
        combo=ifo,
    )

    bg_ifar = 1. / bg_far
    fg_ifar = 1. / fg_far

    # Update the ifar_louder criteria depending on whether foreground
    # triggers are being removed via inclusive or exclusive background.
    if is_bkg_inc:
        ifar_louder = fg_ifar
        exc_sig_info = {}

    # Exclusive background doesn't change when removing foreground triggers.
    # So we don't have to take bg_far_exc, just repopulate fg_ifar_exc
    else:
        _, fg_far_exc, exc_sig_info = significance.get_far(
            back_stat_exc,
            fore_stat,
            bkg_exc_dec_facs,
            fg_time_exc,
            **significance_dict[ifo])

        fg_far_exc = significance.apply_far_limit(
            fg_far_exc,
            significance_dict,
            combo=ifo,
        )

        fg_ifar_exc = 1. / fg_far_exc

        ifar_louder = fg_ifar_exc

    # louder_foreground has been updated and the code can continue.

    logging.info("Calculating ifar/fap values")
    f['background_h%s/ifar' % h_iterations] = conv.sec_to_year(bg_ifar)
    f.attrs['background_time_h%s' % h_iterations] = fg_time
    f.attrs['foreground_time_h%s' % h_iterations] = fg_time

    if len(all_trigs) > 0:
        # Write ranking statistic to file just for downstream plotting code
        f['foreground_h%s/stat' % h_iterations] = fore_stat

        for key, value in sig_info.items():
            f['foreground_h%s' % h_iterations].attrs[key] = value
        for key, value in exc_sig_info.items():
            f['foreground_h%s' % h_iterations].attrs[key + "_exc"] = value
        fap = 1 - numpy.exp(- fg_time / fg_ifar)
        f['foreground_h%s/ifar' % h_iterations] = conv.sec_to_year(fg_ifar)
        f['foreground_h%s/fap' % h_iterations] = fap

        fap_exc = 1 - numpy.exp(- fg_time / fg_ifar_exc)
        f['foreground_h%s/ifar' % h_iterations] = conv.sec_to_year(fg_ifar_exc)
        f['foreground_h%s/fap' % h_iterations] = fap_exc

        # Update ifar and fap for other foreground triggers
        for i in range(len(fg_ifar)):
            orig_fore_idx = numpy.where(orig_fore_stat == fore_stat[i])[0][0]
            f['foreground/ifar'][orig_fore_idx] = conv.sec_to_year(fg_ifar[i])
            f['foreground/fap'][orig_fore_idx] = fap[i]

        # Save trigger ids for foreground triggers for downstream plotting code.
        # These don't change with the iterations but should be written at every
        # level.

        f['foreground_h%s/template_id' % h_iterations] = all_trigs.data['template_id']
        trig_id = all_trigs.data['%s/trigger_id' % ifo]
        trig_time = all_trigs.data['%s/time' % ifo]
        f['foreground_h%s/%s/time' % (h_iterations,ifo)] = trig_time
        f['foreground_h%s/%s/trigger_id' % (h_iterations,ifo)] = trig_id
    else :
        f['foreground_h%s/stat' % h_iterations] = numpy.array([])
        f['foreground_h%s/ifar' % h_iterations] = numpy.array([])
        f['foreground_h%s/fap' % h_iterations] = numpy.array([])
        f['foreground_h%s/template_id' % h_iterations] = numpy.array([])
        f['foreground_h%s/%s/time' % (h_iterations,ifo)] = numpy.array([])
        f['foreground_h%s/%s/trigger_id' % (h_iterations,ifo)] = numpy.array([])

# Write to file how many hierarchical removals were implemented.
f.attrs['hierarchical_removal_iterations'] = h_iterations

# Write whether hierarchical removals were removed against the
# inclusive background or the exclusive background. Have to use
# numpy.bytes_ datatype.
if h_iterations != 0:
    hrm_method = args.hierarchical_removal_against
    f.attrs['hierarchical_removal_method'] = numpy.bytes_(hrm_method)

logging.info("Done")

