#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot single IFO SNR vs coherent SNR for a PyGRB run.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import collections
import operator
from matplotlib import pyplot as plt
from matplotlib import rc
import numpy
import scipy

import pycbc.version
from pycbc import init_logging
from pycbc.detector import Detector
from pycbc.results import save_fig_with_metadata
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')
rc('font', size=14)

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_coh_ifosnr"


# =============================================================================
# Functions
# =============================================================================
# Plot lines representing deviations based on non-central chi-square
def plot_deviation(percentile, snr_grid, y, ax, style):
    """Plot deviations based on non-central chi-square"""

    # ncx2: non-central chi-squared; ppf: percent point function
    # ax.plot(snr_grid, scipy.stats.ncx2.ppf(percentile, 2, y*y)**0.5, style)

    # Using interpolation to work around "saturation" given by the
    # original code line (commented out above)
    y_vals = scipy.stats.ncx2.ppf(percentile, 2, y * y) ** 0.5
    y_vals = numpy.unique(y_vals)
    x_vals = snr_grid[0:len(y_vals)]
    n_vals = int(len(y_vals) / 2)
    f = scipy.interpolate.interp1d(
        x_vals[0:n_vals],
        y_vals[0:n_vals],
        bounds_error=False,
        fill_value="extrapolate",
    )
    ax.plot(snr_grid, f(snr_grid), style)


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__)
parser.add_argument(
    "-t",
    "--trig-file",
    action="store",
    default=None,
    required=True,
    help="The location of the trigger file",
)
parser.add_argument(
    "--found-missed-file",
    help="The hdf injection results file",
    required=False,
)
parser.add_argument(
    "-z",
    "--zoom-in",
    default=False,
    action="store_true",
    help="Output file a zoomed in version of the plot.",
)
ppu.pygrb_add_bestnr_cut_opt(parser)
ppu.pygrb_add_slide_opts(parser)
opts = parser.parse_args()
ppu.slide_opts_helper(opts)

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
found_missed_file = (
    os.path.abspath(opts.found_missed_file) if opts.found_missed_file else None
)
zoom_in = opts.zoom_in
if opts.ifo is None:
    err_msg = "Please specify an interferometer"
    parser.error(err_msg)

if opts.plot_title is None:
    opts.plot_title = opts.ifo + " SNR vs Coherent SNR"
if opts.plot_caption is None:
    opts.plot_caption = "Blue crosses: background triggers.  "
    if found_missed_file:
        opts.plot_caption += "Red crosses: injections triggers.  "
    opts.plot_caption = (
        opts.plot_caption
        + "Black line: veto line.  "
        + "Gray shaded region: vetoed area - The cut is "
        + "applied only to the two most sensitive detectors, "
        + "which can vary with template parameters and sky location. "
        + "Green lines: the expected SNR for optimally "
        + "oriented injections (mean, min, and max).  "
        + "Magenta lines: 2 sigma errors.  "
        + "Cyan lines: 3 sigma errors."
    )

logging.info("Imported and ready to go.")

# Set output directories
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs
ifos = ppu.extract_ifos(trig_file)

# Generate time-slides dictionary
slide_dict = ppu.load_time_slides(trig_file)

# Generate segments dictionary
segment_dict = ppu.load_segment_dict(trig_file)

# Construct trials removing vetoed times
trial_dict, total_trials = ppu.construct_trials(
    opts.seg_files,
    segment_dict,
    ifos,
    slide_dict,
    opts.veto_file
)

# Load triggers/injections (apply reweighted SNR cut, not vetoes)
trig_data = ppu.load_data(trig_file, ifos, data_tag='trigs',
                          rw_snr_threshold=opts.newsnr_threshold,
                          slide_id=opts.slide_id)
inj_data = ppu.load_data(found_missed_file, ifos, data_tag='injs',
                         rw_snr_threshold=opts.newsnr_threshold,
                         slide_id=0)

# Extract needed trigger properties and store them as dictionaries
# Based on trial_dict: if vetoes were applied, trig_* are the veto survivors
# Coherent SNR is always used
x_key = 'network/coherent_snr'
keys = [x_key]
# Get parameters necessary for antenna responses
keys += ['network/ra', 'network/dec', 'network/end_time_gc']
# Get event_ids
keys += [ifo+'/event_id' for ifo in ifos]
keys += ['network/event_id']
# Get single ifo SNR data
keys += [ifo+'/snr' for ifo in ifos]
# Get sigma for each ifo
keys += [ifo+'/sigmasq' for ifo in ifos]
found_trigs_slides = ppu.extract_trig_properties(
    trial_dict,
    trig_data,
    slide_dict,
    segment_dict,
    keys
)
found_trigs = {}
for key in keys:
    found_trigs[key] = numpy.concatenate(
       [found_trigs_slides[key][slide_id][:] for slide_id in slide_dict]
    )

