#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

"""
This executable is used to get the output of pycbc_multiifo_sngls_findtrigs
and separate the triggers into foreground - times when coincs are not
possible - and background - times when coincs are possible. It also performs
foreground removal, and creates background_exc which is a background that
contains no triggers within a certain window of those which form foreground
coincidences.

"""

import pycbc, pycbc.io, copy
import argparse, logging, numpy as np
from igwn_ligolw import lsctables, utils as ligolw_utils
from pycbc import conversions as conv
from pycbc.events import veto
from pycbc.io.ligolw import LIGOLWContentHandler
from igwn_segments import segment, segmentlist
import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot as plt
from scipy.stats import gaussian_kde as gk


d_power = {
    'log': 3.,
    'uniform': 2.,
    'distancesquared': 1.,
    'volume': 0.
}

mchirp_power = {
    'log': 0.,
    'uniform': 5. / 6.,
    'distancesquared': 5. / 3.,
    'volume': 15. / 6.
}

parser = argparse.ArgumentParser()
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--single-statmap-files", nargs='+', required=True,
                    help="Single statmap files for which p_astro is "
                         "calculated.")

# Files to help remove foreground events
parser.add_argument('--coinc-statmap-files', nargs='+', required=True,
                    help="Coincident statmap files, containing coincident "
                         "events to be removed from the background.")
parser.add_argument('--coinc-veto-ifar-threshold', type=float, default=1.0,
                    help="Censor triggers around coincidences with "
                         "IFAR (years) above the threshold [default=1 yr")
parser.add_argument('--coinc-veto-window', type=float, default=0.1,
                    help="Time around each coincident event above threshold "
                         "to window out. Default = 0.1s")
parser.add_argument("--remove-n-loudest", type=int,
                    help="If given, will remove this number of triggers from "
                         "the background after applying foreground censor. "
                         "This helps prevent signal contamination, but too "
                         "many removals can adversely affect background "
                         "estimate.")

# Arguments for dealing with the injection trigger files for signal population
parser.add_argument('--inj-single-statmap-files', nargs='+', required=True,
                    help="Single statmap files for the injections. "
                         "Must be in same order as --inj-files")
parser.add_argument('--inj-files', nargs='+', required=True,
                    help="File which define injections. Must be in same "
                         "order as --inj-single-trigger-files")
parser.add_argument('--injection-window', type=float, default=1,
                    help="Window for how close a trigger is to an "
                         "injection to be considered to be associated "
                         "with it (seconds). Default=1.")

# Arguments to decide the weighting of the injection distribution
parser.add_argument('--distance-param', choices=['distance', 'chirp_distance'],
                    help="Parameter used to calculate injection distribution "
                         "for weighting. Default='distance'")
parser.add_argument('--distribution', default='uniform',
                    choices=['log', 'uniform', 'distancesquared', 'volume'],
                    help="Form of distribution over --distance-param. "
                         "Default='uniform'")
parser.add_argument('--expected-signal-rate', type=float, required=True,
                    help="Expected rate of signals (per year) with stat "
                         "value above --stat-threshold for use in p_astro "
                         "calcation.")
parser.add_argument('--signal-ifar-threshold', type=float, default=1,
                    help="Coincident IFAR threshold to consider an injection "
                         "as 'found' (years). Default=1")

# produces a list of lists to allow multiple invocations and multiple args
parser.add_argument('--pastro-method', required=True,
                    choices=['truncated_shelf', 'callister', 'background_kde'],
                    help="Which method to use for calculating p_astro. ")
parser.add_argument("--bg-distribution-limit", type=float, default=8,
                    help="The point at which the noise model will change from "
                         "the background distribution to --pastro-method")

parser.add_argument('--plot-distribution',
                    help="If given, will plot the distribution of rate "
                         "densities and p_astro given ranking statistic.")
parser.add_argument('--output-file', required=True,
                    help="name of output file")
args = parser.parse_args()

d_power = {
    'log': 3.,
    'uniform': 2.,
    'distancesquared': 1.,
    'volume': 0.
}[args.distribution]

mchirp_power = {
    'log': 0.,
    'uniform': 5. / 6.,
    'distancesquared': 5. / 3.,
    'volume': 15. / 6.
}[args.distribution]


pycbc.init_logging(args.verbose)

