#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

# Copyright (C) 2021 Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Determine efficiency and exclusion distances of a PyGRB run."""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import json
import matplotlib.pyplot as plt
from matplotlib import rc
import numpy as np
import scipy
from scipy import stats

import pycbc.version
from pycbc import init_logging
from pycbc.detector import Detector
from pycbc.results import save_fig_with_metadata
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.conversions import mchirp_from_mass1_mass2
from pycbc.io.hdf import HFile

plt.switch_backend('Agg')
rc("image")

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_efficiency"


def efficiency_with_errs(found_bestnr, num_injections, num_mc_injs=0):
    """Function to calculate the fraction of recovered injections and its
    error bars (used for efficiency/sensitive distance plots)."""

    if not isinstance(num_mc_injs, int):
        err_msg = "The parameter num_mc_injs is the number of Monte-Carlo "
        err_msg += "injections.  It must be an integer."
        raise TypeError(err_msg)

    only_found_injs = found_bestnr[:-1]
    all_injs = num_injections[:-1]
    fraction = only_found_injs / all_injs

    # Divide by Monte-Carlo iterations
    if num_mc_injs:
        only_found_injs = only_found_injs / num_mc_injs
        all_injs = all_injs / num_mc_injs

    err_common = all_injs * (2 * only_found_injs + 1)
    err_denom = 2 * all_injs * (all_injs + 1)
    err_vary = 4 * all_injs * only_found_injs * (all_injs - only_found_injs) \
        + all_injs**2
    err_vary = err_vary**0.5
    err_low = (err_common - err_vary)/err_denom
    err_low_mc = fraction - err_low
    err_high = (err_common + err_vary)/err_denom
    err_high_mc = err_high - fraction

    # Check for cases where error bars are negative and set them to zero
    if (err_low_mc < 0).any():
        logging.warning("Negative lower error bar(s) detected."
                        " Setting to zero.")
        err_low_mc[err_low_mc < 0] = 0
    if (err_high_mc < 0).any():
        logging.warning("Negative upper error bar(s) detected."
                        " Setting to zero.")
        err_high_mc[err_high_mc < 0] = 0

    return err_low_mc, err_high_mc, fraction


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__)
parser.add_argument("-F", "--trig-file", action="store", required=True,
                    help="Location of off-source trigger file.")
parser.add_argument("--onsource-file", action="store",
                    help="Location of on-source trigger file (or a " +
                    "background trigger file to be treated as such).")
parser.add_argument("--background-output-file",
                    help="Detection efficiency output file.")
parser.add_argument("--onsource-output-file",
                    help="Exclusion distance output file.")
parser.add_argument("--exclusion-dist-output-file",
                    help="JSON file containing exclusion distances.")
parser.add_argument("-g", "--glitch-check-factor", action="store",
                    type=float, default=1.0, help="When deciding " +
                    "exclusion efficiencies this value is multiplied " +
                    "to the offsource around the injection trigger to " +
                    "determine if it is just a loud glitch.")
parser.add_argument("--found-missed-file", action="store", type=str,
                    required=True, help="Location of found/missed " +
                    "injections and trigger file")
parser.add_argument("--injection-set-name", action="store", type=str,
                    default="", help="Name of the injection set to be " +
                    "used in the plot title.")
parser.add_argument("--trial-name", action="store", type=str,
                    required=True, help="Name of trial used " +
                    "for this run (i.e. ONSOURCE, OFFTRIAL)")
parser.add_argument("-C", "--cluster-window", action="store", type=float,
                    default=0.1, help="The cluster window used " +
                    "to cluster triggers in time.")
parser.add_argument("--bank-file", action="store", type=str, required=True,
                    help="Location of the full template bank used.")
ppu.pygrb_add_injmc_opts(parser)
ppu.pygrb_add_bestnr_cut_opt(parser)
ppu.pygrb_add_slide_opts(parser)
opts = parser.parse_args()
ppu.slide_opts_helper(opts)

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Load bank file
bank_file = HFile(opts.bank_file, 'r')

