#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

# Copyright (C) 2021 Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Set up qscans and SNR timeseries plots of loudest triggers/missed injections
"""

# =============================================================================
# Preamble
# =============================================================================
import os
import argparse
import logging

import pycbc.workflow as wf
from pycbc.workflow.core import FileList, resolve_url_to_file
import pycbc.workflow.minifollowups as mini
import pycbc.version
import pycbc.events
import pycbc.results.pygrb_postprocessing_utils as ppu
from pycbc.results import layout
from pycbc.workflow.plotting import PlotExecutable
from pycbc.io.hdf import HFile

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_minifollowups"


# =============================================================================
# Functions
# =============================================================================
def add_wiki_row(outfile, cols):
    """
    Adds a wiki-formatted row to an output file from a list or a numpy array.
    """
    with open(outfile, 'a') as f:
        f.write('||%s||\n' % '||'.join(map(str, cols)))


def make_timeseries_plot(workflow, trig_file, snr_type, central_time,
                         out_dir, ifo=None, seg_files=None,
                         veto_file=None, tags=None):
    """Adds a node for a timeseries of PyGRB results to the workflow"""

    tags = [] if tags is None else tags

    # Ensure that zero-lag data is used in these follow-up plots and store
    # the slide-id option originally provided
    orig_slide_id = None
    if workflow.cp.has_option('pygrb_plot_snr_timeseries', 'slide-id'):
        orig_slide_id = workflow.cp.get('pygrb_plot_snr_timeseries',
                                        'slide-id')
    workflow.cp.add_options_to_section(
        'pygrb_plot_snr_timeseries',
        [('slide-id', '0')],
        True
    )

    # Initialize job node with its tags
    grb_name = workflow.cp.get('workflow', 'trigger-name')
    extra_tags = ['GRB'+grb_name]
    extra_tags += [snr_type]
    if ifo is not None:
        extra_tags += [ifo]
    node = PlotExecutable(workflow.cp, 'pygrb_plot_snr_timeseries',
                          ifos=workflow.ifos, out_dir=out_dir,
                          tags=tags+extra_tags).create_node()
    node.add_input_opt('--trig-file', trig_file)
    # Include the onsource trial if this is a follow up on the onsource
    if 'loudest_onsource_event' in tags:
        node.add_opt('--onsource')
    # Pass the segments files and veto file
    if seg_files:
        node.add_input_list_opt('--seg-files', seg_files)
    if veto_file:
        node.add_input_opt('--veto-file', veto_file)
    node.new_output_file_opt(workflow.analysis_time, '.png',
                             '--output-file', tags=extra_tags)
    # Quantity to be displayed on the y-axis of the plot
    node.add_opt('--y-variable', snr_type)
    if ifo is not None:
        node.add_opt('--ifo', ifo)
    node.add_opt('--x-lims=-5.,5.')
    # Plot title
    if ifo is not None:
        title_str = f"'{ifo} SNR at {central_time:.3f} (s)'"
    else:
        title_str = f"'{snr_type.capitalize()} SNR at {central_time:.3f} (s)'"
    node.add_opt('--trigger-time', central_time)
    node.add_opt('--plot-title', title_str)

    # Add job node to workflow
    workflow += node

    # Revert the config parser back to how it was given
    if orig_slide_id:
        workflow.cp.add_options_to_section(
            'pygrb_plot_snr_timeseries',
            [('slide-id', orig_slide_id)],
            True
        )
    else:
        workflow.cp.remove_option('pygrb_plot_snr_timeseries', 'slide-id')

    return node.output_files


# =============================================================================
# Main script starts here
# =============================================================================
parser = argparse.ArgumentParser(description=__doc__[1:])
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--trig-file',
                    help="HDF file with the triggers found by PyGRB")
parser.add_argument('--followups-file',
                    help="HDF file with the triggers/injections to follow up")
parser.add_argument('--wiki-file',
                    help="Name of file to save wiki-formatted table in")
parser.add_argument("-a", "--seg-files", nargs="+", action="store",
                    default=[], help="The location of the buffer, " +
                    "onsource and offsource txt segment files.")
parser.add_argument("-V", "--veto-file", action="store",
                    help="The location of the xml veto file.")
wf.add_workflow_command_line_group(parser)
wf.add_workflow_settings_cli(parser, include_subdax_opts=True)
ppu.pygrb_add_bestnr_cut_opt(parser)
args = parser.parse_args()

pycbc.init_logging(args.verbose,
                   format="%(asctime)s: %(levelname)s: %(message)s")

workflow = wf.Workflow(args)

wf.makedir(args.output_dir)

# Create a FileList that will contain all output files
layouts = []

# Read the file with the triggers/injections to follow up
logging.info('Reading list of triggers/injections to followup')
fp = HFile(args.followups_file, "r")

# Initialize a wiki table and add the column headers
if args.wiki_file:
    wiki_file = os.path.join(args.output_dir, args.wiki_file)
    add_wiki_row(wiki_file, fp.keys())

# Establish the number of follow-ups to perform
num_events = int(workflow.cp.get_opt_tags('workflow-minifollowups',
                                          'num-events', ''))
num_events = min(num_events, len(fp['BestNR'][:]))

# File instance of the input trigger file
trig_file = resolve_url_to_file(os.path.abspath(args.trig_file))

# Determine ifos used in the analysis
ifos = ppu.extract_ifos(os.path.abspath(args.trig_file))
num_ifos = len(ifos)

# File instance of the veto file
veto_file = args.veto_file
start_rundir = os.getcwd()
if veto_file:
    veto_file = os.path.join(start_rundir, args.veto_file)
    veto_file = wf.resolve_url_to_file(veto_file)

# Convert the segments files to a FileList
seg_files = wf.FileList([
    wf.resolve_url_to_file(os.path.join(start_rundir, f))
    for f in args.seg_files
])

# (Loudest) off/on-source events are on time-slid data so the
# try will succeed, as it finds the time shift columns.
is_injection_followup = True
try:
    time_shift = fp[ifos[0]+' time shift (s)'][0]
    is_injection_followup = False
except KeyError:
    pass

# Loop over triggers/injections to be followed up
for num_event in range(num_events):
    files = FileList([])
    logging.info('Processing event: %s', num_event+1)
    gps_time = fp['GPS time'][num_event]
    gps_time = gps_time.astype(float)
    tags = args.tags + [str(num_event+1)]
    if wiki_file:
        row = []
        for key in fp.keys():
            row.append(fp[key][num_event])
        add_wiki_row(wiki_file, row)
    # Handle injections (which are on unslid data)
    if is_injection_followup:
        for snr_type in ['reweighted', 'coherent']:
            files += make_timeseries_plot(workflow, trig_file,
                                          snr_type, gps_time,
                                          args.output_dir, ifo=None,
                                          seg_files=seg_files,
                                          veto_file=veto_file,
                                          tags=tags)
        for ifo in ifos:
            files += mini.make_qscan_plot(workflow, ifo, gps_time,
                                          args.output_dir,
                                          tags=tags)
    # Handle off/on-source loudest triggers follow-up (which may be on slid
    # data in the case of the off-source)
    else:
        for i_ifo, ifo in enumerate(ifos):
            time_shift = fp[ifo+' time shift (s)'][num_event]
            ifo_time = gps_time + time_shift
            files += make_timeseries_plot(workflow, trig_file,
                                          'single', ifo_time,
                                          args.output_dir, ifo=ifo,
                                          seg_files=seg_files,
                                          veto_file=veto_file,
                                          tags=tags)
            files += mini.make_qscan_plot(workflow, ifo, ifo_time,
                                          args.output_dir,
                                          tags=tags)

    layouts += list(layout.grouper(files, 2))

workflow.save()
layout.two_column_layout(args.output_dir, layouts)