sngl_ifo = pycbc.io.HFile(args.single_statmap_files[0], 'r').attrs['ifos']
groups = ['decimation_factor', 'stat', 'template_id', 'timeslide_id',
          sngl_ifo + '/time', sngl_ifo + '/trigger_id']
empty_darray = {g: np.array([]) for g in groups}
sngl_trigs = pycbc.io.DictArray(data=empty_darray)

logging.info('Getting single-detector triggers and segments')
sngl_detector_segs = segmentlist([segment(0, 0)])
for smap_file in args.single_statmap_files:
    with pycbc.io.HFile(smap_file, 'r') as f_in:
        sngl_starts = f_in['segments/' + sngl_ifo + '/start'][:]
        sngl_ends = f_in['segments/' + sngl_ifo + '/end'][:]
        sngl_trigs += pycbc.io.DictArray(data={g: f_in[g][:] for g in groups})
    sngl_detector_segs += veto.start_end_to_segments(sngl_starts, sngl_ends)

fg_time = abs(sngl_detector_segs)
coinc_segs = segmentlist([segment(0, 0)])
fgveto_segs = segmentlist([segment(0, 0)])
logging.info('Getting coinc segments and foreground vetoes')
for cfilename in args.coinc_statmap_files:
    with pycbc.io.HFile(cfilename, 'r') as cfile:
        # Coinc segments are different depending on statmap file
        if 'coinc' in cfile['segments']:
            c_starts = cfile['segments/coinc/start'][:]
            c_ends = cfile['segments/coinc/end'][:]
        else:
            ctype = cfile.attrs['ifos'].replace(' ','')
            c_starts = cfile['segments/' + ctype + '/start'][:]
            c_ends = cfile['segments/' + ctype + '/end'][:]
        # Test if the ifo is actually in this statmap file
        if 'ifos' in cfile.attrs.keys():
            ifolist = cfile.attrs['ifos'].split(' ')
            time_key = sngl_ifo + '/time'
        else:
            ifolist = [cfile.attrs['detector_1'],
                       cfile.attrs['detector_2']]
            if sngl_ifo == cfile.attrs['detector_1']:
                time_key = 'time1'
            else:
                time_key = 'time2'

        if sngl_ifo not in ifolist:
            logging.warning("IFO %s is not in file %s", sngl_ifo, cfilename)
            continue

        # Find the coinc events above a threshold IFAR
        cifar = cfile['foreground/ifar'][:]
        c_above = cifar > args.coinc_veto_ifar_threshold
        ctime = cfile['foreground'][time_key][:][c_above]
        c_fgv_starts = ctime - args.coinc_veto_window
        c_fgv_ends = ctime + args.coinc_veto_window
    coinc_segs = coinc_segs + veto.start_end_to_segments(c_starts, c_ends)
    fgveto_segs = fgveto_segs + veto.start_end_to_segments(c_fgv_starts,
                                                           c_fgv_ends)
    coinc_segs.coalesce()
    fgveto_segs.coalesce()


logging.info("Removing triggers in single-detector time from background.")
bg_bool = np.array([t in coinc_segs
                    for t in sngl_trigs.data[sngl_ifo + '/time']])
bg_idx = np.flatnonzero(bg_bool)
bg_trigs = sngl_trigs.select(bg_idx)

logging.info("Removing triggers in foreground vetoed time from background.")
fg_veto_bool_bg = np.array([t not in fgveto_segs
                         for t in bg_trigs.data[sngl_ifo + '/time']])
bg_trigs = bg_trigs.select(np.flatnonzero(fg_veto_bool_bg))
bg_time = abs(coinc_segs - fgveto_segs)

if args.remove_n_loudest:
    logging.info("Removing %d loudest remaining triggers from background",
                 args.remove_n_loudest)
    loudest_idx = np.argsort(bg_trigs.data['stat'])[-args.remove_n_loudest:]
    bg_trigs = bg_trigs.remove(loudest_idx)

logging.info("Getting injected signal triggers and information.")