# Complete the dictionary found_trigs
# 1) Do not assume individual IFO and network event_ids are sorted the same way
for ifo in ifos:
    sorted_ifo_ids = numpy.array(
        [
            numpy.nonzero(found_trigs[ifo + '/event_id'] == idx)[0][0]
            for idx in found_trigs['network/event_id']
        ]
    )
    for key in [ifo+'/snr', ifo+'/sigmasq']:
        found_trigs[key] = found_trigs[key][sorted_ifo_ids]

# 2) Get antenna response based parameters
found_trigs['sigma_tot'] = numpy.zeros(len(found_trigs[x_key]))
for ifo in ifos:
    antenna = Detector(ifo)
    ifo_f_resp = ppu.get_antenna_responses(
        antenna,
        found_trigs['network/ra'],
        found_trigs['network/dec'],
        found_trigs['network/end_time_gc']
    )
    # Get the average for f_resp_mean and calculate sigma_tot
    found_trigs[ifo+'/f_resp_mean'] = ifo_f_resp.mean()
    found_trigs['sigma_tot'] += found_trigs[ifo+'/sigmasq'] * ifo_f_resp

# 3) Calculate the mean, max, and min sigmas
for ifo in ifos:
    sigma_norm = found_trigs[ifo+'/sigmasq'] / found_trigs['sigma_tot']
    found_trigs[ifo+'/sigma_mean'] = sigma_norm.mean() \
        if len(sigma_norm) else 0
    if ifo == opts.ifo:
        found_trigs['sigma_max'] = sigma_norm.max() if len(sigma_norm) else 0
        found_trigs['sigma_min'] = sigma_norm.min() if len(sigma_norm) else 0

# Gather injections found surviving vetoes
found_injs, *_ = ppu.apply_vetoes_to_found_injs(
    found_missed_file,
    inj_data,
    ifos,
    veto_file=opts.veto_file,
    keys=keys
)

# Generate plots
logging.info("Plotting...")

# Order the IFOs by sensitivity
ifo_sensitivity = {
    ifo: found_trigs[ifo+'/f_resp_mean'] * found_trigs[ifo+'/sigma_mean']
    for ifo in ifos
}
ifo_sensitivity = collections.OrderedDict(
    sorted(ifo_sensitivity.items(), key=operator.itemgetter(1), reverse=True)
)
loudness_labels = ['first', 'second', 'third']

# Determine the maximum coherent SNR value we are dealing with
x_max = plu.axis_max_value(
    found_trigs[x_key], found_injs[x_key], found_missed_file
)
max_snr = x_max
if x_max < 50.0:
    max_snr = 50.0

# Determine the maximum auto veto value we are dealing with
y_key = opts.ifo+'/snr'
y_max = plu.axis_max_value(
    found_trigs[y_key], found_injs[y_key], found_missed_file
)

# Setup the plots
x_label = "Coherent SNR"
y_label = opts.ifo + " SNR"
fig = plt.figure()
ax = fig.gca()
# Plot trigger data
ax.plot(found_trigs[x_key], found_trigs[y_key], 'bx')
ax.grid()
# Plot injection data
if found_missed_file:
    ax.plot(found_injs[x_key], found_injs[y_key], 'r+')
# Sigma-mean, min, max
y_data = {
    'mean': found_trigs[opts.ifo+'/sigma_mean'],
    'min': found_trigs['sigma_min'],
    'max': found_trigs['sigma_max'],
}
# Calculate: zoom-snr * sqrt(response * sigma-mean, min, max)
snr_grid = numpy.arange(0.01, max_snr, 1)
y_data = dict(
    (key,
     snr_grid * (found_trigs[opts.ifo+'/f_resp_mean'] * val) ** 0.5)
    for key, val in y_data.items()
)
for key in y_data:
    ax.plot(snr_grid, y_data[key], 'g-')
# 2 sigma (0.9545)
plot_deviation(0.02275, snr_grid, y_data['min'], ax, 'm-')
plot_deviation(1 - 0.02275, snr_grid, y_data['max'], ax, 'm-')
# 3 sigma (0.9973)
plot_deviation(0.00135, snr_grid, y_data['min'], ax, 'c-')
plot_deviation(1 - 0.00135, snr_grid, y_data['max'], ax, 'c-')
# Non-zoomed plot
ax.plot([0, max_snr], [4, 4], 'k-')
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_xlim([0, 1.01 * x_max])
ax.set_ylim([0, 1.20 * y_max])
# Veto applies to the two most sensitive IFOs, so shade them
loudness_index = list(ifo_sensitivity.keys()).index(opts.ifo)
if loudness_index < 2:
    limy = ax.get_ylim()[0]
    polyx = [0, max_snr]
    polyy = [4, 4]
    polyx.extend([max_snr, 0])
    polyy.extend([limy, limy])
    ax.fill(polyx, polyy, color='#dddddd')
opts.plot_title += (
    f" ({loudness_labels[loudness_index]} average sensitivity)"
)
# Zoom in if asked to do so
if opts.zoom_in:
    ax.set_xlim([6, 50])
    ax.set_ylim([0, 20])
# Save plot
save_fig_with_metadata(
    fig,
    opts.output_file,
    cmd=' '.join(sys.argv),
    title=opts.plot_title,
    caption=opts.plot_caption,
)
plt.close()
