#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2019 Duncan Macleod
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Combine triggers from a splitbank GRB run
"""

import argparse
import os
import logging

import numpy

import tqdm

import h5py

from gwdatafind.utils import (file_segment, filename_metadata)

import igwn_segments as segments
from igwn_segments.utils import fromsegwizard

from pycbc import add_common_pycbc_options, init_logging
from pycbc.results.pygrb_postprocessing_utils import template_hash_to_id
from pycbc.io.hdf import HFile

__author__ = "Duncan Macleod <duncan.macleod@ligo.org>"

TQDM_BAR_FORMAT = ("{desc}: |{bar}| "
                   "{n_fmt}/{total_fmt} {unit} ({percentage:3.0f}%) "
                   "[{elapsed} | ETA {remaining}]")
TQDM_KW = {
    "ascii": " -=#",
    "bar_format": TQDM_BAR_FORMAT,
    "smoothing": 0.05,
}


# -- utilities -----------------------------------

def merge_hdf5_files(inputfiles, outputfile, verbose=False, **compression_kw):
    """Merge several HDF5 files into a single file

    Parameters
    ----------
    inputfiles : `list` of `str`
        the paths of the input HDF5 files to merge

    outputfile : `str`
        the path of the output HDF5 file to write
    """
    attributes = {}
    datasets = {}
    chunk = {}

    nfiles = len(inputfiles)

    def _scan_dataset(name, obj):
        global chunk

        if isinstance(obj, h5py.Dataset) and "/search/" in name:
            datasets[name] = (obj.shape, obj.dtype)
        elif isinstance(obj, h5py.Dataset):
            shape = obj.shape
            dtype = obj.dtype
            try:
                shape = numpy.sum(datasets[name][0] + shape, keepdims=True)
            except KeyError:
                pass
            else:
                assert dtype == datasets[name][1], (
                    "Cannot merge {0}/{1}, does not match dtype".format(
                        obj.file.filename, name,
                    ))
            datasets[name] = (shape, dtype)

            # use default compression options from this file
            for copt in ('compression', 'compression_opts'):
                compression_kw.setdefault(copt, getattr(obj, copt))

    # get list of datasets
    attributes = {}
    datasets = {}
    for filename in tqdm.tqdm(inputfiles, desc="Scanning trigger files",
                              disable=not verbose, total=nfiles, unit="files",
                              **TQDM_KW):
        with HFile(filename, 'r') as h5f:
            # note: attributes are only recorded from the last file,
            #       since we presume all files have the same attributes
            attributes = dict(h5f.attrs)
            h5f.visititems(_scan_dataset)

    # print summary of what we found
    logging.info(
        "Found %d events across %d files",
        datasets["network/event_id"][0][0],
        nfiles,
    )

    def _network_first(name):
        """Key function to sort datasets in the network group first

        This is required so that when we merge the network/X1_event_id
        columns, we can reference the X1/event_id event count from the
        _previous_ iteration so get the correct increments.
        """
        if name.startswith("network"):
            return 0
        return 1

    # get ordered list of dataset names to loop over
    dataset_names = sorted(datasets, key=_network_first)

    # get list of event_id columns
    sngl_event_id_names = [
        x for x in dataset_names if x.endswith("/event_id")
    ]
    network_event_id_names = [
        "network/{}".format(x.replace('/', '_')) for
        x in sngl_event_id_names if not x.startswith("network")
    ]
    all_event_id_names = sngl_event_id_names + network_event_id_names

    # handle search datasets as a special case
    # (they will be the same in all files)
    search_datasets = set(filter(lambda x: "/search/" in x, dataset_names))
    gating_datasets = set(filter(lambda x: "/gating/" in x, dataset_names))
    once_only_datasets = search_datasets.union(gating_datasets)

    with HFile(outputfile, 'w') as h5out:
        h5out.attrs.update(attributes)

        # copy dataset contents
        for dset in dataset_names:
            # Template ID becomes unhelpful in the merged file.
            # This will be handled separately.
            if "template_id" in dset or dset in all_event_id_names:
                continue

            logging.info("Merging %s dataset", dset)
            data = []
            for filename in tqdm.tqdm(inputfiles, desc="Merging trigger files",
                                      disable=not verbose, total=nfiles,
                                      unit="files", **TQDM_KW):
                with HFile(filename, 'r') as h5in:
                    if len(h5in['network']['event_id']) == 0:
                        continue
                    data += [h5in[dset][:]]
                    # read the search datasets and gating_datasets only once
                    if dset in once_only_datasets:
                        break

            h5out.create_dataset(dset, data=numpy.concatenate(data),
                                 **compression_kw)
            del data
        # END OF DATASET LOOP

        # Handling event_id datasets seprately
        dset = all_event_id_names[0]
        n_events = 0
        logging.info("Merging %s dataset", dset)
        data = []
        for filename in tqdm.tqdm(inputfiles, desc="Merging trigger files",
                                  disable=not verbose, total=nfiles,
                                  unit="files", **TQDM_KW):
            with HFile(filename, 'r') as h5in:
                if len(h5in['network']['event_id']) == 0:
                    continue

                # shift the event ids by the number of events gathered so far
                n_events_h5in = len(h5in[dset][:])
                data += [h5in[dset][:] + [n_events]*n_events_h5in]
                n_events += n_events_h5in
        data = numpy.concatenate(data)
        for dset in all_event_id_names:
            logging.info("Storing %s dataset", dset)
            h5out.create_dataset(dset, data=data, **compression_kw)
        del data

        # Handling template_id separately
        logging.info("Preparing template IDs")
        template_ids = template_hash_to_id(
            trigger_file=h5out, bank_path=args.bank_file
        )
        logging.info("Storing template IDs")
        h5out.create_dataset("network/template_id",
                             data=template_ids,
                             compression="gzip",
                             compression_opts=9)
    # END OF WITH H5OUT

    logging.info("Merged triggers written to %s", outputfile)


def bin_events(inputfile, bins, outdir, filetag,
               column="network/end_time_gc", verbose=False):
    """Separate events in the inputfile into bins
    """
    ifotag, _, seg = filename_metadata(inputfile)

    with HFile(inputfile, "r") as h5in:
        ifos = [k for k in h5in.keys() if k != "network"]
        times = h5in[column][()]

        for bin_, segl in bins.items():
            # find which network triggers to keep
            if isinstance(segl, list):
                def _in_bin(t):
                    return t in segl
            else:
                def _in_bin(t):
                    return segl[0] <= t < segl[1]

            func = numpy.vectorize(_in_bin, otypes=[bool])
            include = func(times)

            # find which single-ifo events to keep
            ifo_index = {
                ifo: numpy.unique(
                    h5in["network/{}_event_id".format(ifo)][()][include],
                ) for ifo in ifos
            }

            # generate output file
            outf = os.path.join(
                outdir,
                '{}-{}_{}-{}-{}.h5'.format(
                    ifotag, filetag, bin_, seg[0], abs(seg),
                ),
            )

            nsets = sum(isinstance(item, h5py.Dataset) for
                        group in h5in.values() for
                        item in group.values())
            msg = "Slicing {} network events for {}".format(
                include.sum(),
                bin_,
            )
            bar = tqdm.tqdm(total=nsets, desc=msg, disable=not verbose,
                            unit="datasets", **TQDM_KW)
            with HFile(outf, "w") as h5out:
                for old in h5in["network"].values():
                    if isinstance(old, h5py.Dataset):
                        h5out.create_dataset(
                            old.name,
                            data=old[()][include],
                            compression=old.compression,
                            compression_opts=old.compression_opts,
                        )
                        bar.update()
                    elif isinstance(old, h5py.Group):
                        # If search group, copy it all over
                        if "search" in old.name:
                            h5in.copy(h5in[old.name], h5out, old.name)
                            bar.update()
                for ifo in ifos:
                    idx = numpy.isin(h5in[ifo]["event_id"][()], ifo_index[ifo])
                    for old in h5in[ifo].values():
                        if isinstance(old, h5py.Dataset):
                            h5out.create_dataset(
                                old.name,
                                data=old[()][idx],
                                compression=old.compression,
                                compression_opts=old.compression_opts,
                            )
                            bar.update()
                        elif isinstance(old, h5py.Group):
                            if "search" in old.name:
                                h5in.copy(h5in[old.name], h5out, old.name)
                                bar.update()
            bar.close()
            logging.info("%s written to %s", bin_, outf)


def read_segment_files(segdir):
    segs = {}
    for name, filename in {
            "buffer": "bufferSeg.txt",
            "off": "offSourceSeg.txt",
            "on": "onSourceSeg.txt",
    }.items():
        try:
            with open(os.path.join(segdir, filename), "r") as f:
                segs[name], = fromsegwizard(f)
        except ValueError as exc:
            exc.args = ("more than one segment, an error has occurred",)
            raise
    return segs


# -- parse command line -------------------------

parser = argparse.ArgumentParser(description=__doc__)

add_common_pycbc_options(parser)

# tags
parser.add_argument(
    "-i",
    "--ifo-tag",
    required=True,
    help="the IFO tag, e.g. H1L1",
)
parser.add_argument(
    "-u",
    "--user-tag",
    default="PYGRB",
    type=str.upper,
    help="the user tag (default: %(default)s)",
)
parser.add_argument(
    "-j",
    "--job-tag",
    type=str.upper,
    help="the job tag, for use when more than one trig_combiner "
         "job is included in a workflow",
)
parser.add_argument(
    "-S",
    "--slide-tag",
    type=str.upper,
    help="the slide tag, used to differentiate long slides",
)

# run parameters
parser.add_argument(
    "-n", "--grb-name",
    type=str.upper,
    help="GRB event name, e.g. 010203",
)
parser.add_argument(
    "-T",
    "--num-trials",
    type=int,
    default=6,
    help="The number of off source trials, default: %(default)d",
)
parser.add_argument(
    "-e",
    "--seed",
    type=int,
    default=None,
    help="Seed for the NumPy random number generator, default: %(default)s",
)
parser.add_argument(
    "-p",
    "--trig-start-time",
    type=int,
    required=True,
    help="The start time of the analysis segment",
)
parser.add_argument(
    "-a",
    "--segment-dir",
    required=True,
    help="directory holding buffer, on and off source segment files",
)
parser.add_argument(
    "-s",
    "--short-slides",
    action="store_true",
    help="Did analysis use short time slides?",
)
parser.add_argument(
    "-t",
    "--long-slides",
    action="store_true",
    help="Are these triggers from long time slides?",
)

# input/output
parser.add_argument(
    "-f",
    "--input-files",
    nargs="*",
    required=True,
    metavar="TRIGGER FILE",
    help="read in listed trigger files",
)
parser.add_argument(
    "--bank-file",
    required=True,
    help="full template bank file for mapping hashes to ids",
)
parser.add_argument(
    "-o",
    "--output-dir",
    default=os.getcwd(),
    help="output directory (default: %(default)s)",
)
parser.add_argument(
    "-c",
    "--no-compression",
    action="store_true",
    default=False,
    help="don't compress output files (default: %(default)s)",
)

args = parser.parse_args()

init_logging(args.verbose)

logging.info("Welcome to the PyGRB trigger combiner")

if args.grb_name:
    args.user_tag += "_GRB{}".format(args.grb_name)
if args.job_tag:
    args.user_tag += "_{}".format(args.job_tag)

analysis = segments.segmentlist([file_segment(args.input_files[0])])
start, end = analysis[0]

if args.no_compression:
    compression_kw = {
        "compression": None,
        "compression_opts": None,
    }
else:
    compression_kw = {}

# -- construct segments -------------------------

segs = read_segment_files(args.segment_dir)
trialtime = abs(segs["on"])
bins = {
    "ONSOURCE": segs["on"],
    "OFFSOURCE": analysis - segments.segmentlist([segs["buffer"]]),
}
logging.info("Parsed parameters and generated off-source trials:")
logging.info("           trial time : %d seconds", trialtime)
logging.info("    on-source segment : %s", bins["ONSOURCE"])
logging.info("       buffer segment : %s", segs["buffer"])
logging.info(
    "  off-source segments : [%s]",
    ", \n                         ".join(map(str, bins["OFFSOURCE"]))
)

# The onsource should never be the first or the last bin:
# there should always be two segments in the offsource bins,
# one before and one after the onsource.
# Nevertheless, simply rely on there being at least one stretch of offsource.
offsource_ntrials = [int(abs(b) / trialtime) for b in bins["OFFSOURCE"]]

# Seed random number generator
numpy.random.seed(args.seed)

# Randomly pick num_trials offsource trials
offtrials = numpy.random.choice(
    range(sum(offsource_ntrials)), args.num_trials, replace=False
)
offtrials.sort()
for i, j in enumerate(offtrials):
    if j < offsource_ntrials[0]:
        _ts = bins["OFFSOURCE"][0][0]
    else:
        _ts = bins["OFFSOURCE"][1][0]
        # if the offtrial is in the second offsource segment,
        # subtract the number of trials in the first offsource segment.
        # This avoids the offtrials overflowing out of the offsource segments.
        j -= offsource_ntrials[0]
    _ts += j * trialtime
    _te = _ts + trialtime
    bins["OFFTRIAL_{}".format(i+1)] = seg = segments.segment(_ts, _te)
    if seg not in bins["OFFSOURCE"]:
        raise ValueError(f"off-trial {i+1} not in off-source segments\n"
                         f"off-trial {i+1} : {seg}\n")
    logging.info("          off-trial %d : %s", i + 1, seg)

# -- read triggers ------------------------------

logging.info("Merging events")

if args.short_slides and args.long_slides:
    raise NotImplementedError

outfilename = "{}-{}_ALL_TIMES-{}-{}.h5".format(
    args.ifo_tag, args.user_tag, start, end-start,
)
outfile = os.path.join(args.output_dir, outfilename)
merge_hdf5_files(
    args.input_files, outfile, verbose=args.verbose, **compression_kw
)

logging.info("Binning events")
bin_events(outfile, bins, args.output_dir, args.user_tag, verbose=args.verbose)

logging.info("All done")
