#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

# Copyright (C) 2017 Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
"""Extracts posterior samples from a Sampler HDF file, writing to a
posterior HDF file. Parameters can be renamed in the output file using the
--parameters option.

If more than one file is provided, the samples from all of the files will
be combined. Some attempt is made to check that the files share common
attributes. Sampler info (stored in the 'sampler_info' group) will not be
written when multiple files are combined. All other groups, unless explicitly
skipped (see the --skip-groups option), will be written using the first file.
No attempt is made to check that the data in these other groups is the same
across combined files.
"""

import os
import numpy
import pycbc
from pycbc.inference.io import (ResultsArgumentParser, results_from_cli,
                                PosteriorFile, loadfile)
from pycbc.inference.io.base_hdf import format_attr
from scipy.special import logsumexp
import warnings


def isthesame(current_val, val):
    """Checks that attrs values from two hdf files are the same (or nearly so).
    """
    if isinstance(current_val, float):
        issame = numpy.isclose(current_val, val, equal_nan=True)
    elif isinstance(current_val, numpy.ndarray):
        # sort before comparing
        current_val = numpy.sort(current_val)
        val = numpy.sort(val)
        try:
            # in case of floats, just check that they're close
            issame = numpy.isclose(current_val, val, equal_nan=True)
        except TypeError:
            # wasn't float, do direct comparison
            issame = current_val == val
    else:
        issame = current_val == val
    if isinstance(issame, numpy.ndarray):
        issame = issame.all()
    # we didn't have an array, assume it was a single value
    return issame

parser = ResultsArgumentParser(defaultparams='all', autoparamlabels=False,
                               description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--output-file", type=str, required=True,
                    help="Output file to create.")
parser.add_argument("--force", action="store_true", default=False,
                    help="If the output-file already exists, overwrite it. "
                         "Otherwise, an IOError is raised.")
parser.add_argument("--skip-groups", default=None, nargs="+",
                    help="Don't write the specified groups in the output "
                         "(aside from samples; samples are always written), "
                         "for example, 'sampler_info'. If 'all' skip "
                         "all groups, only writing the samples. Default is "
                         "to write all groups if only one file is provided, "
                         "and all groups from the first file except "
                         "sampler_info if multiple files are provided.")
parser.add_argument("--combine-via-sampling", action='store_true', 
                    default=False,
                    help="Specify whether to combine the posteriors by " 
                         "sampling. By default, extract_samples will dump all "
                         "samples from multiple inputs into one output file. "
                         "If this option is specified, the samples will "
                         "be randomly sampled with weighting based on the "
                         "evidence in each input. For example, if the "
                         "evidence of input 1 is twice that of input 2, "
                         "the resulting posterior file with this option "
                         "specified will have twice as many points from input "
                         "1 than input 2. The output will have the same "
                         "number of samples as the smallest input file. "
                         "An error is thrown if this option is specified and "
                         "any of the input files does not have a log_evidence "
                         "attribute (e.g., the file used a sampler like emcee "
                         "that does not report evidence). This option "
                         "assumes that the priors of all files do not "
                         "overlap; we cannot properly combine posteriors if "
                         "the priors do overlap.")
parser.add_argument("--mutually-exclusive-priors", action='store_true',
                    default=False,
                    help="If specifying --combine-via-sampling, specify "
                         "whether to treat priors as mutually exclusive. By "
                         "default, the provided input files are assumed to "
                         "have identical priors, and their evidences will be "
                         "averaged. If this option is specified, the priors "
                         "are assumed to be non-overlapping across files, and "
                         "their evidences will be summed together.")

opts = parser.parse_args()

pycbc.init_logging(opts.verbose)

# check that the output doesn't already exist
if os.path.exists(opts.output_file) and not opts.force:
    raise IOError("output file already exists; use --force if you wish to "
                  "overwrite.")

# load the samples
fps, params, labels, samples = results_from_cli(opts)
if len(opts.input_file) == 1:
    fps = [fps]

