#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Produces signal consistency plots of the form network power/bank/auto or
single detector chi-square vs single IFO/coherent/reweighted/null/coinc SNR.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import numpy
from matplotlib import pyplot as plt
from matplotlib import rc
import pycbc.version
from pycbc import init_logging
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')
rc('font', size=14)

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_chisq_veto"


# =============================================================================
# Functions
# =============================================================================
# Function to calculate chi-square weight for the reweighted SNR
def new_snr_chisq(snr, new_snr, chisq_index=4.0, chisq_nhigh=3.0):
    """Returns the chi-square value needed to weight SNR into new SNR"""

    chisqnorm = (snr/new_snr)**chisq_index
    if chisqnorm <= 1:
        return 1E-20

    return (2*chisqnorm - 1)**(chisq_nhigh/chisq_index)


# Function that produces the contours to be plotted
def calculate_contours(opts, new_snrs=None):
    """Generate the contours for the veto plots"""

    # Add the new SNR threshold contour to the list if necessary
    # and keep track of where it is
    if new_snrs is None:
        new_snrs = [5.5, 6, 6.5, 7, 8, 9, 10, 11]
    try:
        cont_value = new_snrs.index(opts.newsnr_threshold)
    except ValueError:
        new_snrs.append(opts.newsnr_threshold)
        cont_value = -1

    # Get SNR values for contours
    snr_low_vals = numpy.arange(1, 30, 0.1)
    snr_high_vals = numpy.arange(30, 500, 1)
    snr_vals = numpy.asarray(list(snr_low_vals) + list(snr_high_vals))

    # Initialise contours
    contours = numpy.zeros([len(new_snrs), len(snr_vals)],
                           dtype=numpy.float64)

    # Loop over SNR values and calculate chisq variable needed
    for j, snr in enumerate(snr_vals):
        for i, new_snr in enumerate(new_snrs):
            contours[i][j] = new_snr_chisq(snr, new_snr,
                                           opts.chisq_index,
                                           opts.chisq_nhigh)

    # Colors and styles of the contours
    colors = ["k-" if snr == opts.newsnr_threshold else
              "y-" if snr == int(snr) else
              "y--" for snr in new_snrs]

    return contours, snr_vals, cont_value, colors


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__)
parser.add_argument("-t", "--trig-file", action="store",
                    default=None, required=True,
                    help="The location of the trigger file")
parser.add_argument("--found-missed-file",
                    help="The hdf injection results file", required=False)
parser.add_argument("-z", "--zoom-in", default=False, action="store_true",
                    help="Output file a zoomed in version of the plot.")
parser.add_argument("-y", "--y-variable", required=True,
                    choices=['network', 'bank', 'auto', 'power'],
                    help="Quantity to plot on the vertical axis.")
parser.add_argument("--snr-type", default='coherent',
                    choices=['coherent', 'coincident', 'null', 'reweighted',
                             'single'], help="SNR value to plot on x-axis.")
ppu.pygrb_add_bestnr_cut_opt(parser)
ppu.pygrb_add_bestnr_opts(parser)
ppu.pygrb_add_slide_opts(parser)
opts = parser.parse_args()
ppu.slide_opts_helper(opts)

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
found_missed_file = os.path.abspath(opts.found_missed_file) \
    if opts.found_missed_file else None
zoom_in = opts.zoom_in
veto_type = opts.y_variable
ifo = opts.ifo
snr_type = opts.snr_type
# If this is false, coherent SNR is used on the horizontal axis
# otherwise the single IFO SNR is used
if snr_type == 'single':
    if ifo is None:
        err_msg = "--ifo must be given to plot single IFO SNR veto"
        parser.error(err_msg)

# Veto is intended as a single IFO quantity. Network chisq will be obsolete.
# TODO: fix vetoes

# Prepare plot title and caption
veto_labels = {'network': "Network Power",
               'bank': "Bank",
               'auto': "Auto",
               'power': "Power"}
if opts.plot_title is None:
    opts.plot_title = veto_labels[veto_type] + " Chi Square"
    if veto_type != 'network':
        opts.plot_title = ifo + opts.plot_title
    if snr_type == 'single':
        opts.plot_title += f" vs {ifo} SNR"
    else:
        opts.plot_title += f" vs {snr_type.capitalize()} SNR"
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: background triggers. ")
    if found_missed_file:
        opts.plot_caption += "Red crosses: injections triggers. "
    if veto_type == 'network':
        opts.plot_caption += ("Gray shaded region: area cut by the " +
                              "reweighted SNR threshold. " +
                              "Black line: reweighted SNR threshold. Yellow " +
                              "lines: contours of constant reweighted SNR.")

logging.info("Imported and ready to go.")

# Set output directory
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs
ifos = ppu.extract_ifos(trig_file)

# Generate time-slides dictionary
slide_dict = ppu.load_time_slides(trig_file)

# Generate segments dictionary
segment_dict = ppu.load_segment_dict(trig_file)