# Check options
if opts.exclusion_dist_output_file is not None or \
        opts.onsource_output_file is not None:
    if opts.onsource_file is None:
        logging.error("Requesting the --exclusion-dist-output-file or " +
                      "the --onsource-output-file requires the " +
                      "--onsource-file as nput.")

# Store options used multiple times in local variables
trig_file = opts.trig_file
onsource_file = opts.onsource_file
found_missed_file = opts.found_missed_file
veto_file = opts.veto_file
inj_set_name = opts.injection_set_name
wf_err = opts.waveform_error
cal_errs = {}
cal_errs['G1'] = opts.g1_cal_error
cal_errs['H1'] = opts.h1_cal_error
cal_errs['K1'] = opts.k1_cal_error
cal_errs['L1'] = opts.l1_cal_error
cal_errs['V1'] = opts.v1_cal_error
cal_dc_errs = {}
cal_dc_errs['G1'] = opts.g1_dc_cal_error
cal_dc_errs['H1'] = opts.h1_dc_cal_error
cal_dc_errs['K1'] = opts.k1_dc_cal_error
cal_dc_errs['L1'] = opts.l1_dc_cal_error
cal_dc_errs['V1'] = opts.v1_dc_cal_error
# pycbc_multi_inspiral already applies sngl, coinc and null SNR cuts
new_snr_thresh = opts.newsnr_threshold
upper_dist = opts.upper_inj_dist
lower_dist = opts.lower_inj_dist
num_bins = opts.num_bins
wav_err = opts.waveform_error
cluster_window = opts.cluster_window
glitch_check_fac = opts.glitch_check_factor
num_mc_injs = opts.num_mc_injections
# Initialize random number generator
np.random.seed(opts.seed)
logging.info("Setting random seed to %d.", opts.seed)

# Set output directories
outdir = None
for output_file in [opts.exclusion_dist_output_file,
                    opts.background_output_file, opts.onsource_output_file]:
    if output_file is not None:
        outdir = os.path.split(os.path.abspath(output_file))[0]
        if not os.path.isdir(outdir):
            logging.info("Creating the output directory %s.", outdir)
            os.makedirs(outdir)

# Extract IFOs
ifos = ppu.extract_ifos(trig_file)

# Generate time-slides dictionary
slide_dict = ppu.load_time_slides(trig_file)

# Generate segments dictionary
segment_dict = ppu.load_segment_dict(trig_file)

# Construct trials removing vetoed times
trial_dict, total_trials = ppu.construct_trials(
    opts.seg_files,
    segment_dict,
    ifos,
    slide_dict,
    veto_file
)

# Load triggers (apply reweighted SNR cut, not vetoes)
all_off_trigs = ppu.load_data(trig_file, ifos, data_tag='offsource',
                              rw_snr_threshold=opts.newsnr_threshold,
                              slide_id=opts.slide_id)

# Extract needed trigger properties and store them as dictionaries
# Based on trial_dict: if vetoes were applied, trig_* are the veto survivors
keys = ['network/end_time_gc', 'network/reweighted_snr']
trig_data = ppu.extract_trig_properties(
    trial_dict,
    all_off_trigs,
    slide_dict,
    segment_dict,
    keys
)

# Max BestNR values in each trial: these are stored in a dictionary keyed
# by slide_id, as arrays indexed by trial number
background = {k: np.zeros(len(v)) for k, v in trial_dict.items()}
for slide_id in slide_dict:
    trig_times = trig_data[keys[0]][slide_id]
    for j, trial in enumerate(trial_dict[slide_id]):
        # True whenever the trigger is in the trial
        trial_cut = (trial[0] <= trig_times) & (trig_times < trial[1])
        # Move on if nothing was in the trial
        if not trial_cut.any():
            continue
        # Max BestNR
        background[slide_id][j] = max(trig_data[keys[1]][slide_id][trial_cut])

