#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1767619825/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python
# Copyright (C) 2015-2023 Alexander Harvey Nitz, Gareth Cabourn Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Prepare files for upload to GraceDB for foreground events
"""
import os
import argparse
import logging
import numpy as np

import igwn_segments as segments

from pycbc import init_logging, add_common_pycbc_options
import pycbc.workflow as wf
from pycbc.types import MultiDetOptionAction
from pycbc.events import select_segments_by_definer
from pycbc.io import HFile
import pycbc.workflow.minifollowups as mini
from pycbc.workflow.core import resolve_url_to_file, resolve_td_option

parser = argparse.ArgumentParser(description=__doc__[1:])
add_common_pycbc_options(parser)
parser.add_argument('--bank-file',
                    help="HDF format template bank file")
parser.add_argument('--statmap-file',
                    help="HDF format clustered coincident trigger result file")
parser.add_argument('--xml-all-file',
                    help="XML format result file containing all events")
parser.add_argument('--single-detector-triggers', nargs='+', action=MultiDetOptionAction,
                    help="HDF format merged single detector trigger files")
parser.add_argument('--inspiral-segments',
                    help="xml segment files containing the inspiral analysis times")
parser.add_argument('--inspiral-data-read-name',
                    help="Name of inspiral segmentlist containing data read in "
                         "by each analysis job.")
parser.add_argument('--inspiral-data-analyzed-name',
                    help="Name of inspiral segmentlist containing data "
                         "analyzed by each analysis job.")
parser.add_argument('--psd-files', nargs='+', action=MultiDetOptionAction,
                    help="HDF format merged single detector PSD files")
parser.add_argument('--ifar-thresh', type=float,
                    help="IFAR threshold for preparing SNR timeseries "
                         "files for upload. Default=No upload prep")

wf.add_workflow_command_line_group(parser)
wf.add_workflow_settings_cli(parser, include_subdax_opts=True)
args = parser.parse_args()

# Default logging level is info: --verbose adds to this
init_logging(args.verbose, default_level=1)

workflow = wf.Workflow(args)

wf.makedir(args.output_dir)

channel_opts = {}
for ifo in workflow.ifos:
    channel_opts[ifo] = workflow.cp.get_opt_tags(
        "workflow",
        "%s-channel-name" % ifo.lower(),
        ""
    )

# create a FileList that will contain all output files
layouts = []
logging.info(
    "Grabbing inputs: template bank, insp segs and XML with all events"
)
tmpltbank_file = resolve_url_to_file(os.path.abspath(args.bank_file))
insp_segs = resolve_url_to_file(os.path.abspath(args.inspiral_segments))
xml_all = resolve_url_to_file(os.path.abspath(args.xml_all_file))

single_triggers = []
psd_files = []
fsdt = {}
insp_data_seglists = {}
insp_analysed_seglists = {}
for ifo in args.single_detector_triggers:
    strig_fname = args.single_detector_triggers[ifo]
    logging.info("Getting %s single-detector trigger file", ifo)
    strig_file = resolve_url_to_file(os.path.abspath(strig_fname),
                                     attrs={'ifos': ifo})
    single_triggers.append(strig_file)

    logging.info("Getting %s PSD file", ifo)
    psd_fname = args.psd_files[ifo]
    psd_file = resolve_url_to_file(os.path.abspath(psd_fname),
                                     attrs={'ifos': ifo})
    psd_files.append(psd_file)

    fsdt[ifo] = HFile(args.single_detector_triggers[ifo], 'r')
    logging.info("Loading inspiral segments information")
    insp_data_seglists[ifo] = select_segments_by_definer(
        args.inspiral_segments,
        segment_name=args.inspiral_data_read_name,
        ifo=ifo)
    insp_analysed_seglists[ifo] = select_segments_by_definer(
        args.inspiral_segments,
        segment_name=args.inspiral_data_analyzed_name,
        ifo=ifo)
    insp_data_seglists[ifo].coalesce()
    insp_analysed_seglists[ifo].coalesce()

f = HFile(args.statmap_file, 'r')
stat = f['foreground/stat'][:]

bank_data = HFile(args.bank_file, 'r')

ifar_limit = args.ifar_thresh
# Get indices of all events which pass the IFAR threshold
event_ifars = f['foreground/ifar'][:]
events_to_read = np.count_nonzero(event_ifars > ifar_limit)
logging.info(
    "%d events exceed the IFAR threshold of %.3f years",
    events_to_read,
    ifar_limit
)
# Sort by IFAR, descending
event_idx = event_ifars.argsort()[::-1][:events_to_read]
# Times and tids need to be reset for this set of events:
times = {}
tids = {}
bank_ids = {}

logging.info("Getting event information")
ifo_list = f.attrs['ifos'].split(' ')
for ifo in ifo_list:
    times[ifo] = f[f'foreground/{ifo}/time'][:][event_idx]
    tids[ifo] = f[f'foreground/{ifo}/trigger_id'][:][event_idx]
bank_ids = f['foreground/template_id'][:][event_idx]

f.close()

for curr_idx in range(event_idx.size):
    logging.info("Event number %d", curr_idx)
    logging.info("Getting template parameters")
    params = mini.get_single_template_params(
        curr_idx,
        times,
        bank_data,
        bank_ids[curr_idx],
        fsdt,
        tids
    )

    # Extract approximant
    try:
        appx = params.pop('approximant')
    except KeyError:
        # approximant not stored in params, use default
        appx = None

    channel_name = ""
    for ifo in ifo_list:
        ifo_chname = resolve_td_option(
            channel_opts[ifo],
            segments.segment(params['mean_time'], params['mean_time'])
        )
        channel_name += ifo_chname + " "

    single_temp_files = []
    for ifo in ifo_list:
        if params['mean_time'] not in insp_analysed_seglists[ifo]:
            logging.info("Mean time %.3f not in segment list",
                         params['mean_time'])
            continue
        logging.info(
            "Making single-template files"
        )
        single_temp_files += mini.make_single_template_files(
            workflow,
            insp_segs,
            ifo,
            args.inspiral_data_read_name,
            args.inspiral_data_analyzed_name,
            params,
            args.output_dir,
            store_file=True,
            tags=args.tags+['upload', str(curr_idx)],
        )

    mini.make_upload_files(
        workflow,
        psd_files,
        single_temp_files,
        xml_all,
        curr_idx,
        appx,
        args.output_dir,
        channel_name,
        tags=args.tags+['upload', str(curr_idx)]
    )

workflow.save()
