#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1767619825/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot single IFO/coherent/reweighted/null SNR timeseries for a PyGRB run.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import numpy
import matplotlib.pyplot as plt
from matplotlib import rc
import pycbc.version
from pycbc import init_logging
from pycbc.events.coherent import reweightedsnr_cut
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')
rc('font', size=14)

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_snr_timeseries"


# =============================================================================
# Functions
# =============================================================================
# Find start and end times of trigger/injecton data relative to a given time
def get_start_end_times(data_time, central_time):
    """Determine padded start and end times of data relative to central_time"""

    start = int(min(data_time)) - central_time
    end = int(max(data_time)) - central_time
    duration = end - start
    start -= duration*0.05
    end += duration*0.05

    return start, end


# Reset times so that t=0 is corresponds to the given trigger time
def reset_times(data_time, trig_time):
    """Reset times so that t=0 corresponds to the trigger time provided"""

    data_time = [t-trig_time for t in data_time]

    return data_time


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__)
parser.add_argument("-t", "--trig-file", action="store",
                    default=None, required=True,
                    help="The location of the trigger file")
parser.add_argument("--found-missed-file",
                    help="The hdf injection results file", required=False)
parser.add_argument("--trigger-time", type=float, default=0,
                    help="External GPS time.  Used to center the plot.")
parser.add_argument("-y", "--y-variable", default=None,
                    choices=['coherent', 'single', 'reweighted', 'null'],
                    help="Quantity to plot on the vertical axis.")
parser.add_argument("--onsource", default=False, action="store_true",
                    help="Include onsource data in the plot (CAUTION!)")
ppu.pygrb_add_bestnr_cut_opt(parser)
ppu.pygrb_add_slide_opts(parser)
opts = parser.parse_args()
ppu.slide_opts_helper(opts)

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
inj_file = \
    os.path.abspath(opts.found_missed_file) if opts.found_missed_file else None
snr_type = opts.y_variable
ifo = opts.ifo
if snr_type == 'single' and ifo is None:
    err_msg = "Please specify an interferometer for a single IFO plot"
    parser.error(err_msg)

logging.info("Imported and ready to go.")

# Set output directories
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs
ifos = ppu.extract_ifos(trig_file)

# Generate time-slides dictionary
slide_dict = ppu.load_time_slides(trig_file)

# Generate segments dictionary
segment_dict = ppu.load_segment_dict(trig_file)

# Construct trials removing vetoed times
trial_dict, total_trials = ppu.construct_trials(
    opts.seg_files,
    segment_dict,
    ifos,
    slide_dict,
    opts.veto_file,
    hide_onsource=(not opts.onsource)
)

# Load trigger and injections data: when plotting reweighted SNR, keep all
# points to show the impact of the cut, otherwise remove points with
# reweighted SNR below threshold
rw_snr_threshold = None if snr_type == 'reweighted' else opts.newsnr_threshold
# When including the onsource, avoid printing to screen via logging the number
# of triggers found
data_tag = 'trigs' if opts.onsource else None
trig_data = ppu.load_data(trig_file, ifos, data_tag=data_tag,
                          rw_snr_threshold=rw_snr_threshold,
                          slide_id=opts.slide_id)
inj_data = ppu.load_data(inj_file, ifos, data_tag='injs',
                         rw_snr_threshold=rw_snr_threshold,
                         slide_id=0)

# Specify HDF file keys for x quantity (time) and y quantity (SNR)
if opts.y_variable == 'single':
    x_key = opts.ifo + '/end_time'
    y_key = opts.ifo + '/snr'
else:
    x_key = 'network/end_time_gc'
    y_key = 'network/' + opts.y_variable + '_snr'

# Extract needed trigger properties and store them as dictionaries
# Based on trial_dict: if vetoes were applied, trig_* are the veto survivors
found_trigs = ppu.extract_trig_properties(
    trial_dict,
    trig_data,
    slide_dict,
    segment_dict,
    [x_key, y_key]
)

# Gather injections found surviving vetoes
found_after_vetoes, *_ = ppu.apply_vetoes_to_found_injs(
    opts.found_missed_file,
    inj_data,
    ifos,
    veto_file=opts.veto_file,
    keys=[x_key, y_key]
)

# Obtain times and SNRs
trig_data_time = numpy.concatenate([found_trigs[x_key][slide_id][:]
                                    for slide_id in slide_dict])
trig_data_snr = numpy.concatenate([found_trigs[y_key][slide_id][:]
                                   for slide_id in slide_dict])
inj_data_time = found_after_vetoes[x_key][:] if inj_file else None
inj_data_snr = found_after_vetoes[y_key][:] if inj_file else None

# Apply reweighted SNR threshold keeping downweighted points
if snr_type == 'reweighted':
    trig_data_snr = reweightedsnr_cut(
        trig_data_snr,
        opts.newsnr_threshold
    )
    if inj_file:
        inj_data_snr = reweightedsnr_cut(
            inj_data_snr,
            opts.newsnr_threshold
        )

# Determine the central time (t=0) for the plot
central_time = opts.trigger_time

# Determine trigger data start and end times relative to the central time
start, end = get_start_end_times(trig_data_time, central_time)

# Reset trigger and injection times
trig_data_time = reset_times(trig_data_time, central_time)
if inj_file:
    inj_data_time = reset_times(inj_data_time, central_time)

# Generate plots
logging.info("Plotting...")

# Determine what goes on the vertical axis
y_labels = {'coherent': "Coherent SNR",
            'single': f"{ifo} SNR",
            'null': "Null SNR",
            'reweighted': "Reweighted SNR"}
y_label = y_labels[snr_type]

# Determine title and caption
if opts.plot_title is None:
    opts.plot_title = y_label + " vs Time"
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: background triggers.  ")
    if inj_file:
        opts.plot_caption += ("Red crosses: injections triggers.")

# Single IFO SNR versus time plots
if not opts.x_lims:
    opts.x_lims = f"{start},{end}"
plu.pygrb_plotter([trig_data_time, trig_data_snr],
                  [inj_data_time, inj_data_snr],
                  f"Time since {central_time:.3f} (s)", y_label,
                  opts, cmd=' '.join(sys.argv))
