#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

# Copyright (C) 2020 Collin D. Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""Performs a PP test on a collection of parameters, writing results out to
an html table.
"""

import logging
import numpy
import sys
from scipy import stats
import pycbc
from pycbc import results
from pycbc.inference.option_utils import add_injsamples_map_opt
from pycbc.inference.io import (ResultsArgumentParser, results_from_cli,
                                injections_from_cli)

parser = ResultsArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--output-file", type=str, required=True,
    help="Path to output plot.")
add_injsamples_map_opt(parser)

# parse the command line
opts = parser.parse_args()

pycbc.init_logging(opts.verbose)

# read results
logging.info('Loading parameters')
_, parameters, labels, samples = results_from_cli(opts)

# typecast to list for iteration
samples = [samples] if not isinstance(samples, list) else samples

logging.info("Getting percentiles of injections")
measured_percentiles = {}
ninjections = len(opts.input_file)
for input_file, input_samples in zip(opts.input_file, samples):
    # load the injections
    opts.input_file = input_file
    inj_parameters = injections_from_cli(opts)

    for p in parameters:
        inj_val = inj_parameters[p]
        sample_vals = input_samples[p]
        measured = stats.percentileofscore(sample_vals, inj_val, kind='weak')
        try:
            measured_percentiles[p].append(measured)
        except KeyError:
            measured_percentiles[p] = []
            measured_percentiles[p].append(measured)

logging.info("Performing percentile-percentile test")
p_values = numpy.zeros(len(parameters))
table = []
fmt = '{:.3f}'
for ii, param in enumerate(parameters):
    row = [labels[param]]
    meas = numpy.array(measured_percentiles[param])
    meas.sort()
    expected = numpy.array([stats.percentileofscore(meas, x, kind='weak')
                            for x in meas])
    # perform ks test
    ks, p = stats.kstest(meas/100., 'uniform')
    p_values[ii] = p
    row += [fmt.format(ks), fmt.format(p)]
    table.append(row)

# do the p-value of p-values test if more than one parameter provided
if len(parameters) > 1:
    summary_ks, summary_p = stats.kstest(p_values, 'uniform')
    table.append(['<b>p-values</b>',
                  '<b>'+fmt.format(summary_ks)+'</b>',
                  '<b>'+fmt.format(summary_p)+'</b>'])
    # for the caption
    extra_caption = (
           ' If all parameters '
           'satisfy the percentile-percentile test, then the distribution of '
           'p-values should be uniform. The p-value of p-values is obtained '
           'by applying the KS test to the distribution of p-values. The '
           'smaller this value, the lower the probability of obtaining the '
           'observed distribution of parameters assuming that the analysis '
           'did satsify the expectation that $X$ percent of injections fall '
           'in the $X$ percent credible interval.')
else:
    extra_caption = ''

# create table header
headers = ["Parameter", "$D_{{KS}}$", "p-value"]

# add mathjax header to display latex
html = results.mathjax_html_header() + '\n%s'%(
    str(results.static_table(table, headers) ))

# add the number of injections that were used
mdatatmplt = '<h4><b>{}:</b> {}</h4>'
metadata = [mdatatmplt.format('Number of injections', ninjections)]

html += '\n'.join(metadata) + '<br />\n<br />\n'

# save HTML table
results.save_fig_with_metadata(
    html, opts.output_file, {},
    cmd=" ".join(sys.argv),
    title="Percentile-Percentile Test",
    caption='Percentile-percentile test results. The value of the KS '
           'statistic $D_{KS}$ gives the maximum distance '
           'between the the measured CDF and the '
           'expected (uniform) CDF. The associated two-tailed p-value gives '
           'the probability of getting a maximum distance larger than the '
           'observed $D_{KS}$ '
           'assuming that the measured CDF is the same as the expected. '
           'In other words, the larger (smaller) the p-value ($D_{KS}$), the '
           'more likely the measured distribution is the same as the '
           'expected.'+extra_caption)