# convert samples to a dict in which the keys are the labels
# also stack results if multiple files were provided
if len(opts.input_file) > 1:
    if opts.combine_via_sampling:
        raw_samples = {labels[p]: numpy.concatenate([s[p] for s in samples])
                   for p in params}
        logz_list = []
        dlogz_list = []
        len_list = []
        raw_samps_list = []
        weights_list = []
        for file in opts.input_file:
            fp = loadfile(file, 'r')
            # get evidence from each file if possible
            try:
                logz, dlogz = fp.log_evidence
            except KeyError:
                raise ValueError(f"Cannot combine evidences; file {file} "
                                 "does not have a log_evidence attr")
            logz_list.append(logz)
            dlogz_list.append(dlogz)
            # get samples from each file
            file_samps = fp.read_samples(list(fp['samples'].keys()))
            raw_samps_list.append(file_samps)
            # get the number of samples from each file
            len_list.append(len(file_samps))
        # compute sampling weights from evidences
        logz_net = logsumexp(logz_list)
        len_net = sum(len_list)
        out_size = min(len_list)
        for i in range(len(opts.input_file)):
            # weight each file's samples according to logz
            logwt = logz_list[i] - logz_net
            weights_list.append([numpy.exp(logwt)/len_list[i] for j in 
                                 range(len_list[i])])
        # randomly sample indices from all samples
        weights = numpy.concatenate(weights_list)
        idx = numpy.random.choice(int(len_net), size=out_size, replace=True, 
                                  p=weights)
        samples = {param: raw_samples[param][idx] for param in 
                   raw_samples.keys()}
    else:
        samples = {labels[p]: numpy.concatenate([s[p] for s in samples])
                   for p in params}
else:
    samples = {labels[p]: samples[p] for p in params}
    if opts.combine_via_sampling:
        warnings.warn("Specified combine_via_sampling with only one input "
                      "file. This option will have no effect.")

# create the file
outtype = PosteriorFile.name
out = loadfile(opts.output_file, 'w', filetype=outtype) 

# write the samples
out.write_samples(samples)

# Preserve samples group metadata
for fp in fps:
    for key, val in fp[fp.samples_group].attrs.items():
        val = format_attr(val)
        try:
            current_val = out[out.samples_group].attrs[key]
        except KeyError:
            out[out.samples_group].attrs[key] = val
            current_val = val
        # enforce that the metadata must be the same across multiple files
        if not isthesame(current_val, val):
            raise ValueError("cannot combine all files; samples group attr {} "
                             "is not the same across all files. ({} vs {})"
                             .format(key, current_val, val))

# Preserve top-level metadata...
# ...except for: the filetype, since that's set when the file was loaded;
# the automatic thinning settings, since those will no longer make sense;
# cmd (we will replace it with this program's command)
# resume points (may not be the same across multiple files)
# effective nsamples (that can just be obtained from the samples size)
skip_attrs = ['filetype', 'thin_start', 'thin_interval', 'thin_end',
              'thinned_by', 'cmd', 'resume_points', 'effective_nsamples',
              'run_start_time', 'run_end_time']
# also skip evidence if multiple files are being combined; this will be handled
# via sampling if specified
if len(opts.input_file) > 1:
    skip_attrs += ['log_evidence', 'dlog_evidence']
    
# make sure attrs are the same between files...
cat_params = False
for fp in fps:
    for key in map(format_attr, fp.attrs):
        if key not in skip_attrs:
            val = format_attr(fp.attrs[key])
            try:
                current_val = format_attr(out.attrs[key])
            except KeyError:
                out.attrs[key] = val
                current_val = format_attr(out.attrs[key])
            if not isthesame(current_val, val):
                if key == 'remapped_params':
                    # ...unless it's remapped_params; just save first file's
                    # entries if they don't match between files
                    warnings.warn("WARNING: remapped_params metadata does not "
                                 "match between files; saving metadata from "
                                 "first file")
                else:
                    raise ValueError("cannot combine all files; file attr {} is "
                                     "not the same across all files ({} vs {})"
                                     .format(key, current_val, val))

# store what parameters were renamed (if not already saved above)
if not cat_params:
    out.attrs['remapped_params'] = list(labels.items())

# write combined evidence and dlog evidence if combining via sampling
if opts.combine_via_sampling:
    if len(opts.input_file) == 1:
        # there's only one file; just return what's in the input
        out.attrs['log_evidence'] = fps[0].log_evidence[0]
        out.attrs['dlog_evidence'] = fps[0].log_evidence[1]
    elif opts.mutually_exclusive_priors:
        # add together the evidences; quadrature sum the residuals
        out.attrs['log_evidence'] = logsumexp(logz_list)
        out.attrs['dlog_evidence'] = numpy.sqrt(sum([i**2 for i in dlogz_list]))
    else:
        # average the evidences; quadrature sum and scale the residuals
        n = len(opts.input_file)
        out.attrs['log_evidence'] = logsumexp(logz_list) - numpy.log(n)
        out.attrs['dlog_evidence'] = numpy.sqrt(sum([i**2 for i in dlogz_list])) / n

# write the other groups using the first file
fp = fps[0]
skip_groups = opts.skip_groups
if skip_groups is not None and 'all' in opts.skip_groups:
    skip_groups = [group for group in fp.keys()
                   if group != fp.samples_group]
# don't write the sampler info if more than one file was provided
if len(opts.input_file) > 1:
    if skip_groups is None:
        skip_groups = []
    skip_groups.append('sampler_info')

fp.copy_info(out, ignore=skip_groups)

# close and exit
for fp in fps:
    fp.close()
out.close()
