#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1767619825/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

# Copyright 2024 Arthur Tolley
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

"""Find trigger files and combine them into a single hdf trigger merge file."""

import numpy
import argparse
import h5py
import os
import logging

from igwn_segments import segmentlist, segment

import pycbc
from pycbc.io import live as liveio
from pycbc.events import cuts, veto

# Set up the command line argument parser
parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
liveio.add_live_trigger_selection_options(parser)
cuts.insert_cuts_option_group(parser)

parser.add_argument(
    '--ifos',
    nargs='+',
    required=True,
    help="The list of detectors to include triggers in the merged file"
)

parser.add_argument(
    '--output-file',
    required=True,
    help='The output file containing merged triggers.'
)
parser.add_argument(
    "--bank-file",
    required=True,
    help="The bank file used in the search"
)

args = parser.parse_args()

pycbc.init_logging(args.verbose)

logging.info("Finding trigger files")
trigger_files = liveio.find_trigger_files_from_cli(args)
logging.info("%s files found", len(trigger_files))

###########################
# COLLATE THE TRIGGER FILES
###########################
# Deal with where the gps time is in the middle of a file:
args.trigger_cuts = args.trigger_cuts or []
args.trigger_cuts.append(f"end_time:{args.gps_start_time}:lower_inc")
args.trigger_cuts.append(f"end_time:{args.gps_end_time}:upper_inc")

trigger_cut_dict, template_cut_dict = cuts.ingest_cuts_option_group(args)

logging.info(
    "Collating triggers to %s",
    args.output_file
)

# Some tracking objects
file_count = 0
n_triggers = {ifo: 0 for ifo in args.ifos}
n_triggers_cut = {ifo: 0 for ifo in args.ifos}
segs = {ifo: segmentlist([]) for ifo in args.ifos}

with h5py.File(args.bank_file,'r') as bank_file:
    # Count the number of templates
    n_templates = bank_file['template_hash'].size

with h5py.File(args.output_file, 'w') as destination:
    # Create the ifo groups in the trigger file
    for ifo in args.ifos:
        if ifo not in destination:
            destination.create_group(ifo)

    for file_count, source_file in enumerate(trigger_files):
        trigger_file = os.path.basename(source_file)

        start_time = float(trigger_file.split('-')[2])
        duration = float(trigger_file.split('-')[3][:-4])
        end_time = start_time + duration

        if file_count % 100 == 0:
            logging.info(
                "Files appended: %d/%d",
                file_count,
                len(trigger_files)
            )

        with h5py.File(source_file, 'r') as source:
            for ifo in args.ifos:
                try:
                    n_trigs_ifo = source[ifo]['snr'].size
                    n_triggers[ifo] += n_trigs_ifo
                except KeyError:
                    # No triggers in this IFO in this file
                    continue

                # Triggers were generated, so add to the segment list
                segs[ifo].append(segment(start_time, end_time))

                triggers = {
                    k: source[ifo][k][:] for k in source[ifo].keys()
                    if k not in ('loudest', 'stat', 'gates', 'psd')
                    and source[ifo][k].size == n_trigs_ifo
                }
                # The stored chisq is actually reduced chisq, so convert back
                # to unreduced chisq using chisq_dof
                tmpchisq = triggers['chisq'][:] * (2 * triggers['chisq_dof'][:] - 2)
                triggers['chisq'][:] = tmpchisq

                # Apply the cuts to triggers
                keep_idx = cuts.apply_trigger_cuts(triggers, trigger_cut_dict)

                # triggers contains the datasets that we want to use for
                # the template cuts, so here it can be used as the template bank
                keep_idx = cuts.apply_template_cuts(
                    triggers,
                    template_cut_dict,
                    template_ids=keep_idx
                )
                if not any(keep_idx):
                    # No triggers kept after cuts in this ifo for this file
                    continue
                n_triggers_cut[ifo] += keep_idx.size

                triggers = {
                    k: triggers[k][keep_idx]
                    for k in triggers.keys()
                }

                if ('approximant' not in triggers) or (len(triggers['approximant']) == 0):
                    continue

                for name, dataset in triggers.items():
                    if name in destination[ifo]:
                        # Append new data to existing dataset in destination
                        if triggers[name].shape[0] == 0:
                            continue
                        destination[ifo][name].resize(
                            n_triggers_cut[ifo],
                            axis=0
                        )
                        destination[ifo][name][-keep_idx.size:] = dataset
                    else:
                        destination[ifo].create_dataset(
                            name,
                            data=dataset,
                            chunks=True,
                            maxshape=(None,)
                        )

                for attr_name, attr_value in source.attrs.items():
                    destination.attrs[attr_name] = attr_value

    for ifo in args.ifos:
        # Collect the segments, and output to the file
        search_grp = destination[ifo].create_group('search')
        segs[ifo].coalesce()
        seg_starts, seg_ends = veto.segments_to_start_end(segs[ifo])
        search_grp.create_dataset(
            'end_time',
            data=seg_ends,
        )
        search_grp.create_dataset(
            'start_time',
            data=seg_starts,
        )
        logging.info(
            "Found %d %s triggers in %s seconds",
            n_triggers[ifo],
            ifo,
            abs(segs[ifo])
        )
        if n_triggers[ifo] != n_triggers_cut[ifo]:
            logging.info(
                "  %d triggers after cuts",
                n_triggers_cut[ifo],
            )

    # Add all the template boundaries
    for ifo in args.ifos:
        logging.info("Processing %s", ifo)
        triggers = destination[ifo]
        try:
            template_ids = triggers['template_id'][:]
        except KeyError:
            logging.info("No triggers for %s, skipping", ifo)
            continue

        logging.info("Calculating template boundaries")
        sorted_indices = numpy.argsort(template_ids)
        sorted_template_ids = template_ids[sorted_indices]
        unique_template_ids, template_id_counts = numpy.unique(
            sorted_template_ids,
            return_counts=True
        )
        template_boundaries = numpy.searchsorted(
            sorted_template_ids,
            numpy.arange(n_templates)
        )
        triggers['template_boundaries'] = template_boundaries

        # Sort other datasets by template_id so it makes sense
        # Datasets with the same length as the number of triggers:
        #   approximant, chisq, chisq_dof, coa_phase, end_time, f_lower
        #   mass_1, mass_2, sg_chisq, sigmasq, snr, spin1z, spin2z,
        #   template_duration, template_hash, template_id
        for key in triggers.keys():
            if len(triggers[key]) == len(template_ids):
                logging.info('Sorting %s by template id', key)
                sorted_key = triggers[key][:][sorted_indices]
                triggers[key][:] = sorted_key

        logging.info("Setting up region references")
        # Datasets which need region references:
        region_ref_datasets = ('chisq_dof', 'chisq', 'coa_phase',
                               'end_time', 'sg_chisq', 'snr',
                               'template_duration', 'sigmasq')
        if 'psd_var_val' in triggers.keys():
            region_ref_datasets += ('psd_var_val',)
        if 'dq_state' in triggers.keys():
            region_ref_datasets += ('dq_state',)

        start_boundaries = template_boundaries
        end_boundaries = numpy.roll(start_boundaries, -1)
        end_boundaries[-1] = len(template_ids)

        for dataset in region_ref_datasets:
            logging.info(
                "Region references for %s",
                dataset
            )
            refs = [
                triggers[dataset].regionref[l:r]
                for l, r in zip(start_boundaries, end_boundaries)
            ]

            logging.info("Adding to file")
            triggers.create_dataset(
                dataset + '_template',
                data=refs,
                dtype=h5py.special_dtype(ref=h5py.RegionReference)
            )

logging.info("Done!")