signal_stat = np.array([])
signal_weights = np.array([])
# This method requires injection definition files and injection single statmap files to be
# in the same order - would prefer to not have to do this
for inj_filename, inj_trigger_filename in zip(args.inj_files,
                                              args.inj_single_statmap_files):

    logging.info('Reading injection statistic file')
    with pycbc.io.HFile(inj_trigger_filename, 'r') as inj_trig_f:
        inj_trig_id = inj_trig_f[sngl_ifo]['trigger_id'][:]
        inj_trig_time = inj_trig_f[sngl_ifo]['time'][:]
        inj_stat = inj_trig_f['stat'][:]

    time_sort = inj_trig_time.argsort()

    logging.info('Reading injection file')
    indoc = ligolw_utils.load_filename(inj_filename, False,
                                       contenthandler=LIGOLWContentHandler)
    sim_table = lsctables.SimInspiralTable.get_table(indoc)
    inj_time = np.array(sim_table.get_column('geocent_end_time') +
                        1e-9 * sim_table.get_column('geocent_end_time_ns'),
                        dtype=np.float64)

    left = np.searchsorted(inj_trig_time[time_sort],
                           inj_time - args.injection_window, side='left')
    right = np.searchsorted(inj_trig_time[time_sort],
                            inj_time + args.injection_window, side='right')
    found = np.flatnonzero((right - left) == 1)

    found_d = np.array(sim_table.get_column('distance'),
                       dtype=np.float32)[found]
    found_m1 = np.array(sim_table.get_column('mass1'),
                        dtype=np.float32)[found]
    found_m2 = np.array(sim_table.get_column('mass2'),
                        dtype=np.float32)[found]

    signal_stat = np.append(signal_stat, inj_stat[left[found]])

    found_mchirp = conv.mchirp_from_mass1_mass2(found_m1, found_m2)
    if args.distance_param == 'chirp_distance':
        found_d = conv.chirp_distance(found_d, found_mchirp)

    weights = found_mchirp ** mchirp_power * found_d ** d_power
    signal_weights = np.append(signal_weights, weights)


bg_stat = bg_trigs.data['stat']
fg_stat = sngl_trigs.data['stat']

n_sig_exp = args.expected_signal_rate * conv.sec_to_year(fg_time)
logging.info("%.3f signals are expected in this amount of data", n_sig_exp)

max_bg = bg_stat.max()
min_bg = bg_stat.min()

logging.info('Getting Gaussian kernel density estimates of signal '
             'distribution')
sig_kern = gk(signal_stat, weights=signal_weights, bw_method=1)
sig_norm = sig_kern.integrate_box_1d(max_bg, np.inf)
sig_dens = sig_kern(fg_stat) / sig_norm

logging.info('Getting Gaussian kernel density estimates of background '
             'distribution')
bg_kern = gk(bg_stat, bw_method=1)
bg_rate_dens = bg_kern(fg_stat) * bg_stat.size

# Find the peak of the background distribution:
# anything with statistic below this gets p_astro zero, this prevents
# problems with the background distribution reducing as triggers are
# clustered away
bg_dens_max_idx = bg_rate_dens.argmax()
bg_dens_max_stat = fg_stat[bg_dens_max_idx]
fg_stat_below_peak = fg_stat < bg_dens_max_stat

logging.info('Getting noise distribution of %s method.', args.pastro_method)
# If the foreground statistic value is quieter than X
# then use background density for noise
pastro_noise_valid = fg_stat >= args.bg_distribution_limit
noise_model = copy.deepcopy(bg_rate_dens)

if args.pastro_method == 'callister':
    noise_model[pastro_noise_valid] = sig_dens[pastro_noise_valid]
elif args.pastro_method == 'truncated_shelf':
    fgs = fg_stat[pastro_noise_valid]
    stat_below = np.array([max(bg_stat[bg_stat < fs])
                           if fs > min_bg else -np.inf for fs in fgs])
    ts_norm = np.array([sig_kern.integrate_box_1d(sbel, fs)
                        for sbel, fs in zip(stat_below, fgs)])
    noise_model[pastro_noise_valid] = sig_kern(fgs) / ts_norm

logging.info("Calculating pastro for foreground events")
p_astro = n_sig_exp * sig_dens / (noise_model + n_sig_exp * sig_dens)
p_astro[fg_stat_below_peak] = 0

