#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot null statistic or coincident SNR vs coherent SNR for a PyGRB run.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import numpy
from matplotlib import rc
import matplotlib.pyplot as plt
import pycbc.version
from pycbc import init_logging
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_null_stats"


# =============================================================================
# Functions
# =============================================================================
# Function that produces the contrours to be plotted
def calculate_contours(opts, new_snrs=None):
    """Generate the contours to plot"""

    # Add the new SNR threshold contour to the list if necessary
    if new_snrs is None:
        new_snrs = [5.5, 6, 6.5, 7, 8, 9, 10, 11]
    new_snr_thresh = opts.newsnr_threshold
    if new_snr_thresh not in new_snrs:
        new_snrs.append(new_snr_thresh)

    # Get SNR values for contours
    snr_low_vals = numpy.arange(4, 30, 0.1)
    snr_high_vals = numpy.arange(30, 500, 1)
    snr_vals = numpy.asarray(list(snr_low_vals) + list(snr_high_vals))

    # Determine contour consistenly with null_snr in coherent.py
    null_cont = []
    null_thresh = opts.null_snr_threshold
    null_grad_snr = opts.null_grad_thresh
    null_grad_val = opts.null_grad_val
    for snr in snr_vals:
        if snr > null_grad_snr:
            null_cont.append(null_thresh + (snr-null_grad_snr)*null_grad_val)
        else:
            null_cont.append(null_thresh)
    null_cont = numpy.asarray(null_cont)

    return null_cont, snr_vals


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__)
parser.add_argument("-t", "--trig-file", action="store",
                    default=None, required=True,
                    help="The location of the trigger file")
parser.add_argument("--found-missed-file",
                    help="The hdf injection results file", required=False)
parser.add_argument("-z", "--zoom-in", default=False, action="store_true",
                    help="Output file a zoomed in version of the plot.")
parser.add_argument("-y", "--y-variable", default=None,
                    choices=['coincident', 'null'],  # TODO: overwhitened?
                    help="Quantity to plot on the vertical axis.")
ppu.pygrb_add_null_snr_opts(parser)
ppu.pygrb_add_bestnr_cut_opt(parser)
ppu.pygrb_add_slide_opts(parser)
opts = parser.parse_args()
ppu.slide_opts_helper(opts)

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
found_missed_file = os.path.abspath(opts.found_missed_file)\
    if opts.found_missed_file else None
zoom_in = opts.zoom_in

# Prepare plot title and caption
y_labels = {'null': "Null SNR",
            'coincident': "Coincident SNR"}  # TODO: overwhitened
if opts.plot_title is None:
    opts.plot_title = y_labels[opts.y_variable] + " vs Coherent SNR"
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: background triggers.  ")
    if found_missed_file:
        opts.plot_caption = opts.plot_caption +\
                            ("Red crosses: injections triggers.  ")

    if opts.y_variable == 'coincident':
        opts.plot_caption += ("Green line: coincident SNR = coherent SNR.")
    else:
        opts.plot_caption = opts.plot_caption +\
                             "Black line: veto line.  " +\
                             "Magenta line: above this triggers have " +\
                             "reduced detection statistic."

logging.info("Imported and ready to go.")

# Set output directories
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs
ifos = ppu.extract_ifos(trig_file)

# Generate time-slides dictionary
slide_dict = ppu.load_time_slides(trig_file)

# We will be looping over slide_dict below so here we reduce it if possible
if opts.slide_id is not None:
    for key in list(slide_dict.keys()):
        if key != opts.slide_id:
            slide_dict.pop(key, None)

# Generate segments dictionary
segment_dict = ppu.load_segment_dict(trig_file)

# Construct trials removing vetoed times
trial_dict, total_trials = ppu.construct_trials(
    opts.seg_files,
    segment_dict,
    ifos,
    slide_dict,
    opts.veto_file
)

# Load triggers/injections (apply reweighted SNR cut, not vetoes)
trig_data = ppu.load_data(trig_file, ifos, data_tag='trigs',
                          rw_snr_threshold=opts.newsnr_threshold,
                          slide_id=opts.slide_id)
inj_data = ppu.load_data(found_missed_file, ifos, data_tag='injs',
                         rw_snr_threshold=opts.newsnr_threshold,
                         slide_id=0)

# Extract needed trigger properties and store them as dictionaries
# Based on trial_dict: if vetoes were applied, trig_* are the veto survivors
# Coherent SNR is always used
x_key = 'network/coherent_snr'
# The other SNR may vary
y_key = 'network/' + opts.y_variable + '_snr'
found_trigs_slides = ppu.extract_trig_properties(
    trial_dict,
    trig_data,
    slide_dict,
    segment_dict,
    [x_key, y_key]
)
found_trigs = {}
for key in [x_key, y_key]:
    found_trigs[key] = numpy.concatenate(
       [found_trigs_slides[key][slide_id][:] for slide_id in slide_dict]
    )

# Gather injections found surviving vetoes
found_injs, *_ = ppu.apply_vetoes_to_found_injs(
    found_missed_file,
    inj_data,
    ifos,
    veto_file=opts.veto_file,
    keys=[x_key, y_key]
)

# Generate plots
logging.info("Plotting...")

# Contours
cont_colors = ['g-']
snr_vals = None
shade_cont_value = None
# Coincident SNR plot case: we want a coinc=coh diagonal line on the plot
if y_key == 'network/coincident_snr':
    x_max = plu.axis_max_value(found_trigs[x_key], found_injs[x_key],
                               found_missed_file)
    snr_vals = [4, x_max]
    null_stat_conts = [[4, x_max]]
# Overwhitened null stat (null SNR) and null stat  cases: newSNR contours
else:
    cont_colors = ['k-', 'm-']
    null_cont, snr_vals = calculate_contours(opts, new_snrs=None)
    null_stat_conts = [null_cont]
    if zoom_in:
        null_thresh = opts.null_snr_threshold
        null_stat_conts.append(numpy.asarray(null_cont) - 1)
    shade_cont_value = 0

# Overwhitened null stat (null SNR), null stat or coincident SNR vs
# Coherent SNR plot
if not opts.x_lims and zoom_in:
    opts.x_lims = '6,30'
if not opts.y_lims and zoom_in:
    opts.y_lims = '0,30'
# Get rcParams
rc('font', size=14)
plu.pygrb_plotter([found_trigs[x_key], found_trigs[y_key]],
                  [found_injs[x_key], found_injs[y_key]],
                  "Coherent SNR", y_labels[opts.y_variable], opts,
                  snr_vals=snr_vals, conts=null_stat_conts,
                  shade_cont_value=shade_cont_value,
                  colors=cont_colors, vert_spike=True,
                  cmd=' '.join(sys.argv))
