#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

"""
Make a plot of the histograms of compression factor for a given bank
with compressed waveforms
"""

import argparse
import matplotlib
matplotlib.use("agg")
from matplotlib import pyplot as plt
import numpy as np
import logging
import sys

import pycbc
from pycbc.io import HFile
from pycbc.results import save_fig_with_metadata
from pycbc.inference import option_utils
import pycbc.tmpltbank as tmpltbank

parser = argparse.ArgumentParser()
pycbc.add_common_pycbc_options(parser)
parser.add_argument(
    "--bank-files",
    required=True,
    nargs="+",
    help="Template bank file(s) containing compressed waveforms",
)
parser.add_argument(
    "--output",
    required=True,
    help="Output plot filename",
)
parser.add_argument(
    "--n-bins",
    type=int,
    default=128,
    help="Number of bins used when making the histogram, default=128",
)
parser.add_argument(
    "--log-comparison",
    action="store_true",
    help="Flag to plot comparison values on a log scale",
)
parser.add_argument(
    "--log-compression",
    action="store_true",
    help="Flag to histogram/plot compression factor on a log scale",
)
parser.add_argument(
    "--log-mismatch",
    action="store_true",
    help="Flag to histogram/plot mismatch on a log scale",
)
parser.add_argument(
    "--histogram-density",
    action="store_true",
    help="Flag to indicate that the histogram should be a density "
         "rather than a count per bin",
)
parser.add_argument(
    "--log-histogram",
    action="store_true",
    help="Flag to indicate that histogram values should be plotted "
         "on a log scale"
)
default_param = "template_duration"
parser.add_argument("--comparison-parameter",
    action=option_utils.ParseParametersArg,
    metavar="PARAM[:LABEL]",
    help="Plot the scatter plot of compressin factor versus the given "
         "parameter. Optionally provide a LABEL for use in the plot. "
         "Choose from " + ", ".join(tmpltbank.conversion_options)  + ", "
         "though some options may not be buildable from bank parameters. "
         "If no LABEL is provided, PARAM will used as the LABEL. If LABEL "
         "is the same as a parameter in pycbc.waveform.parameters, the label "
         "property of that parameter will be used. Default: " + default_param
)
args = parser.parse_args()

if args.comparison_parameter is None:
    args.comparison_parameter = default_param
    args.comparison_parameter_labels = {
        default_param: "Template Duration (s)"
    }
elif args.comparison_parameter not in tmpltbank.conversion_options:
    raise parser.error(
        "--comparison-parameter %s not in conversion options %s, see help"
        % (args.comparison_parameter, ', '.join(tmpltbank.conversion_options))
    )

pycbc.init_logging(args.verbose)

comp_label = args.comparison_parameter_labels[args.comparison_parameter]

# Quieten the matplotlib logger
plt.set_loglevel("info" if args.verbose else "warning")
logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR)

logging.info("Getting information from the bank")

# The three things we want from the banks:
compression_factor = []
comparison_values = []
approximants = []
mismatch = []
total_templates = 0
for i, bank_fname in enumerate(args.bank_files):
    logging.debug(
        "Bank %d out of %d: %s",
        i,
        len(args.bank_files),
        bank_fname
    )

    with HFile(bank_fname, "r") as bank_f:
        compressed_grp = bank_f["compressed_waveforms"]
        thashes = bank_f["template_hash"][:]
        total_templates += thashes.size
        valid_idx = np.zeros(thashes.size, dtype=bool)
        compression_thisbank = np.zeros(thashes.size, dtype=float)
        mismatch_thisbank = np.zeros(thashes.size, dtype=float)
        logging.debug(
            "Bank contains %s approximant(s)",
            ','.join(apx.decode() for apx in np.unique(bank_f["approximant"][:]))
        )

        logging.debug("Getting compression factors and mismatches")
        for i, thash in enumerate(thashes):
            try:
                this_grp = compressed_grp[str(thash)]
            except KeyError:
                continue

            compression_thisbank[i] = this_grp.attrs['compression_factor']
            mismatch_thisbank[i] = this_grp.attrs['mismatch']
            valid_idx[i] = True
            
        logging.info(
            "%d out of %d compressed waveforms in %s",
            np.count_nonzero(valid_idx),
            thashes.size,
            bank_fname
        )
        if not any(valid_idx):
            continue
           
        compression_factor += list(compression_thisbank[valid_idx])
        mismatch += list(mismatch_thisbank[valid_idx])

        logging.debug("Getting approximants")
        approximants += [
            apx.decode() for apx in
            bank_f["approximant"][valid_idx]
        ]
        logging.debug("Getting %s", args.comparison_parameter)
        comparison_values += list(tmpltbank.get_bank_property(
            args.comparison_parameter,
            bank_f,
            template_ids=valid_idx
        ))

