#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1767619825/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python
""" Plot variation in PSD
"""
import matplotlib
matplotlib.use('Agg')
import numpy
import argparse
from matplotlib import pyplot as plt
import sys

import pycbc
import pycbc.results
import pycbc.psd
from pycbc.io.hdf import HFile


parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--psd-files", nargs='+', required=True,
                    help='HDF file(s) containing the PSDs to plot')
parser.add_argument('--hdf-group', default=None,
                    help="Group in the HDF file(s) to read the psd from. "
                         "PSDs will be read from {psd_group}/{ifo}. Default "
                         "is to look in the top level.")
parser.add_argument('--low-frequency-cutoff', type=float,
                    help="The low frequency cutoff of the PSD. If not "
                         "provided, will try to retrive it from the "
                         "hdf-group's attrs.")
parser.add_argument("--ifos", nargs="+",
                    help="What ifo(s) to plot. If none provided, will use  "
                         "the names of the top-level keys in the psd file(s).")
parser.add_argument("--dyn-range-factor", type=float,
                    default=pycbc.DYN_RANGE_FAC,
                    help="Dynamic range factor that was applied to the PSDs "
                         "in the PSD files. Default is pycbc.DYN_RANGE_FAC.")
parser.add_argument("--output-file", required=True, help='Output file name')
parser.add_argument("--memory-limit", default=2., type=float, metavar="X GB",
                    help="Reading in all PSD files can use a lot of memory. "
                         "If given, this option will read the PSDs in slices "
                         "ensuring that memory usage approximately is limited "
                         "to the given value in GB. For general usage on LDG "
                         "clusters, this defaults to 2GB if this option is "
                         "not provided explicitly.")
pycbc.psd.insert_psd_option_group(parser, output=False)
pycbc.results.add_style_opt_to_parser(parser)
args = parser.parse_args()

pycbc.init_logging(args.verbose)

# set the matplotlib style
pycbc.results.set_style_from_cli(args)

fig = plt.figure(0)
ax = fig.gca()
ax.grid(which='both', ls='solid', alpha=0.2, lw=0.3)
ax.set_ylabel('Amplitude Spectral Density (Strain / $\\sqrt{\\rm Hz}$)')
ax.set_xlabel('Frequency (Hz)')

y_min = None

for psd_file in args.psd_files:
    f = HFile(psd_file, 'r')
    if args.hdf_group is not None:
        f = f[args.hdf_group]
    if args.ifos is not None:
        ifo_list = args.ifos
    else:
        ifo_list = list(f.keys())
    for ifo in ifo_list:
        fac = 1.0 / args.dyn_range_factor
        df = f[ifo + '/psds/0'].attrs['delta_f']
        keys = list(f[ifo + '/psds'].keys())
        if args.low_frequency_cutoff is not None:
            flow = args.low_frequency_cutoff
        elif 'low_frequency_cutoff' in f.attrs.keys():
            flow = f.attrs['low_frequency_cutoff']
        else:
            try:
                with HFile(psd_file, 'r') as base_file:
                    flow = base_file.attrs[
                               '{}_likelihood_low_freq'.format(ifo)]
            except KeyError:
                raise RuntimeError("hdf file %s does not contain an attribute giving "
                                   "low frequency cutoff, please specify a value using "
                                   "the --low-frequency-cutoff option" % psd_file)
        kmin = int(flow / df)
        kmax = len(f[ifo + '/psds/' + keys[0]][:])
        klen = kmax - kmin
        # Let's try and keep this to 1GB
        num_points = len(keys) * (klen)
        # Assuming 16 bytes memory usage per PSD point read. That number was
        # determined empirically using top on a 16GB machine
        max_points_read = args.memory_limit * (1E9 / 16.)
        num_splits = int(numpy.ceil(num_points / max_points_read))
        num_points_to_read = int((klen) / num_splits)

        high = numpy.zeros(klen)
        low = numpy.zeros(klen)
        middle = numpy.zeros(klen)
        samples = numpy.arange(kmin, kmax) * df

        for split_idx in range(num_splits):
            curr_kmin = kmin + split_idx * num_points_to_read
            curr_kmax = kmin + (split_idx+1) * num_points_to_read
            # klen may not divide by num_gb perfectly
            if split_idx == (num_splits - 1):
                curr_kmax = kmax

            psds = [f[ifo + '/psds/' + key][curr_kmin:curr_kmax] \
                    for key in keys]
            high[curr_kmin-kmin:curr_kmax-kmin] = \
                numpy.percentile(psds, 95, axis=0) ** 0.5 * fac
            low[curr_kmin-kmin:curr_kmax-kmin] = \
                numpy.percentile(psds, 5, axis=0) ** 0.5 * fac
            middle[curr_kmin-kmin:curr_kmax-kmin] = \
                numpy.percentile(psds, 50, axis=0) ** 0.5 * fac

        if y_min is None or y_min > low.min():
            y_min = low.min()

        color = pycbc.results.ifo_color(ifo)

        ax.fill_between(samples, low, high, alpha=0.4, linewidth=0, color=color)
        ax.loglog(samples, middle, linewidth=0.3, color=color, label=ifo)
        ax.set_xlim(flow, samples[-1])

if args.psd_model or args.psd_file or args.asd_file:
    reference_psd = pycbc.psd.from_cli(args, 2048, 1., 10., None)
    if reference_psd is not None:
        ax.loglog(reference_psd.sample_frequencies, reference_psd ** 0.5,
                  '-k', lw=0.3, label='Reference')

ax.set_ylim(y_min * 0.5, y_min * 100)

ax.legend(loc='upper right', fontsize='small')
pycbc.results.save_fig_with_metadata(fig, args.output_file, 
    title="Noise Amplitude Spectral Density",
    caption="Median amplitude spectral density plotted with a shaded region " 
              "between the 5th and 95th perentiles. ",
    cmd=' '.join(sys.argv),
    fig_kwds={'dpi': 200,
              'bbox_inches': 'tight'})