# Max and median values of reweighted SNR,
# and sorted (loudest in trial) reweighted SNR values
max_bestnr, median_bestnr, sorted_bkgd =\
    ppu.max_median_stat(slide_dict, background, trig_data[keys[1]],
                        total_trials)
assert total_trials == len(sorted_bkgd)

logging.info("Background bestNR calculated.")

# Output details of loudest offsouce triggers: only triggers compatible
# with the trial_dict are considered
offsource_trigs = []
sorted_off_trigs = ppu.sort_trigs(
    trial_dict,
    all_off_trigs,
    slide_dict,
    segment_dict
)
for slide_id in slide_dict:
    offsource_trigs.extend(
        zip(trig_data[keys[1]][slide_id], sorted_off_trigs[slide_id])
    )
offsource_trigs.sort(key=lambda element: element[0])
offsource_trigs.reverse()


# Calculate chirp masses of templates in bank
logging.info('Reading template chirp masses')
with HFile(opts.bank_file, 'r') as bank_file:
    template_mchirps = mchirp_from_mass1_mass2(
        bank_file['mass1'][:],
        bank_file['mass2'][:]
    )

# =======================
# Load on source triggers
# =======================
if onsource_file:

    logging.info("Processing onsource.")
    # Load onsoource triggers (apply reweighted SNR cut, not vetoes)
    on_trigs = ppu.load_data(onsource_file, ifos, data_tag=None,
                             rw_snr_threshold=opts.newsnr_threshold,
                             slide_id=0)

    # Calculate chirp mass values
    on_mchirp = template_mchirps[on_trigs['network/template_id']]

    # Set loudest event arrays
    loud_on_bestnr = 0

    # Retrieve BestNRs and record loudest trig by BestNR.
    # Get indices of loudest triggers and pick the loudest.
    if on_trigs and on_trigs['network/reweighted_snr'].size > 0:
        loud_on_bestnr_idx = np.argmax(on_trigs['network/reweighted_snr'])
        loud_on_bestnr = np.max(on_trigs['network/reweighted_snr'])

    # Convert to float for output
    loud_on_bestnr = float(loud_on_bestnr)

    # If the loudest event has bestnr = 0, there is no event at all!
    if loud_on_bestnr == 0:
        loud_on_bestnr_idx = None
        loud_on_fap = 1

    logging.info("Onsource analysed.")

    if loud_on_bestnr_idx is not None:
        loud_on_fap = sum(sorted_bkgd > loud_on_bestnr) / total_trials


# ====================
# Load injections data
# ====================
# injs contains found/missed injections AND triggers they generated
# The reweighted SNR cut is applied, vetoes are not
injs = ppu.load_data(found_missed_file, ifos, data_tag='injs',
                     rw_snr_threshold=opts.newsnr_threshold,
                     slide_id=0)

# Gather injections that were not missed
found_inj = {}
for k in injs.keys():
    if 'missed' not in k:
        found_inj[k] = injs[k]

# Separate them in found surviving vetoes and found but vetoed
found_after_vetoes, vetoed, *_ = ppu.apply_vetoes_to_found_injs(
    found_missed_file,
    found_inj,
    ifos,
    veto_file=veto_file
)

