#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

# Copyright (C) 2025 Alex Correia
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
"""Estimates the Bayes factor using the Savage-Dickey density ratio from a
given posterior file. 

The Savage-Dickey ratio is defined with respect to a single parameter. The
ratio compares two models: a null hypothesis, where the parameter of interest
is fixed to one value, and a model where this parameter is allowed to vary. The
SD ratio is computed by dividing the value of the variable parameter posterior
by the value of the prior at the null point. Qualitatively, this indicates the
probability with which the null hypothesis is true. Bayes factor in favor
of the null hypothesis can be approximated by the Savage-Dickey ratio. We can
invert this to estimate the Bayes factor in favor of the variable model.

As an GW analysis example, say we wish to evaluate the presence of a QNM
overtone in the ringdown of a CBC signal. The null hypothesis states that the
QNM is not present, indicated by a QNM amplitude of zero. We can therefore
estimate the Bayes factor in favor of overtone detection by taking the ratio of
the  amplitude prior to the QNM amplitude posterior at 0.

Inputting multiple files evaluates the SD ratio for each file with respect to
the specified parameter. Only one parameter is accepted, and a null value is
required for the given parameter.
"""

import os
import numpy
import pycbc
from pycbc.inference.io import (ResultsArgumentParser, loadfile)
from pycbc import distributions
from pycbc.distributions.utils import prior_from_config
from pycbc.boundaries import Bounds
from scipy.stats import gaussian_kde

parser = ResultsArgumentParser(defaultparams='all', autoparamlabels=False,
                               description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--parameter-null-value", required=True,
                    help="Value at which to evaluate the posterior and prior "
                         "when calculating the Savage-Dickey ratio.")
parser.add_argument("--kde-samples", type=int, default=50000,
                    help="The number of samples to use for generating the "
                         "prior KDE. Default 50000.")
parser.add_argument("--kde-bandwidth", type=float, default=0.01,
                    help="The bandwidth of the KDEs, which determine how "
                         "'smooth' the KDEs are. Lower values lead to smoother "
                         "KDEs; higher values lead to more 'jagged' KDEs but "
                         "will potentially capture more structure. Use "
                         "'--plot' to aid in fine-tuning this.")
parser.add_argument("--prior-name", type=str, default=None,
                    help="The type of prior distribution to generate. "
                         "Accepts valid prior names from the distributions "
                         "module (e.g. 'uniform', 'uniform_log10', etc.). "
                         "If specifying a name, a min and max must also be "
                         "specified using --prior-min and --prior-max "
                         "respectively. By default, the prior is read from "
                         "the input posterior file(s).")
parser.add_argument("--prior-min", type=float, default=None,
                    help="The minimum value of the prior if specifying "
                         "--prior-name. By default (i.e. if prior-name is not "
                         "specified), this is read from the input file(s).")
parser.add_argument("--prior-max", type=float, default=None,
                    help="The maximum value of the prior if specifying "
                         "--prior-name. By default (i.e. if prior-name is not "
                         "specified), this is read from the input file(s).")
parser.add_argument("--prior-cyclic", type=bool, default=False,
                    help="Specify whether the prior has cyclic bounds if "
                         "specifying --prior-name. By default (i.e. if "
                         "prior-name is not specified), this is read from the "
                         "input file(s).")
parser.add_argument("--reflect-kde-on-left", action="store_true", default=False,
                    help="Specify whether to reflect samples on the left bound "
                         "when generating KDEs. This is done to prevent KDE "
                         "sampling issues at the edges of a histogram. This is "
                         "not done by default.")
parser.add_argument("--reflect-kde-on-right", action="store_true", default=False,
                    help="Specify whether to reflect samples on the right bound "
                         "when generating KDEs. This is done to prevent KDE "
                         "sampling issues at the edges of a histogram. This is "
                         "not done by default.")
parser.add_argument("--plot-name", type=str, default=None, nargs='+',
                    help="Specify the name(s) of the output plot. If a name is "
                         "specified, a plot will be generated showing the "
                         "prior and posterior KDEs and samples for each file. "
                         "By default, a plot is not generated. One name is "
                         "required per file specified by '--input-file'.")

opts = parser.parse_args()
pycbc.init_logging(opts.verbose)

# read in the posterior of given parameter
if len(opts.parameters) > 1:
    raise ValueError("Multiple parameters specified. Only one parameter is "
                     "accepted at a time")

param = opts.parameters[0]
null = float(opts.parameter_null_value)

print(f"Calculating Bayes factor for selected parameter {param} versus {param} "
      f"at null value {null}")
print("======================================================================")