if not len(compression_factor):
    raise ValueError(
        "No compressed waveforms found in any of the given banks"
    )


approximants = np.array(approximants)
comparison_values = np.array(comparison_values)
compression_factor = np.array(compression_factor)
mismatch = np.array(mismatch)

logging.info(
    "%d out of %d templates have compressed waveforms",
    mismatch.size,
    total_templates
)

# Store the max/min factors, as these are used for setting
# histogram / plot limits
max_factor = compression_factor.max()
min_factor = compression_factor.min()
max_mmatch = mismatch.max()
min_mmatch = mismatch.min()

# Set the bin edges
# Add 5% in either linear or log space to the edges used
# for the bins and plots, so that things arent too cramped
# in the plots
if args.log_compression:
    factor_range = np.log10(max_factor) - np.log10(min_factor)
    factor_max = np.log10(max_factor) + 0.05 * factor_range
    factor_min = max(np.log10(min_factor) - 0.05 * factor_range, 0)
    bin_comp_edges = np.logspace(
        factor_min,
        factor_max,
        args.n_bins
    )
else:
    factor_range = max_factor - min_factor
    factor_min = max(min_factor - 0.05 * factor_range, 1)
    factor_max = max_factor + 0.05 * factor_range
    bin_comp_edges = np.linspace(
        factor_min,
        factor_max,
        args.n_bins
    )

if args.log_mismatch:
    mmatch_range = np.log10(max_mmatch) - np.log10(min_mmatch)
    mmatch_max = np.log10(max_mmatch) + 0.05 * mmatch_range
    mmatch_min = np.log10(min_mmatch) - 0.05 * mmatch_range
    bin_mmatch_edges = np.logspace(
        mmatch_min,
        mmatch_max,
        args.n_bins
    )
else:
    mmatch_range = max_mmatch - min_mmatch
    mmatch_min = min_mmatch - 0.05 * mmatch_range
    mmatch_max = max_mmatch + 0.05 * mmatch_range
    bin_mmatch_edges = np.linspace(
        mmatch_min,
        mmatch_max,
        args.n_bins
    )

# These are used as the x values in the histogram plot;
# this will be each bin edge repeated twice in order,
# except the first and last which will appear once
bin_comp_step = np.concatenate(tuple(zip(bin_comp_edges[:-1], bin_comp_edges[1:])))
bin_mmatch_step = np.concatenate(tuple(zip(bin_mmatch_edges[:-1], bin_mmatch_edges[1:])))

fig, axes = plt.subplots(
    2, 2,
    figsize=(8,8),
    layout="constrained",
    sharex='col',
)