# ===================================================================
# Post-process injections: skip this if there are no found injections
# ===================================================================
if len(found_after_vetoes['network/template_id']):
    # Calculate quantities not included in trigger files, such as chirp mass
    found_trig_mchirp = \
        template_mchirps[found_after_vetoes['network/template_id']]

    # Construct conditions for injection:
    # 1) found (surviving vetoes) louder than background,
    zero_fap = \
        np.zeros(len(found_after_vetoes['network/end_time_gc'])).astype(bool)
    zero_fap_cut = found_after_vetoes['network/reweighted_snr'] > max_bestnr
    zero_fap = zero_fap | (zero_fap_cut)

    # 2) found (bestnr>0, and surviving vetoes) but not louder than background
    nonzero_fap = ~zero_fap & (found_after_vetoes['network/reweighted_snr']
                               != 0)

    # 3) missed after being recovered (i.e., vetoed) are in vetoed

    # Non-zero FAP triggers (g_ifar)
    g_ifar = {}
    g_ifar['bestnr'] = \
        found_after_vetoes['network/reweighted_snr'][nonzero_fap]
    g_ifar['stat'] = np.zeros([len(g_ifar['bestnr'])])
    for ix, (mc, bestnr) in \
            enumerate(zip(found_trig_mchirp[nonzero_fap], g_ifar['bestnr'])):
        g_ifar['stat'][ix] = (sorted_bkgd > bestnr).sum()
    g_ifar['stat'] = g_ifar['stat'] / total_trials

    # Set the sigma values
    inj_sigma = {ifo: found_after_vetoes[f'{ifo}/sigmasq'][:] for ifo in ifos}
    # If the sigmasqs are not populated, we can still do calibration errors,
    # but only in the 1-detector case
    for ifo in ifos:
        if sum(inj_sigma[ifo] == 0):
            logging.info("%s: sigmasq not set for at least one trigger.", ifo)
        if sum(inj_sigma[ifo] != 0) == 0:
            logging.info("%s: sigmasq not set for any trigger.", ifo)
            if len(ifos) == 1:
                msg = "Since this is a single ifo analysis, sigmasq will be "
                msg += "set to unity for all triggers in order to build the "
                msg += "calibration errors."
                logging.info(msg)
                inj_sigma[ifo][:] = 1.

    f_resp = {}
    for ifo in ifos:
        antenna = Detector(ifo)
        f_resp[ifo] = ppu.get_antenna_responses(
            antenna,
            found_after_vetoes['found/ra'][:],
            found_after_vetoes['found/dec'][:],
            found_after_vetoes['found/tc'][:])

    inj_sigma_mult = (np.asarray(list(inj_sigma.values())) *
                      np.asarray(list(f_resp.values())))

    inj_sigma_tot = inj_sigma_mult[0, :]
    for i in range(1, len(ifos)):
        inj_sigma_tot += inj_sigma_mult[i, :]

    inj_sigma_mean = {}
    for ifo in ifos:
        inj_sigma_mean[ifo] = \
            ((inj_sigma[ifo]*f_resp[ifo])/inj_sigma_tot).mean()

    msg = f"{len(found_after_vetoes['found/tc'])} injections found and "
    msg += f"surviving vetoes and {len(injs['missed/tc'])} missed injections "
    msg += "analysed."
    logging.info(msg)

    # Create new set of injections for efficiency calculations:
    # these are as many as the original injections
    total_injs = len(injs['found/distance']) + len(injs['missed/distance'])
    long_inj = {}
    long_inj['dist'] = stats.uniform.rvs(size=total_injs) * \
        (upper_dist-lower_dist) + upper_dist

    logging.info("%d long distance injections created.", total_injs)

    # Set distance bins and data arrays
    dist_bins = zip(np.arange(lower_dist, upper_dist + (upper_dist-lower_dist),
                              (upper_dist-lower_dist)/num_bins),
                    np.arange(lower_dist, upper_dist + (upper_dist-lower_dist),
                              (upper_dist-lower_dist)/num_bins) +
                             (upper_dist-lower_dist)/num_bins)
    dist_bins = list(dist_bins)

    num_dist_bins_plus_one = len(dist_bins) + 1
    num_injections = {}
    found_max_bestnr = {}
    found_on_bestnr = {}
    for key in ['mc', 'no_mc']:
        num_injections[key] = np.zeros(num_dist_bins_plus_one)
        found_max_bestnr[key] = np.zeros(num_dist_bins_plus_one)
        found_on_bestnr[key] = np.zeros(num_dist_bins_plus_one)

    # Construct FAP list for all found injections
    inj_fap = np.zeros(len(found_after_vetoes['found/distance']))
    inj_fap[nonzero_fap] = g_ifar['stat']

    # Calculate the amplitude error
    # Begin by calculating the components from each detector
    cal_error = 0
    for ifo in ifos:
        cal_error += cal_errs[ifo]**2 * inj_sigma_mean[ifo]**2
    cal_error = cal_error**0.5

    max_dc_cal_error = max(cal_dc_errs.values())

    # Calibration phase uncertainties are neglected
    logging.info("Calibration amplitude uncertainty calculated.")

    # Now create the numbers for the efficiency plots; these include
    # calibration and waveform errors. These are incorporated by running over
    # each injection num_mc_injs times, where each time we draw a random value
    # of distance

    # Distribute injections
    # NOTE: the loop on num_mc_injs would fill up the *_inj['dist_mc']'s at the
    # same time, so filling them up sequentially will vary the numbers a little
    # (this is an MC, order of operations matters!)
    found_inj_dist_mc = ppu.mc_cal_wf_errs(
        num_mc_injs,
        found_after_vetoes['found/distance'],
        cal_error,
        wav_err,
        max_dc_cal_error
    )
    missed_inj_dist_mc = ppu.mc_cal_wf_errs(
        num_mc_injs,
        np.concatenate((vetoed['found/distance'], injs['missed/distance'])),
        cal_error,
        wav_err,
        max_dc_cal_error
    )
    long_inj['dist_mc'] = ppu.mc_cal_wf_errs(num_mc_injs,
                                             long_inj['dist'],
                                             cal_error,
                                             wav_err,
                                             max_dc_cal_error)

    logging.info("MC injection set distributed with %d iterations.",
                 num_mc_injs)

    # Check injections against on source
    if onsource_file:
        more_sig_than_onsource = (inj_fap <= loud_on_fap)
    else:
        more_sig_than_onsource = (inj_fap <= 0.5)

    distance_count = np.zeros(len(dist_bins))

    found_trig_max_bestnr = \
        np.empty(len(found_after_vetoes['network/event_id']))
    found_trig_max_bestnr.fill(max_bestnr)

    max_bestnr_cut = \
        (found_after_vetoes['network/reweighted_snr'] > found_trig_max_bestnr)

    # Check louder than on source
    found_trig_loud_on_bestnr = \
        np.empty(len(found_after_vetoes['network/event_id']))
    if onsource_file:
        found_trig_loud_on_bestnr.fill(loud_on_bestnr)
    else:
        found_trig_loud_on_bestnr.fill(median_bestnr)
    on_bestnr_cut = found_after_vetoes['network/reweighted_snr'] > \
        found_trig_loud_on_bestnr

    # Check whether injection is found for the purposes of exclusion
    # distance calculation.
    # Found: if louder than all on source
    # Missed: if not louder than loudest on source
    found_excl = on_bestnr_cut & (more_sig_than_onsource) & \
        (found_after_vetoes['network/reweighted_snr'] != 0)
    # If not missed, double check bestnr against nearby triggers
    near_test = np.zeros((found_excl).sum()).astype(bool)
    for j, (t, bestnr) in enumerate(zip(
            found_after_vetoes['found/tc'][found_excl],
            found_after_vetoes['network/reweighted_snr'][found_excl])):
        # 0 is the zero-lag timeslide
        near_bestnr = trig_data[keys[1]][0][np.abs(trig_data[keys[0]][0]-t) <
                                            cluster_window]
        near_test[j] = ~((near_bestnr * glitch_check_fac > bestnr).any())
    # Apply the local test
    c = 0
    for z, b in enumerate(found_excl):
        if found_excl[z]:
            found_excl[z] = near_test[c]
            c += 1

    # Loop over each random instance of the injection set
    for k in range(num_mc_injs+1):
        # Loop over the distance bins
        for j, dist_bin in enumerate(dist_bins):
            # Construct distance cut
            found_dist_cut = (dist_bin[0] <= found_inj_dist_mc[k, :]) &\
                             (found_inj_dist_mc[k, :] < dist_bin[1])
            missed_dist_cut = (dist_bin[0] <= missed_inj_dist_mc[k, :]) &\
                              (missed_inj_dist_mc[k, :] < dist_bin[1])
            long_dist_cut = (dist_bin[0] <= long_inj['dist_mc'][k, :]) &\
                            (long_inj['dist_mc'][k, :] < dist_bin[1])

            # Count all injections in this distance bin
            num_found_pass = (found_dist_cut).sum()
            num_missed_pass = (missed_dist_cut).sum()
            num_long_pass = long_dist_cut.sum() or 0
            # Count only zero FAR injections
            num_zero_far = (found_dist_cut & max_bestnr_cut).sum()
            # Count number found for exclusion
            num_excl = (found_dist_cut & (found_excl)).sum()

            # Record number of injections, number found for exclusion
            # and number of zero FAR
            if k == 0:
                key = 'no_mc'
            else:
                key = 'mc'
            num_pass = num_found_pass + num_missed_pass + num_long_pass
            num_injections[key][j] += num_pass
            num_injections[key][-1] += num_pass
            found_max_bestnr[key][j] += num_zero_far
            found_max_bestnr[key][-1] += num_zero_far
            found_on_bestnr[key][j] += num_excl
            found_on_bestnr[key][-1] += num_excl

    logging.info("Found/missed injection efficiency calculations completed.")

    # Post-processing of injections ends here

