#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1767619825/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

# Copyright (C) 2021 Francesco Pannarale, Viviana Caceres
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot found/missed injection properties for the triggered search (PyGRB).
"""

import logging
import os.path
import sys
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

import pycbc.conversions
import pycbc.results
import pycbc.version
from pycbc import init_logging
from pycbc.results import pygrb_postprocessing_utils as ppu

plt.switch_backend('Agg')
matplotlib.rc("image")

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_injs_results"


# =============================================================================
# Functions
# =============================================================================
def process_var_strings(qty):
    """Add underscores to match HDF column name conventions"""

    qty = qty.replace('skyerror', 'sky_error')
    qty = qty.replace('cos', 'cos_')
    qty = qty.replace('abs', 'abs_')
    qty = qty.replace('endtime', 'end_time')
    qty = qty.replace('spin1a', 'spin1_a')
    qty = qty.replace('spin2a', 'spin2_a')

    return qty


def complete_incl_data(injs, key, tag):
    """Extract data related to inclination from raw injection data"""

    local_dict = {}

    # Whether the user requests incl, |incl|, cos(incl), or cos(|incl|)
    # the following information is needed
    # TODO: make sure it is clear that inclination is interpreted as theta_jn
    local_dict['incl'] = injs[tag+'/thetajn']

    # Requesting |incl| or cos(|incl|)
    if 'abs_' in key:
        local_dict['abs_incl'] = 0.5*np.pi - \
            abs(local_dict['incl'] - 0.5*np.pi)

    # Requesting cos(incl) or cos(|incl|): take cosine
    if 'cos_' in key:
        angle = key.replace('cos_', '')
        angle_data = local_dict[angle]
        data = np.cos(angle_data)
    # Requesting incl or abs_incl: convert to degrees
    else:
        data = np.rad2deg(local_dict[key])

    return data


def complete_mass_data(injs, key, tag):
    """Extract data related to mass ratio, chirp mass or total mass from raw
    injection data"""

    mass1 = injs[tag+'/mass1']
    mass2 = injs[tag+'/mass2']

    if key == 'mtotal':
        data = mass1 + mass2
    elif key == 'mchirp':
        data = pycbc.conversions.mchirp_from_mass1_mass2(mass1, mass2)
    else:
        data = mass2 / mass1
        data = np.where(data > 1, 1./data, data)

    return data


def complete_sky_error_data(injs, tag):
    """Extract data related to sky_error from raw injection and trigger data"""

    # Missed injections are assigned null values
    data = np.full(len(injs[tag+'/mass1']), None)
    if tag == 'found':
        inj = {}
        inj['ra'] = injs[tag+'/ra']
        inj['dec'] = injs[tag+'/dec']
        trig = {}
        trig['ra'] = injs['network/ra']
        trig['dec'] = injs['network/dec']
        data = np.arccos(np.cos(inj['dec'] - trig['dec']) -
                         np.cos(inj['dec']) * np.cos(trig['dec']) *
                         (1 - np.cos(inj['ra'] - trig['ra'])))

    return data


# These are keys in the found-missed-file under 'found' and 'missed'
# TODO: also in the found-missed injs file is 'inclination'
easy_keys = ['distance', 'mass1', 'mass2', 'polarization',
             'spin1_a', 'spin1x', 'spin1y', 'spin1z',
             'spin2_a', 'spin2x', 'spin2y', 'spin2z',
             'spin1_azimuthal', 'spin1_polar',
             'spin2_azimuthal', 'spin2_polar',
             'dec', 'ra', 'phi_ref']


def complete_inj_data(injs, keys, tag, ifos=[]):
    """Create a dictionary containing the data specified by the
    list of keys extracted from an injection file"""

    data_dict = {}

    for key in keys:
        data_dict[key] = np.array([])
        try:
            if key == 'end_time':
                data_dict[key] = injs[tag+'/tc']
                # data_dict[key] -= grb_time
            elif key in ['mchirp', 'mtotal', 'q']:
                data_dict[key] = complete_mass_data(injs, key, tag)
            elif 'incl' in key:
                data_dict[key] = complete_incl_data(injs, key, tag)
            elif key == 'sky_error':
                data_dict[key] = complete_sky_error_data(injs, tag)
            else:
                data_dict[key] = injs[tag+'/'+key]
        except KeyError:
            # raise NotImplemented(key+' not allowed: returning empty entry')
            logging.info("%s/%s not allowed yet", tag, key)

    # Include information needed to apply vetoes and reweighted SNR cut
    if tag == 'found':
        for ifo in ifos:
            key = ifo+'/end_time'
            data_dict[key] = injs[key] if key in injs else np.array([])
        key = 'network/reweighted_snr'
        data_dict['reweighted_snr'] = \
            injs[key] if key in injs else np.array([])

    logging.info("%d %s injections analysed.", len(data_dict[keys[0]]), tag)

    return data_dict


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__)
parser.add_argument("--found-missed-file",
                    help="The hdf injection results file", required=True)
parser.add_argument("--trig-file",
                    help="The hdf offsource trigger file", required=True)
parser.add_argument("--trigger-time",
                    help="The GPS time of the external trigger", required=True)
# FIXME: not all of these work
# TODO: effective distance
admitted_vars = easy_keys + ['mtotal', 'q', 'mchirp',
                             'spin1a', 'spin2a',
                             'incl', 'cos_incl', 'abs_incl', 'cos_abs_incl',
                             'cosincl', 'absincl', 'cosabsincl',
                             'sky_error', 'skyerror', 'end_time', 'endtime',
                             'latitude', 'longitude', 'coaphase', 'coa_phase',
                             'eff_site_dist', 'eff_dist',
                             'effsitedist', 'effdist']
parser.add_argument("-x", "--x-variable", default=None, required=True,
                    choices=admitted_vars,
                    help="Quantity to plot on the horizontal axis. " +
                    "(Underscores may be omitted in specifying this option).")
parser.add_argument("--x-log", action="store_true",
                    help="Use log horizontal axis")
parser.add_argument("-y", "--y-variable", default=None, required=True,
                    choices=admitted_vars,
                    help="Quantity to plot on the vertical axis. " +
                    "(Underscores may be omitted in specifying this option).")
parser.add_argument("--y-log", action="store_true",
                    help="Use log vertical axis")
parser.add_argument("--colormap", default='cividis_r',
                    help="Type of colormap to be used for the plots.")
parser.add_argument("--far-type", choices=('inclusive', 'exclusive'),
                    default='inclusive',
                    help="Type of far to plot for the color. Choices are "
                         "'inclusive' or 'exclusive'. Default = 'inclusive'")
parser.add_argument("--missed-on-top", action='store_true',
                    help="Plot missed injections on top of found ones and "
                         "high FAR on top of low FAR")
ppu.pygrb_add_bestnr_cut_opt(parser)
ppu.pygrb_add_slide_opts(parser)
opts = parser.parse_args()
ppu.slide_opts_helper(opts)

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

x_qty = process_var_strings(opts.x_variable)
y_qty = process_var_strings(opts.y_variable)

if 'eff_site_dist' in [x_qty, y_qty] and opts.ifo is None:
    err_msg = "A value for --ifo must be provided for "
    err_msg += "site specific effective distance"
    parser.error(err_msg)

# Store options used multiple times in local variables
outfile = opts.output_file
trig_file = os.path.abspath(opts.trig_file)
grb_time = opts.trigger_time

# Set output directory
logging.info("Setting output directory.")
outdir = os.path.split(os.path.abspath(outfile))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs
ifos = ppu.extract_ifos(trig_file)

# Generate time-slides dictionary
slide_dict = ppu.load_time_slides(trig_file)

# Generate segments dictionary
segment_dict = ppu.load_segment_dict(trig_file)

# Construct trials removing vetoed times
trial_dict, total_trials = ppu.construct_trials(
    opts.seg_files,
    segment_dict,
    ifos,
    slide_dict,
    opts.veto_file
)

# Load triggers (apply reweighted SNR cut, not vetoes)
all_off_trigs = ppu.load_data(trig_file, ifos, data_tag='trigs',
                              rw_snr_threshold=opts.newsnr_threshold,
                              slide_id=opts.slide_id)

# Extract needed trigger properties and store them as dictionaries
# Based on trial_dict: if vetoes were applied, trig_* are the veto survivors
keys = ['network/end_time_gc', 'network/reweighted_snr']
trig_data = ppu.extract_trig_properties(
    trial_dict,
    all_off_trigs,
    slide_dict,
    segment_dict,
    keys
)

# Max BestNR values in each trial: these are stored in a dictionary keyed
# by slide_id, as arrays indexed by trial number
background = {k: np.zeros(len(v)) for k, v in trial_dict.items()}
for slide_id in slide_dict:
    trig_times = trig_data[keys[0]][slide_id][:]
    for j, trial in enumerate(trial_dict[slide_id]):
        # True whenever the trigger is in the trial
        trial_cut = (trial[0] <= trig_times) & (trig_times < trial[1])
        # Move on if nothing was in the trial
        if not trial_cut.any():
            continue
        # Max BestNR
        background[slide_id][j] = max(trig_data[keys[1]][slide_id][trial_cut])
# Gather all values and max over them
background = list(background.values())
background = np.concatenate(background)
max_bkgd_reweighted_snr = background.max()
assert total_trials == len(background)
logging.info("Background bestNR calculated.")

# =======================
# Post-process injections
# =======================
# Load injection data without applying the reweighted SNR cut, nor vetoes
inj_data = ppu.load_data(opts.found_missed_file, ifos, data_tag='injs',
                         slide_id=0)

# Extract the necessary data from the missed injections for the plot
missed_inj = complete_inj_data(inj_data, [x_qty, y_qty], 'missed', ifos=ifos)

# Extract the necessary data from the found injections for the plot
found_inj = complete_inj_data(inj_data, [x_qty, y_qty], 'found', ifos=ifos)

# Apply reweighted SNR cut
if opts.newsnr_threshold:
    rw_snr_cut = found_inj['reweighted_snr'] < opts.newsnr_threshold
    for key in [x_qty, y_qty]:
        missed_inj[key] = np.concatenate((missed_inj[key],
                                          found_inj[key][rw_snr_cut]))
        found_inj[key] = found_inj[key][~rw_snr_cut]
    for ifo in ifos:
        found_inj[ifo+'/end_time'] = found_inj[ifo+'/end_time'][~rw_snr_cut]
    found_inj['reweighted_snr'] = found_inj['reweighted_snr'][~rw_snr_cut]
    msg = f"After applying reweighted SNR cut at {opts.newsnr_threshold}: "
    msg += f"{len(found_inj[x_qty])} found injections and "
    msg += f"{len(missed_inj[x_qty])} missed injections"
    logging.info(msg)

# Split in injections found surviving vetoes and injections found but vetoed
found_after_vetoes, vetoed, found_idx, _ = ppu.apply_vetoes_to_found_injs(
    opts.found_missed_file,
    found_inj,
    ifos,
    veto_file=opts.veto_file
)

# Extract the detection statistic of injections found after vetoes
# Injections found but not louder than background are populated further down
if len(list(found_after_vetoes['reweighted_snr'])) > 0:
    found_after_vetoes_stat = \
        found_inj['reweighted_snr'][found_idx] \
        if 'found_after_vetoes/stat' not in found_after_vetoes.keys() else \
        found_after_vetoes['found_after_vetoes/stat']

    # Separate triggers into:
    # 1) Found louder than background
    louder_mask = found_after_vetoes_stat > max_bkgd_reweighted_snr
    found_louder = {}
    for key in found_after_vetoes.keys():
        found_louder[key] = found_after_vetoes[key][louder_mask]
    found_louder['reweighted_snr'] = found_after_vetoes_stat[louder_mask]

    # 2) Found quieter than background: injections found (bestnr > 0)
    #    but not louder than background (non-zero FAP)
    quieter_mask = (found_after_vetoes_stat <= max_bkgd_reweighted_snr) \
        & (found_after_vetoes_stat != 0)
    found_quieter = {}
    for key in found_after_vetoes.keys():
        found_quieter[key] = found_after_vetoes[key][quieter_mask]
    found_quieter['reweighted_snr'] = found_after_vetoes_stat[quieter_mask]

    # TODO: ifar still missing
    """
    # Extract inclusive/exclusive IFAR for injections found quieter than
    # background
    ifar_string = 'found_after_vetoes/ifar' if opts.far_type == 'inclusive' \
        else 'found_after_vetoes/ifar_exc'
    found_quieter['ifar'] = f[ifar_string][quieter_mask]
    """
    found_quieter['fap'] = np.array([sum(background > bestnr) for bestnr in
                                     found_quieter['reweighted_snr']],
                                    dtype=float) / total_trials

    # 3) Missed due to vetoes
    # TODO: needs function to cherry-pick a subset of inj_data specified by
    # a mask on FAP values.
else:
    found_louder = {x_qty: np.array([]), y_qty: np.array([])}
    found_quieter = {x_qty: np.array([]), y_qty: np.array([])}

# Statistics: found on top (found-missed)
# TODO: ifar still missing
if found_quieter[x_qty].size:
    # Statistics: found on top (found-missed)
    FM = np.argsort(found_quieter['fap'])
    # Statistics: missed on top (missed-found)
    MF = FM[::-1]

# Post-processing of injections ends here

# ==========
# Make plots
# ==========

# Info for site-specific plots
sitename = {'G1': 'GEO', 'H1': 'Hanford', 'L1': 'Livingston', 'V1': 'Virgo',
            'K1': 'KAGRA', 'A1': 'India Aundha'}

# Take care of axes labels
axis_labels_dict = {'mchirp': "Chirp Mass (solar masses)",
                    'mtotal': "Total mass (solar masses)",
                    'q': "Mass ratio",
                    'distance': "Distance (Mpc)",
                    'eff_site_dist': f"{sitename.get(opts.ifo)} effective "
                                     "distance (Mpc)",
                    'eff_dist': "Inverse sum of effective distances (Mpc)",
                    'end_time': f"Time since {grb_time} (s)",
                    'sky_error': "Rec. sky error (radians)",
                    'coa_phase': "Phase of complex SNR (radians)",
                    'latitude': "Latitude (degrees)",
                    'longitude': "Longitude (degrees)",
                    'incl': "Inclination (iota)",
                    'abs_incl': 'Magnitude of inclination (|iota|)',
                    'cos_incl': "cos(iota)",
                    'cos_abs_incl': "cos(|iota|)",
                    'mass1': "Mass of 1st binary component (solar masses)",
                    'mass2': "Mass of 2nd binary component (solar masses)",
                    'polarization': "Polarization phase (radians)",
                    'spin1_a': "Spin on 1st binary component",
                    'spin1x': "Spin x-component of 1st binary component",
                    'spin1y': "Spin y-component of 1st binary component",
                    'spin1z': "Spin z-component of 1st binary component",
                    'spin2_a': "Spin on 2nd binary component",
                    'spin2x': "Spin x-component of 2nd binary component",
                    'spin2y': "Spin y-component of 2nd binary component",
                    'spin2z': "Spin z-component of 2nd binary component"}

x_label = axis_labels_dict[x_qty]
y_label = axis_labels_dict[y_qty]

fig = plt.figure()
xscale = "log" if opts.x_log else "linear"
yscale = "log" if opts.y_log else "linear"
ax = fig.gca()
ax.set_xscale(xscale)
ax.set_yscale(yscale)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)

# Define p-value colour
cmap = plt.get_cmap('cividis_r')
# Set color for out-of-range values
# cmap.set_over('g')

# Define the 'found' injection colour
fnd_col = cmap(0)
fnd_col = np.array([fnd_col])
if not opts.missed_on_top:
    if missed_inj[x_qty].size and missed_inj[y_qty].size:
        ax.scatter(missed_inj[x_qty], missed_inj[y_qty],
                   c="black", marker="x", s=10)
    if vetoed[x_qty].size:
        ax.scatter(vetoed[x_qty], vetoed[y_qty], c="red", marker="x", s=10)
    if found_louder[x_qty].size:
        ax.scatter(found_louder[x_qty], found_louder[y_qty],
                   c=fnd_col, marker="+", s=30)
    if found_quieter[x_qty].size:
        p = ax.scatter(found_quieter[x_qty][FM], found_quieter[y_qty][FM],
                       c=found_quieter['fap'][FM],
                       cmap=cmap, vmin=0, vmax=1, s=40,
                       edgecolor="w", linewidths=2.0)
        cb = plt.colorbar(p, label="p-value")
elif opts.missed_on_top:
    if found_louder[x_qty].size:
        ax.scatter(found_louder[x_qty], found_louder[y_qty],
                   c=fnd_col, marker="+", s=15)
    if found_quieter[x_qty].size:
        p = ax.scatter(found_quieter[x_qty][MF], found_quieter[y_qty][MF],
                       c=found_quieter['fap'][MF],
                       cmap=cmap, vmin=0, vmax=1, s=40,
                       edgecolor="w", linewidths=2.0)
        cb = plt.colorbar(p, label="p-value")
    if vetoed[x_qty].size:
        ax.scatter(vetoed[x_qty], vetoed[y_qty], c="red", marker="x", s=40)
    if missed_inj[x_qty].size and missed_inj[y_qty].size:
        ax.scatter(missed_inj[x_qty], missed_inj[y_qty],
                   c="black", marker="x", s=40)
ax.grid()

# Handle axis limits when plotting spins
if "spin" in x_qty and missed_inj['spin1_a'].size:
    max_missed_inj = missed_inj['spin1_a'].max()
    ax.set_xlim([0, np.ceil(10 * max_missed_inj) / 10])
    if found_inj[x_qty].size:
        ax.set_xlim([0, np.ceil(10 * max(max_missed_inj,
                                         found_inj[x_qty].max())) / 10])
if "spin" in y_qty and missed_inj['spin2_a'].size:
    max_missed_inj = missed_inj['spin2_a'].max()
    ax.set_ylim([0, np.ceil(10 * max_missed_inj) / 10])
    if found_inj[y_qty].size:
        ax.set_ylim([0, np.ceil(10 * max(max_missed_inj,
                                         found_inj[y_qty].max())) / 10])

# Handle axis limits when plotting inclination
if "incl" in x_qty or "incl" in y_qty:
    max_inc = np.pi
    # max_inc = max(np.concatenate((g_found[qty], g_ifar[qty], \
    #                               g_missed2[qty], missed_inj[qty])))
    max_inc_deg = np.rad2deg(max_inc)
    max_inc_deg = np.ceil(max_inc_deg/10.0)*10
    max_inc = np.deg2rad(max_inc_deg)
    if x_qty == "incl":
        ax.set_xlim(0, max_inc_deg)
    elif x_qty == "abs_incl":
        ax.set_xlim(0, max_inc_deg*0.5)
    if y_qty == "incl":
        ax.set_ylim(0, max_inc_deg)
    elif y_qty == "abs_incl":
        ax.set_ylim(0, max_inc_deg*0.5)
    # if "cos_incl" in [x_qty, y_qty]:
    if "cos_" in [x_qty, y_qty]:
        # tt = np.arange(0, max_inc_deg + 10, 10)
        tt = np.asarray([0, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120,
                         130, 140, 150, 180])
        tks = np.cos(np.deg2rad(tt))
        tk_labs = [f'cos({tk} deg)' for tk in tt]
        # if x_qty == "cos_incl":
        if "cos_" in x_qty:
            plt.xticks(tks, tk_labs, fontsize=10)
            fig.autofmt_xdate()
            ax.set_xlim(np.cos(max_inc), 1)
            # ax.set_xlim(-1, 1)
        # if y_qty == "cos_incl":
        if "cos_" in y_qty:
            plt.yticks(tks, tk_labs, fontsize=10)
            fig.autofmt_xdate()
            ax.set_ylim(np.cos(max_inc), 1)
            # ax.set_ylim(-1, 1)

# Take care of caption
plot_caption = opts.plot_caption
if plot_caption is None:
    plot_caption = "A black cross indicates no trigger was found "
    plot_caption += "coincident with the injection.\n"
    plot_caption += "A red cross indicates a trigger was found "
    plot_caption += "coincident with the injection but it was vetoed.\n"
    plot_caption += "A yellow plus indicates that a trigger was found "
    plot_caption += "coincident with the injection and it was louder "
    plot_caption += "than all events in the offsource.\n"
    plot_caption += "A coloured circle indicates that a trigger was "
    plot_caption += "found coincident with the injection but it was "
    plot_caption += "not louder than all offsource events. The colour "
    plot_caption += "bar gives the p-value of the trigger."

# Take care of title
plot_title = opts.plot_title
if plot_title is None:
    title_dict = {'mchirp': "chirp mass",
                  'mtotal': "total mass",
                  'q': "mass ratio",
                  'distance': "distance (Mpc)",
                  'eff_dist': "inverse sum of effective distances",
                  'eff_site_dist': "site specific effective distance",
                  'end_time': "time",
                  'coa_phase': "phase of complex SNR",
                  'latitude': "latitude",
                  'longitude': "longitude",
                  'incl': "inclination",
                  'cos_incl': "inclination",
                  'abs_incl': "inclination",
                  'cos_abs_incl': "inclination",
                  'mass1': "mass",
                  'mass2': "mass",
                  'polarization': "polarization",
                  'spin1_a': "spin",
                  'spin1x': "spin x-component",
                  'spin1y': "spin y-component",
                  'spin1z': "spin z-component",
                  'spin2_a': "spin",
                  'spin2x': "spin x-component",
                  'spin2y': "spin y-component",
                  'spin2z': "spin z-component"}

    if "sky_error" in [x_qty, y_qty]:
        plot_title = "Sky error of recovered injections"
    else:
        plot_title = "Injection recovery with respect to "
        plot_title += title_dict[x_qty]
        plot_title += " and " + title_dict[y_qty]

# Wrap up
plt.tight_layout()
pycbc.results.save_fig_with_metadata(fig, outfile, cmd=' '.join(sys.argv),
                                     title=plot_title, caption=plot_caption)
plt.close()
logging.info("Plot complete.")