# Make the histogram and scatter points for each approximant separately
apx_names, apx_count = np.unique(approximants, return_counts=True)
apx_colors = {
    apx_name: col for apx_name, col in
    zip(apx_names, plt.rcParams['axes.prop_cycle'].by_key()['color'])
}
for apx_name, apx_count in zip(apx_names, apx_count):
    # Filter to just this approximant
    this_appx = approximants == apx_name

    logging.info("Plotting %s", apx_name)
    logging.info("Making %s compression factor histogram", apx_name)
    comp_hist, _ = np.histogram(
        compression_factor[this_appx],
        bins=bin_comp_edges,
        density=args.histogram_density,
    )
    if args.histogram_density:
        comp_hist *= apx_count / compression_factor.size

    logging.debug("Plotting compression factor histogram")
    axes[0,0].plot(
        bin_comp_step,
        np.repeat(comp_hist, 2),
        c=apx_colors[apx_name],
        label=f'{apx_name}: {apx_count}',
    )

    logging.info("Making %s mismatch histogram", apx_name)
    mmatch_hist, _ = np.histogram(
        mismatch[this_appx],
        bins=bin_mmatch_edges,
        density=args.histogram_density,
    )
    if args.histogram_density:
        mmatch_hist *= apx_count / mismatch.size

    logging.debug("Plotting mismatch histogram")
    axes[0,1].plot(
        bin_mmatch_step,
        np.repeat(mmatch_hist, 2),
        c=apx_colors[apx_name],
        label=f'{apx_name}: {apx_count}',
    )

    logging.debug(
        "Plotting compression vs %s scatter",
        args.comparison_parameter
    )
    # This makes it so that if there aren't that many points,
    # they are visible, but more points become more cloud-like
    scatter_alpha = max(
        min(100 / len(compression_factor), 1),
        0.05
    )
    axes[1,0].scatter(
        compression_factor[this_appx],
        comparison_values[this_appx],
        alpha=scatter_alpha,
        c=apx_colors[apx_name],
        s=5,
    )
    logging.debug(
        "Plotting mismatch vs %s scatter",
        args.comparison_parameter
    )
    axes[1,1].scatter(
        mismatch[this_appx],
        comparison_values[this_appx],
        alpha=scatter_alpha,
        c=apx_colors[apx_name],
        s=5,
    )
    # This one is for use in the legend:
    axes[1,0].scatter(
        [],[],
        c=apx_colors[apx_name],
        label=f'{apx_name}: {apx_count}'
    )

logging.debug("Setting scales and limits of axes")
if args.log_compression:
    axes[1,0].set_xlim(10 ** factor_min, 10 ** factor_max)
    axes[1,0].set_xscale("log")
else:
    axes[0,0].set_xlim(factor_min, factor_max)
    axes[0,1].set_ylim(factor_min, factor_max)

if args.log_mismatch:
    axes[1,1].set_xscale("log")
    axes[1,1].set_xlim(10 ** mmatch_min, 10 ** mmatch_max)
else:
    axes[1,1].set_xlim(mmatch_min, mmatch_max)

if args.log_comparison:
    axes[1,0].set_yscale("log")
    axes[1,1].set_yscale("log")

if args.log_histogram:
    axes[0,0].set_yscale("log")
    axes[0,1].set_yscale("log")
else:
    axes[0,0].set_ylim(bottom=0)
    axes[0,1].set_ylim(bottom=0)

logging.info("Setting axes labels")
if args.histogram_density:
    axes[0,0].set_ylabel("Template Density")
    axes[0,1].set_ylabel("Template Density")
else:
    axes[0,0].set_ylabel("Number of Templates")
    axes[0,1].set_ylabel("Number of Templates")
axes[1,0].set_xlabel("Compression Factor")

axes[1,0].set_ylabel(comp_label)
axes[1,1].set_ylabel(comp_label)
axes[1,1].set_xlabel("Mismatch")

axes[0,0].legend(loc='upper left')
axes[1,0].legend(loc='upper left')

for ax in axes.flatten():
    ax.grid(zorder=-100)

caption = (
    "Plot showing the a histogram of compression factor (left), a "
    "scatter plot of compression factor vs {label} ({parameter}) (middle) "
    "and mismatch vs {label} ({parameter}) (right)."
    "Legend entries indicate the number of templates per approximant. "
).format(label=comp_label, parameter=args.comparison_parameter)

if args.histogram_density:
    caption += "Density for each histogram is weighted by the number of templates "

logging.info("Saving figure")
save_fig_with_metadata(
    fig,
    args.output,
    title="Bank compression vs %s" % comp_label,
    caption=caption,
    cmd=' '.join(sys.argv)
)
logging.info("Done!")