# ==========
# Make plots
# ==========
logging.info("Plotting.")

# Calculate distances (as means) for the horizontal axis, only if there are
# found injections
if len(found_after_vetoes['network/template_id']):
    dist_plot_vals = [np.asarray(dist_bin).mean()
                      for dist_bin in dist_bins]

# Plot efficiency using loudest background
if opts.background_output_file:
    fig = plt.figure()
    ax = fig.gca()

    # Without found injections, do not determine these quantities to plot
    if len(found_after_vetoes['network/template_id']):
        # Calculate error bars for efficiency/distance plots and datafiles
        # using max BestNR of background
        yerr_low_mc, yerr_high_mc, fraction_mc = efficiency_with_errs(
             found_max_bestnr['mc'],
             num_injections['mc'],
             num_mc_injs=num_mc_injs)
        yerr_low_no_mc, yerr_high_no_mc, fraction_no_mc = efficiency_with_errs(
             found_max_bestnr['no_mc'],
             num_injections['no_mc'])

        ax.plot(dist_plot_vals, (fraction_no_mc), 'g-',
                label='No marginalisation')
        ax.errorbar(dist_plot_vals, (fraction_no_mc),
                    yerr=[yerr_low_no_mc, yerr_high_no_mc], c='green')
        marg_eff = fraction_mc
        if np.nansum(marg_eff) > 0:
            ax.plot(dist_plot_vals, marg_eff, 'r-', label='Marginalised')
            ax.errorbar(dist_plot_vals,
                        marg_eff,
                        yerr=[yerr_low_mc, yerr_high_mc],
                        c='red')
        ax.legend()
    ax.grid()
    ax.set_ylim([0, 1])
    ax.set_xlim(0, 1.3*upper_dist)
    ax.set_ylabel("Fraction of injections found louder than " +
                  "loudest background")
    ax.set_xlabel("Distance (Mpc)")
    plot_title = "Detection efficiency - "+inj_set_name
    plot_caption = "Injection recovery efficiency using "
    plot_caption += "BestNR as detection statistic.  "
    plot_caption += "Injections louder than loudest background trigger."
    fig_path = opts.background_output_file
    save_fig_with_metadata(fig, fig_path, cmd=' '.join(sys.argv),
                           title=plot_title, caption=plot_caption)
    plt.close()

