#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1767619825/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python
""" Bin triggers by their dq value and calculate trigger rates in each bin
"""
import logging
import argparse

import numpy as np
import h5py as h5

from igwn_segments import segmentlist

import pycbc
from pycbc.events import stat as pystat
from pycbc.events.veto import (select_segments_by_definer,
                               start_end_to_segments,
                               segments_to_start_end)
from pycbc.types.optparse import MultiDetOptionAction
from pycbc.io.hdf import SingleDetTriggers

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--template-bins-file", required=True)
parser.add_argument("--trig-file", required=True)
parser.add_argument("--flag-file", required=True)
parser.add_argument("--flag-name", required=True)
parser.add_argument("--analysis-segment-file", required=True)
parser.add_argument("--analysis-segment-name", required=True)
parser.add_argument("--gating-windows", nargs='+',
                    action=MultiDetOptionAction,
                    help="Seconds to reweight before and after the central"
                         "time of each gate. Given as detector-values pairs, "
                         "e.g. H1:-1,2.5 L1:-1,2.5 V1:0,0")
parser.add_argument("--stat-threshold", type=float, default=1.,
                    help="Only consider triggers with --sngl-ranking value "
                    "above this threshold")
parser.add_argument("--output-file", required=True)

pystat.insert_statistic_option_group(
    parser, default_ranking_statistic='single_ranking_only')
args = parser.parse_args()
pycbc.init_logging(args.verbose)

logging.info('Start')

ifo, flag_name = args.flag_name.split(':')

if args.gating_windows:
    gate_times = []
    with h5.File(args.trig_file, 'r') as trig_file:
        logging.info('Getting gated times')
        try:
            gating_types = trig_file[f'{ifo}/gating'].keys()
            for gt in gating_types:
                gate_times += list(trig_file[f'{ifo}/gating/{gt}/time'][:])
            gate_times = np.unique(gate_times)
        except KeyError:
            logging.warning('No gating found in trigger file')

trigs = SingleDetTriggers(
    args.trig_file,
    ifo,
    filter_rank=args.sngl_ranking,
    filter_threshold=args.stat_threshold,
)

# Extract the data we actually need from the data structure:
tmplt_ids = trigs.template_id
trig_times = trigs.end_time
stat = trigs.get_ranking(args.sngl_ranking)

# Get the template bins
bin_tids_dict = {}
with h5.File(args.template_bins_file, 'r') as f:
    ifo_grp = f[ifo]
    for bin_name in ifo_grp.keys():
        bin_tids_dict[bin_name] = ifo_grp[bin_name]['tids'][:]

# get analysis segments
analysis_segs = select_segments_by_definer(
    args.analysis_segment_file,
    segment_name=args.analysis_segment_name,
    ifo=ifo)

livetime = abs(analysis_segs)

# get flag segments
flag_segs = select_segments_by_definer(args.flag_file,
                                       segment_name=flag_name,
                                       ifo=ifo)

# construct gate segments
gating_segs = segmentlist([])
if args.gating_windows:
    gating_windows = args.gating_windows[ifo].split(',')
    gate_before = float(gating_windows[0])
    gate_after = float(gating_windows[1])
    if gate_before > 0 or gate_after < 0:
        raise ValueError("Gating window values must be negative "
                         "before gates and positive after gates.")
    if not (gate_before == 0 and gate_after == 0):
        gating_segs = start_end_to_segments(
                gate_times + gate_before,
                gate_times + gate_after
        ).coalesce()

# make segments into mutually exclusive dq states
gating_segs = gating_segs & analysis_segs
flag_segs = flag_segs & analysis_segs

dq_state_segs_dict = {}
dq_state_segs_dict[2] = gating_segs
dq_state_segs_dict[1] = flag_segs - gating_segs
dq_state_segs_dict[0] = analysis_segs - flag_segs - gating_segs


# utility function to get the dq state at a given time
def dq_state_at_time(t):
    for state, segs in dq_state_segs_dict.items():
        if t in segs:
            return state
    return None


# compute and save results
with h5.File(args.output_file, 'w') as f:
    ifo_grp = f.create_group(ifo)
    all_bin_grp = ifo_grp.create_group('bins')
    all_dq_grp = ifo_grp.create_group('dq_segments')

    # setup data for each template bin
    for bin_name, bin_tids in bin_tids_dict.items():
        bin_grp = all_bin_grp.create_group(bin_name)
        bin_grp['tids'] = bin_tids

        # get the dq states of the triggers in this bin
        inbin = np.isin(tmplt_ids, bin_tids)
        trig_times_bin = trig_times[inbin]
        trig_states = np.array([dq_state_at_time(t) for t in trig_times_bin])

        # calculate the dq rates in this bin
        dq_rates = np.zeros(3, dtype=np.float64)
        for state, segs in dq_state_segs_dict.items():
            frac_eff = np.mean(trig_states == state)
            frac_dt = abs(segs) / livetime
            dq_rates[state] = frac_eff / frac_dt
        bin_grp['dq_rates'] = dq_rates
        bin_grp['num_triggers'] = len(trig_times_bin)

    # save dq state segments
    for dq_state, segs in dq_state_segs_dict.items():
        name = f'dq_state_{dq_state}'
        dq_grp = all_dq_grp.create_group(name)
        starts, ends = segments_to_start_end(segs)
        dq_grp['segment_starts'] = starts
        dq_grp['segment_ends'] = ends
        dq_grp['livetime'] = abs(segs)

    f.attrs['stat'] = f'{ifo}-dq_stat_info'
    f.attrs['sngl_ranking'] = args.sngl_ranking
    f.attrs['sngl_ranking_threshold'] = args.stat_threshold

logging.info('Done!')