# Construct trials removing vetoed times
trial_dict, total_trials = ppu.construct_trials(
    opts.seg_files,
    segment_dict,
    ifos,
    slide_dict,
    opts.veto_file
)

# Load trigger and injections data: ensure that newtwork power chi-square plots
# show all the data to see the impact of the reweighted SNR cut, otherwise
# remove points with reweighted SNR below threshold
rw_snr_threshold = None if veto_type == 'network' else opts.newsnr_threshold
trig_data = ppu.load_data(trig_file, ifos, data_tag='trigs',
                          rw_snr_threshold=rw_snr_threshold,
                          slide_id=opts.slide_id)
inj_data = ppu.load_data(found_missed_file, ifos, data_tag='injs',
                         rw_snr_threshold=rw_snr_threshold,
                         slide_id=0)

# Dataset name for the horizontal direction
if snr_type == 'single':
    x_key = ifo + '/snr'
else:
    x_key = 'network/' + snr_type + '_snr'
# Dataset name for the vertical direction and for normalization
if veto_type == 'power':
    y_key = opts.ifo + '/chisq'
elif veto_type in ['bank', 'auto']:
    y_key = opts.ifo + '/' + veto_type + '_chisq'
else:
    y_key = 'network/my_network_chisq'

keys = [x_key, y_key]
# The network chi-square is already normalized so it does not require a key
# for the number of degrees of freedom
if veto_type != 'network':
    keys += [y_key + '_dof']

# Extract needed trigger properties and store them as dictionaries
# Based on trial_dict: if vetoes were applied, trig_* are the veto survivors
found_trigs_slides = ppu.extract_trig_properties(
    trial_dict,
    trig_data,
    slide_dict,
    segment_dict,
    keys
)
found_trigs = {}
for key in keys:
    found_trigs[key] = numpy.concatenate(
       [found_trigs_slides[key][slide_id][:] for slide_id in slide_dict]
    )

# Gather injections found surviving vetoes
found_injs, *_ = ppu.apply_vetoes_to_found_injs(
    opts.found_missed_file,
    inj_data,
    ifos,
    veto_file=opts.veto_file,
    keys=keys
)

# Sanity checks
for test in zip(keys[0:2], ['x', 'y']):
    if found_trigs[test[0]] is None and found_injs[test[0]] is None:
        err_msg = "No data to be plotted on the " + test[1] + "-axis was found"
        raise RuntimeError(err_msg)

# Normalize chi-squares with the number of degrees of freedom
if len(keys) == 3:
    found_trigs[keys[1]] /= found_trigs[keys[2]]
    found_injs[keys[1]] /= found_injs[keys[2]]

# Single detector chi-squares are initialized to 0: we floor possible
# remaining 0s to 0.005 to avoid asking for logarithms of 0 in the plot
numpy.putmask(found_trigs[y_key], found_trigs[y_key] == 0, 0.005)

# Generate plots
logging.info("Plotting...")

# Determine x-axis values of triggers and injections
# Default is coherent SNR
x_label = ifo if snr_type == 'single' else snr_type.capitalize()
x_label += " SNR"

# Determine the minumum and maximum SNR value we are dealing with
x_min = 0.9*plu.axis_min_value(found_trigs[x_key], found_injs[x_key],
                               found_missed_file)
x_max = 1.1*plu.axis_max_value(found_trigs[x_key], found_injs[x_key],
                               found_missed_file)

# Determine the minimum and maximum chi-square value we are dealing with
y_min = 0.9*plu.axis_min_value(found_trigs[y_key], found_injs[y_key],
                               found_missed_file)
y_max = 1.1*plu.axis_max_value(found_trigs[y_key], found_injs[y_key],
                               found_missed_file)

# Determine y-axis label
y_label = "Network power chi-square" if veto_type == 'network' \
    else f"{ifo} Single {veto_labels[veto_type].lower()} chi-square"

# Determine contours for plots
conts = None
snr_vals = None
cont_value = None
colors = None
# Enable countours of constant reweighted SNR as a function of coherent SNR
if snr_type == 'coherent':
    conts, snr_vals, cont_value, colors = calculate_contours(opts,
                                                             new_snrs=None)
# The cut in reweighted SNR involves only the network power chi-square
if veto_type != 'network':
    cont_value = None

# Produce the veto vs. SNR plot
if not opts.x_lims:
    if zoom_in:
        opts.x_lims = str(x_min)+',50'
        opts.y_lims = str(y_min)+',20000'
    else:
        opts.x_lims = str(x_min)+','+str(x_max)
        opts.y_lims = str(y_min)+','+str(10*y_max)
plu.pygrb_plotter([found_trigs[x_key], found_trigs[y_key]],
                  [found_injs[x_key], found_injs[y_key]],
                  x_label, y_label, opts,
                  snr_vals=snr_vals, conts=conts, colors=colors,
                  shade_cont_value=cont_value, vert_spike=True,
                  cmd=' '.join(sys.argv))