# Calculate and save to disk 50% and 90% exclusion distances
# excl_dist dictionary contains 50% and 90% exclusion distances
# It contains NaNs unless there are found injections
excl_dist = {f"{percentile}%": 'NaN' for percentile in [50, 90]}
if len(found_after_vetoes['network/template_id']):
    # Calculate error bars for efficiency/distance plots and datafiles
    # using max BestNR of foreground
    yerr_low_no_mc, yerr_high_no_mc, fraction_no_mc = efficiency_with_errs(
                                found_on_bestnr['no_mc'],
                                num_injections['no_mc'])
    yerr_low, yerr_high, fraction_mc = \
        efficiency_with_errs(found_on_bestnr['mc'], num_injections['mc'],
                             num_mc_injs=num_mc_injs)

    # Marginalized efficiency (isf = inverse survival function)
    red_efficiency = (fraction_mc) - (yerr_low) * scipy.stats.norm.isf(0.1)

    for percentile in [50, 90]:
        eff_idx = np.where(red_efficiency < (percentile / 100.))[0]
        if eff_idx.size == 0:
            green_efficiency = (fraction_no_mc)
            excl_efficiency = green_efficiency
            eff_idx = np.where(green_efficiency < (percentile / 100.))[0]
        else:
            excl_efficiency = red_efficiency
        if eff_idx.size and eff_idx[0] != 0:
            i = eff_idx[0]
            d = dist_plot_vals[i]
            d_low = dist_plot_vals[i-1]
            e = excl_efficiency[i]
            e_low = excl_efficiency[i-1]
            excl_dist[f"{percentile}%"] = \
                d + (e - (percentile / 100.)) * (d - d_low) / (e_low - e)
        else:
            # Warn the user if the exclusion distance cannot be established,
            # but let the workflow continue: the user will see the plot(s) and
            # repeat with more injections
            excl_dist[f"{percentile}%"] = 0
            msg = "Efficiency below %d%% in first bin!" % (percentile)
            logging.warning(msg)