if args.plot_distribution:
    logging.info("Making distribution plot")
    max_bin_c =  min(signal_stat.max(), bg_stat.max() * 4)
    stat_bins = np.linspace(min_bg, max_bin_c, 201)
    stat_bin_c = (stat_bins[1:] + stat_bins[:-1]) / 2

    fig = plt.figure(figsize=[6, 6])
    ax0 = fig.add_subplot(211)
    ax1 = fig.add_subplot(212)

    # Plot background trigger rate density
    bg_rate_dens_plt = bg_kern(stat_bin_c) * bg_stat.size
    ax0.plot(stat_bin_c, bg_rate_dens_plt, linestyle="--",
             label="Background KDE")

    # Plot signal trigger rate density (from injections)
    sig_rate_dens_plt = sig_kern(stat_bin_c) / sig_norm
    ax0.plot(stat_bin_c, sig_rate_dens_plt, linestyle=":",
             label="Signals KDE")

    # Calculate noise model according to chosen method
    if args.pastro_method == 'callister':
        noise_mod_plt = sig_rate_dens_plt
    elif args.pastro_method == 'truncated_shelf':
        stat_below_plt = [max(bg_stat[sbc >= bg_stat])
                          if sbc >= min_bg else -np.inf for sbc in stat_bin_c]
        ts_norm = np.array([sig_kern.integrate_box_1d(sbel, sbc)
                           for sbel, sbc in zip(stat_below_plt, stat_bin_c)])
        noise_mod_plt = sig_kern(stat_bin_c) / ts_norm
    elif args.pastro_method == 'background_kde':
        noise_mod_plt = bg_rate_dens_plt

    # triggers below the chosen distribution limit use the background
    # distribution for noise model
    past_noise_idx_plt = stat_bin_c < args.bg_distribution_limit
    noise_mod_plt[past_noise_idx_plt] = bg_rate_dens_plt[past_noise_idx_plt]
    ax0.plot(stat_bin_c, noise_mod_plt, linestyle='-', c='k',
             label=args.pastro_method)

    ax0.semilogy()
    ax0.grid()
    # Set axis limits according to the highest point of the background
    # distribution
    ax0.set_xlim([bg_dens_max_stat, max_bin_c])
    ax0.set_ylim([sig_kern(max_bin_c) / sig_norm / 1.5,
                  bg_rate_dens[bg_dens_max_idx] * 1.5])
    ax0.legend()
    ax0.set_xlabel("Ranking Statistic")
    ax0.set_ylabel("Rate Density")

    # Calculate p_astro for the plotted statistic values
    p_astro_plt = \
        n_sig_exp * sig_rate_dens_plt / (noise_mod_plt + n_sig_exp * sig_rate_dens_plt)

    ax1.plot(stat_bin_c, p_astro_plt, c='k', label="p_astro distribution")
    ax1.scatter(fg_stat, p_astro, c='k', marker='x', label='foreground events')
    ax1.grid()
    ax1.legend()
    ax1.set_xlim([bg_dens_max_stat, max_bin_c])
    ax1.set_ylim([0, 1])
    ax1.set_xlabel("Ranking Statistic")
    ax1.set_ylabel("p_astro (%s)" % args.pastro_method)

    fig.savefig(args.plot_distribution)

f_out = pycbc.io.HFile(args.output_file, 'w')
for k in sngl_trigs.data:
    f_out.create_dataset('foreground/' + k,
                         data=sngl_trigs.data[k],
                         compression='gzip',
                         compression_opts=9,
                         shuffle=True)
    f_out.create_dataset('background/' + k,
                         data=bg_trigs.data[k],
                         compression='gzip',
                         compression_opts=9,
                         shuffle=True)

f_out.create_dataset('foreground/p_astro',
                     data=p_astro,
                     compression='gzip',
                     compression_opts=9,
                     shuffle=True)

# Store segments
f_out['segments/%s/start' % sngl_ifo], f_out['segments/%s/end' % sngl_ifo] = \
    veto.segments_to_start_end(sngl_detector_segs)
f_out.attrs['foreground_time'] = fg_time
f_out.attrs['background_time'] = bg_time
f_out.attrs['num_of_ifos'] = 1
f_out.attrs['pivot'] = sngl_ifo
f_out.attrs['fixed'] = sngl_ifo
f_out.attrs['ifos'] = sngl_ifo
f_out.attrs['pastro_method'] = args.pastro_method
logging.info('Done')
