#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1767619825/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python
import sys, h5py
import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot as plt
import numpy as np
import argparse, logging
from pycbc import init_logging, add_common_pycbc_options, results
from pycbc.events import significance
from pycbc.conversions import sec_to_year as convert_s_to_y

parser = argparse.ArgumentParser(usage='pycbc_page_ifar_vs_stat [--options]',
                    description='Plots cumulative IFAR vs stat for'
                          'backgrounds of different event types.')
add_common_pycbc_options(parser)
# get singles fit options
significance.insert_significance_option_group(parser)
parser.add_argument('--trigger-files', nargs='+',
                    help='Paths to separate-event-type statmap files.')
parser.add_argument('--ifo-combos', nargs='+',
                    help='Which event types (detector combinations) to plot.')
parser.add_argument('--min-x', type=float, default=-15.,
                    help='X axis limit.')
parser.add_argument('--max-x', type=float, default=25.,
                    help='X axis limit.')
parser.add_argument('--output-file', required=True,
                    help='Path to output plot.')
opts = parser.parse_args()
init_logging(opts.verbose)

sig_dict = significance.digest_significance_options(opts.ifo_combos, opts)

coinc_color_map = {
    'H':  '#ee0000',  # red
    'L':  '#4ba6ff',  # blue
    'V':  '#9b59b6',  # magenta/purple
    'HL': '#0000DD',
    'LV': '#FFCC33',
    'HV': '#990000',
    'HLV': np.array([20, 120, 20]) / 255.}

coinc_line_map = {
    'H':  '--',
    'L':  '--',
    'V':  '--',
    'HL': '-',
    'LV': '-',
    'HV': '-',
    'HLV':'-'
}

fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(111)

ifar_calc_points = np.linspace(opts.min_x, opts.max_x, 200)

fitted = []  # List of combos for which a fit is used
for f in opts.trigger_files:
    logging.info(f"Opening {f}")
    with h5py.File(f, 'r') as exc_zlag_f:
        logging.info(f"Ifos {exc_zlag_f.attrs['ifos']}")  # eg 'H1 L1'
        sig_key = exc_zlag_f.attrs['ifos'].replace(' ', '')  # eg 'H1L1'
        coinc_key = sig_key.replace('1', '')  # eg 'HL'
        stat_exc = exc_zlag_f['background_exc']['stat'][:]
        dec_fac_exc = exc_zlag_f['background_exc']['decimation_factor'][:]
        bg_time = convert_s_to_y(exc_zlag_f.attrs['background_time_exc'])
        ifar_file = exc_zlag_f['background_exc']['ifar'][:]
    logging.info(f"{coinc_key} background time (y): {bg_time}")

    # If fitting / extrapolation is done
    if sig_dict[sig_key]['method'] == 'trigger_fit':
        fitted.append(sig_key)
        fit_thresh = sig_dict[sig_key]['fit_threshold']
        far_tuple = significance.get_far(stat_exc,
                                         ifar_calc_points,
                                         dec_fac_exc,
                                         bg_time,
                                         **sig_dict[sig_key])
        far = far_tuple[1]  # Version-proof: get_far returns 2 or 3 values

        # Plot n-louder points with a solid line and sngl fits with dashed
        ax.plot(ifar_calc_points[ifar_calc_points < fit_thresh],
                far[ifar_calc_points < fit_thresh],
                '-', c=coinc_color_map[coinc_key], zorder=-5)
        ax.plot(ifar_calc_points[ifar_calc_points > fit_thresh],
                far[ifar_calc_points > fit_thresh],
                coinc_line_map[coinc_key],
                c=coinc_color_map[coinc_key], label=coinc_key, zorder=-5)
        del fit_thresh  # avoid variable hanging around

    else:
        far_tuple = significance.get_far(stat_exc,
                                      ifar_calc_points,
                                      dec_fac_exc,
                                      bg_time,
                                      method='n_louder')

        ax.plot(ifar_calc_points, far_tuple[1], '-',
                c=coinc_color_map[coinc_key], label=coinc_key, zorder=-5)

    del sig_key, coinc_key, far_tuple  # avoid variables hanging around

# Plot the thresholds
for combo in fitted:
    ax.axvline(sig_dict[combo]['fit_threshold'], ls='-.', zorder=-10,
               c=coinc_color_map[combo.replace('1', '')])

ax.semilogy()
ax.legend(ncol=2, fontsize=12)
ax.grid()
ax.set_xlim([opts.min_x, opts.max_x])
ax.set_ylim([1e-3, 1e6])
ax.tick_params(axis='both', labelsize=14)
ax.set_xlabel('Ranking Statistic', fontsize=16)
ax.set_ylabel('Cumulative Event Rate [y$^{-1}$]', fontsize=16)

# Save
caption = 'Cumulative rates of noise events for separate event types. Solid ' \
          'lines represent estimates from counts of louder (zerolag or time ' \
          'shifted) events. Dashed lines represent estimates from fitting / ' \
          'extrapolation.'

results.save_fig_with_metadata(fig, opts.output_file,
     fig_kwds={'bbox_inches': 'tight'},
     title='Cumulative Rate vs. Statistic',
     caption=caption,
     cmd=' '.join(sys.argv))