priors = []
posteriors = []
for file in opts.input_file:
    fp = loadfile(file, 'r')
    cp = fp.read_config_file()

    # read in the prior of given parameter
    if opts.prior_name is not None:
        if opts.prior_min is None:
            raise ValueError("Must specify --prior-min when using --prior-name")
        elif opts.prior_max is None:
            raise ValueError("Must specify --prior-max when using --prior-name")
        else:
            bounds = Bounds(min_bound = opts.prior_min,
                            max_bound = opts.prior_max,
                            cyclic = opts.prior_cyclic)
        prior_name = opts.prior_name
    else:
        if opts.prior_min is not None:
            raise ValueError("Must specify --prior-name when using --prior-min")
        if opts.prior_max is not None:
            raise ValueError("Must specify --prior-name when using --prior-max")
        full_prior = prior_from_config(cp)
        bounds = full_prior.bounds[param]
        prior_name = cp.get_opt_tag('prior', 'name', param)

    kws = {param: bounds}
    prior = distributions.distribs[prior_name](**kws)
    prior_samples = prior.rvs(opts.kde_samples)[param]

    # read in posterior samples
    pos_samples = fp.read_samples([param])[param]

    # if in log scale, rescale the distributions
    if 'log' in prior.name:
        pos = numpy.log10(pos_samples)
        pri = numpy.log10(prior_samples)
        null = numpy.log10(null)
        bounds = Bounds(min_bound = numpy.log10(bounds.min),
                        max_bound = numpy.log10(bounds.max),
                        cyclic = bounds.cyclic)
    else:
        pos = pos_samples
        pri = prior_samples

    # reflect at bounds if specified
    net_pos = pos
    net_pri = pri
    if opts.reflect_kde_on_left:
        left_pos = 2*bounds.min - pos
        net_pos = numpy.append(net_pos, left_pos)
        left_pri = 2*bounds.min - pri
        net_pri = numpy.append(net_pri, left_pri)
    if opts.reflect_kde_on_right:
        right_pos = 2*bounds.max - pos
        net_pos = numpy.append(net_pos, right_pos)
        right_pri = 2*bounds.max - pri
        net_pri = numpy.append(net_pri, right_pri)

    # generate KDEs
    bw = opts.kde_bandwidth
    pos_pdf = gaussian_kde(net_pos, bw_method=bw)
    prior_pdf = gaussian_kde(net_pri, bw_method=bw)

    # normalize KDEs
    pos_norm = pos_pdf.integrate_box_1d(bounds.min, bounds.max)
    prior_norm = prior_pdf.integrate_box_1d(bounds.min, bounds.max)
    B_norm = pos_norm / prior_norm

    # evaluate the SD ratio at the specified null value
    B = prior_pdf(null) / pos_pdf(null)
    B *= B_norm

    # plotting
    if opts.plot_name is not None:
        if len(opts.plot_name) != len(opts.input_file):
            raise KeyError(f"Number of plot names ({opts.plot_name}) "
                           f"does not match number of input files "
                           f"({opts.input_file})")
        import matplotlib.pyplot as plt

        if "log" not in prior.name:
            # non-log distributions
            vals, bins, _ = plt.hist(pos, bins=100, density=True, alpha=0.4,
                                     label="Posterior Samples")
            plt.hist(pri, bins=bins, density=True, alpha=0.4,
                     label="Prior Samples")
            plt.plot(bins, pos_pdf(bins)/pos_norm, color='green',
                     label="Posterior KDE")
            plt.plot(bins, prior_pdf(bins)/prior_norm, color='red',
                     label="Prior KDE")
        else:
            # special handling for log distributions
            bins = numpy.logspace(bounds.min, bounds.max, 100)
            log_bins = numpy.log10(bins)
            plt.hist(10**pos, bins=bins, density=True, alpha=0.4,
                     label="Posterior Samples")
            plt.hist(10**pri, bins=bins, density=True, alpha=0.4,
                     label="Prior Samples")
            # convert log scale KDEs to linear scale
            # p(x) = p(y) / (x ln 10), where y = log x
            jac = 1 / bins / numpy.log(10)
            pos_plot = pos_pdf(log_bins) * jac
            pri_plot = prior_pdf(log_bins) * jac
            pos_plot_norm = numpy.trapezoid(pos_plot, bins)
            pri_plot_norm = numpy.trapezoid(pri_plot, bins)
            plt.plot(bins, pos_plot/pos_plot_norm, color='green',
                     label="Posterior KDE")
            plt.plot(bins, pri_plot/pri_plot_norm, color='red',
                     label="Prior KDE")
            plt.xscale('log')

        f = os.path.basename(file)
        plt.title(f"Posterior and Prior of {param} from {f}")
        plt.xlabel(f"{param}")
        plt.legend()
        idx = opts.input_file.index(file)
        plt.savefig(f'{opts.plot_name[idx]}')
        plt.close() 

    # print output
    print(f"File: {file}")
    if B > 1e-3:
        print(f"Bayes factor: {B[0]:.3f}")
    else:
        print(f"Bayes factor: {B[0]:.5e}")
    print("")

