#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

# Copyright (C) 2016  Soumi De
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
The code reads in a (optionally compressed) template bank and splits it up
into smaller banks. Either the number of banks or the (approximate) number
of templates per bank can be specified.
"""

import argparse
import numpy
import logging
from numpy import random

import pycbc
from pycbc.waveform import bank

__author__  = "Soumi De <soumi.de@ligo.org>"

parser = argparse.ArgumentParser(description=__doc__[1:])
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--bank-file", type=str, required=True,
                    help="Bank hdf file to load.")
outbanks = parser.add_mutually_exclusive_group(required=True)
outbanks.add_argument("--templates-per-bank", type=int,
                      help="Number of templates in each output sub-banks. "
                      "Either specify this or --number-of-banks, not both.")
outbanks.add_argument("--number-of-banks", type=int,
                      help="Number of output sub-banks. Either specify this "
                      "or --templates-per-bank, not both.")
outbanks.add_argument("--output-filenames", nargs='*',
                      action="store",
                      help="Directly specify the names of the output files. "
                      "The bank will be split equally between files.")
parser.add_argument("--output-prefix",
                    help="Prefix to add to the output template bank names, "
                    "for example 'H1L1-BANK'. Output file names would then be "
                    "'H1L1-BANK{x}.hdf' where {x} is 1,2,...")
sortopt = parser.add_mutually_exclusive_group()
sortopt.add_argument("--mchirp-sort", action="store_true", default=False,
                    help="Sort templates by chirp mass before splitting")
sortopt.add_argument("--random-sort", action="store_true", default=False,
                    help="Sort templates randomly before splitting")
parser.add_argument("--random-seed", type=int,
                    help="Random seed for --random-sort")
parser.add_argument("--force", action="store_true", default=False,
                    help="Overwrite the given hdf file if it exists. "
                         "Otherwise, an error is raised.")

args = parser.parse_args()

pycbc.init_logging(args.verbose)

# input checks
if args.mchirp_sort and (args.random_sort or args.random_seed is not None):
    parser.error("Can't use random sort or random seed if using mchirp sort")

if args.output_filenames is None and args.output_prefix is None:
    parser.error("Must specify either output filenames or a prefix!")

if args.output_filenames and args.output_prefix:
    parser.error("Can't specify both output filenames and a prefix")

logging.info("Loading bank")

tmplt_bank = bank.TemplateBank(args.bank_file)

templates = tmplt_bank.table

# Apply any sorting if required.

if args.random_sort:
    if args.random_seed is not None:
        random.seed(args.random_seed)
    idx = numpy.arange(templates.size)
    numpy.random.shuffle(idx)
    templates = templates[idx]
    tmplt_bank.table = templates

if args.mchirp_sort:
    mcsort = numpy.argsort(templates.mchirp)
    templates = templates[mcsort]
    tmplt_bank.table = templates

# Decide how many output banks we are going to have.

if args.output_filenames:
    num_files = len(args.output_filenames)
elif args.number_of_banks:
    num_files = args.number_of_banks
elif args.templates_per_bank:
    num_files = round(templates[:].size / args.templates_per_bank)

# Calculate the number of templates to be stored per bank. For an even
# distribution, this must be a float, converted to int later in the loop.

num_per_file = templates[:].size / num_files

if num_per_file < 1:
    raise ValueError("The user choices imply less than one template per bank")

# Generate sub-banks

logging.info("Generating the output sub-banks")
start_idx = 0
for ii in range(num_files):
    if ii == (num_files - 1):
        end_idx = templates[:].size
    else:
        end_idx = int((ii + 1) * num_per_file)

    # Assign a name to the h5py output file to store the ii'th smaller bank
    if args.output_filenames:
        outname = args.output_filenames[ii]
    elif args.output_prefix:
        outname = args.output_prefix + str(ii) + '.hdf'
    else:
        raise RuntimeError("I shouldn't be able to reach this point. One out "
                           "of --output-filenames and --output-prefix must "
                           "have been supplied!")

    # Generate the hdf5 output file for the ii'th sub-bank, which would
    # be a slice of the input template bank having a start index and
    # end index as calculated above
    output = tmplt_bank.write_to_hdf(outname, start_idx, end_idx,
                                     force=args.force)
    output.close()

    start_idx = end_idx

logging.info("finished")