# Write 50% and 90% exclusion distances to JSON file
# Also include injection set name and trial name
if opts.exclusion_dist_output_file:
    excl_dist['inj_set'] = inj_set_name
    excl_dist['trial_name'] = opts.trial_name
    with open(opts.exclusion_dist_output_file, 'w') as excl_dist_file:
        json.dump(excl_dist, excl_dist_file)

# Plot efficiency using loudest foreground
if opts.onsource_output_file:
    fig = plt.figure()
    ax = fig.gca()
    ax.grid()
    if len(found_after_vetoes['network/template_id']):
        ax.plot(dist_plot_vals, (fraction_no_mc), 'g-',
                label='No marginalisation')
        ax.errorbar(dist_plot_vals, (fraction_no_mc),
                    yerr=[yerr_low_no_mc, yerr_high_no_mc], c='green')
        marg_eff = fraction_mc
        if np.nansum(marg_eff) > 0:
            ax.plot(dist_plot_vals, marg_eff, 'r-', label='Marginalised')
            ax.errorbar(dist_plot_vals, marg_eff, yerr=[yerr_low, yerr_high],
                        c='red')
        if np.nansum(red_efficiency) > 0:
            ax.plot(dist_plot_vals, red_efficiency, 'm-',
                    label='Inc. counting errors')
        ax.set_ylim([0, 1])
        ax.grid()
        ax.legend()
        ax.get_legend().get_frame().set_alpha(0.5)
        ax.grid()
        ax.set_ylim([0, 1])
        ax.set_xlim(0, 1.3*upper_dist)
        ax.plot([excl_dist["90%"]], [0.9], 'gx')
    ax.set_ylabel("Fraction of injections found louder than " +
                  "loudest foreground")
    ax.set_xlabel("Distance (Mpc)")
    ax.set_ylim([0, 1])
    ax.set_xlim(0, 1.3*upper_dist)
    plot_title = "Exclusion distance - "+inj_set_name
    plot_caption = "Injection recovery efficiency using "
    plot_caption += "BestNR as detection statistic.  "
    plot_caption += "Injections louder than loudest foreground trigger.\n"
    plot_caption += f" 90%% exclusion distance: {excl_dist['90%']} Mpc\n"
    plot_caption += f" 50%% sensitive distance: {excl_dist['50%']} Mpc"
    fig_path = opts.onsource_output_file
    save_fig_with_metadata(fig, fig_path, cmd=' '.join(sys.argv),
                           title=plot_title, caption=plot_caption)
    plt.close()

logging.info("Done.")
