#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

# Copyright (C) 2014 Andrew Lundgren, Tito Dal Canton
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Compute the optimal SNRs for every injection in an input HDF file or 
sim_inspiral table and store the result in HDF file datasets or in the
same ligolw table.
"""

import logging
import argparse
import multiprocessing
import numpy as np

from igwn_ligolw import utils as ligolw_utils
from igwn_ligolw import lsctables

import pycbc
import pycbc.inject
import pycbc.psd
from pycbc.pool import BroadcastPool as Pool
from pycbc.filter import sigma, make_frequency_series
from pycbc.types import TimeSeries, FrequencySeries, zeros, float32, \
                        MultiDetOptionAction, load_frequencyseries
from pycbc.io.ligolw import get_table_columns
from pycbc.io.hdf import HFile


class TimeIndependentPSD(object):
    def __init__(self, psd_series):
        self.psd_series = psd_series

    def __call__(self, time=None):
        return self.psd_series

class TimeVaryingPSD(object):
    def __init__(self, file_name, length=None, delta_f=None, f_low=None):
        with HFile(file_name, 'r') as f:
            self.file_name = file_name
            detector = tuple(f.keys())[0]
            self.start_times = f[detector + '/start_time'][:]
            self.end_times = f[detector + '/end_time'][:]
            self.file_f_low = f.attrs['low_frequency_cutoff']
        self._curr_psd = {}
        self._curr_psd_index = {}
        self.detector = detector
        self.length = length
        self.delta_f = delta_f
        self.f_low = f_low

    def __call__(self, time=None):
        mask = np.logical_and(self.start_times <= time,
                              self.end_times > time)
        if not mask.any():
            return None
        center_times = (self.start_times[mask] + self.end_times[mask]) / 2.
        closest_idx = np.argmin(abs(center_times - time))
        return self.get_psd(np.flatnonzero(mask)[closest_idx])

    def get_psd(self, index):
        curr_pid = multiprocessing.current_process().pid
        if curr_pid not in self._curr_psd_index.keys():
            self._curr_psd_index[curr_pid] = -1
        if not index == self._curr_psd_index[curr_pid]:
            group = self.detector + '/psds/' + str(index)
            psd = load_frequencyseries(self.file_name, group=group)
            if delta_f is not None and psd.delta_f != delta_f:
                psd = pycbc.psd.interpolate(psd, delta_f)
            if self.length is not None and self.length != len(psd):
                psd2 = FrequencySeries(zeros(self.length, dtype=psd.dtype),
                                       delta_f=psd.delta_f)
                if self.length > len(psd):
                    psd2[:] = np.inf
                    psd2[0:len(psd)] = psd
                else:
                    psd2[:] = psd[0:self.length]
                psd = psd2
            if self.f_low is not None and self.f_low < self.file_f_low:
                # avoid using the PSD below the f_low given in the file
                k = int(self.file_f_low / psd.delta_f)
                psd[0:k] = np.inf
            self._curr_psd[curr_pid] = psd
            self._curr_psd_index[curr_pid] = index
        return self._curr_psd[curr_pid]

def parse_injection_range(num_inj, rangestr):
    part = int(rangestr.split('/')[0])
    pieces = int(rangestr.split('/')[1])
    tmin =  num_inj * part // pieces
    tmax =  num_inj * (part + 1) // pieces
    return tmin, tmax

def get_gc_end_time(injection):
    """Return the geocenter end time of an injection. Required for seamless
    compatibility with LIGOLW and HDF injection objects, which use different
    names.
    """
    try:
        # geocent time is robust to potentially incomplete sim tables
        return injection.geocent_end_time
    except AttributeError:
        return injection.tc


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__doc__)
    pycbc.add_common_pycbc_options(parser)
    parser.add_argument('--input-file', '-i', dest='injection_file',
                        required=True,
                        help='Input LIGOLW file defining injections')
    parser.add_argument('--injection-f-ref', type=float,
                        help='Reference frequency in Hz for '
                             'creating CBC injections from an XML '
                             'file.')
    parser.add_argument('--injection-f-final', type=float,
                        help='Override the f_final field of a CBC '
                             'XML injection file.')
    parser.add_argument('--input-sort', type=str, choices=['random', 'time'],
                        default='random',
                        help='Sort injections by the given criterion before '
                             'processing them. Sorting by time may speed up '
                             'the calculation when the approximant is very '
                             'fast and injections are tightly spaced in time')
    parser.add_argument('--output-file', '-o', dest='out_file', required=True,
                        help='Output LIGOLW file')
    parser.add_argument('--f-low', type=float, default=30.,
                        help='Start frequency of matched-filter integration '
                             'in Hz (default %(default)s)')
    parser.add_argument('--seg-length', type=float, default=256,
                        help='Segment duration in seconds (default %(default)s)')
    parser.add_argument('--sample-rate', type=float, default=16384,
                        help='Data sample rate in Hz (default %(default)s)')
    parser.add_argument('--ifos', nargs='+',
                        help='Specify ifos for default HDF dataset output')
    parser.add_argument('--snr-columns', nargs='+', action=MultiDetOptionAction,
                        metavar='DETECTOR:COLUMN',
                        help='For sim_inspiral table output, specify columns'
                        ' to store the optimal SNR for each detector. COLUMN'
                        ' should be an existing sim_inspiral column containing'
                        ' no useful data, alpha1, alpha2 etc. are good'
                        ' candidates. For HDF output the --ifos option should be'
                        ' used, datasets will be named eg "optimal_snr_H1"')
    parser.add_argument('--cores', default=1, type=int,
                        help='Parallelize the computation over the given '
                             'number of cores')
    parser.add_argument('--ignore-waveform-errors', action='store_true',
                        help='Ignore errors in waveform generation and keep '
                             'the corresponding column unchanged')
    parser.add_argument('--progress', action='store_true',
                        help='Show a progress bar (requires tqdm)')
    psd_group = pycbc.psd.insert_psd_option_group_multi_ifo(parser)
    psd_group.add_argument('--time-varying-psds', nargs='*', metavar='FILE',
                           help='Instead of time-independent PSDs, use time-varying '
                           'PSDs from the given HDF5 files and pick the appropriate '
                           'PSD for each injection')
    parser.add_argument('--injection-fraction-range', default='0/1',
                        help='Optional, analyze only a certain range of the '
                             'injections. Format PART/NUM_PARTS')
    opts = parser.parse_args()

    if opts.ifos is not None:
        detectors = opts.ifos
        if opts.snr_columns is not None:
            parser.error("Can't use both --ifos and --snr-columns !")
        opts.snr_columns = {i: 'optimal_snr_' + i for i in opts.ifos}
    else:
        detectors = opts.snr_columns.keys()

    if not opts.time_varying_psds:
        pycbc.psd.verify_psd_options_multi_ifo(opts, parser, detectors)

    pycbc.init_logging(opts.verbose)

    seg_len = opts.seg_length
    sample_rate = opts.sample_rate
    delta_t = 1. / sample_rate
    delta_f = 1. / seg_len
    tlen = int(seg_len * sample_rate)
    flen = tlen // 2 + 1
    f_low = opts.f_low

    logging.info("Loading PSDs")
    if opts.time_varying_psds:
        psds = {}
        for tvpsd_file in opts.time_varying_psds:
            tvpsd = TimeVaryingPSD(tvpsd_file, flen, delta_f, f_low)
            psds[tvpsd.detector] = tvpsd
        if set(detectors) != set(psds.keys()):
            parser.error('Inconsistent detector list in time-varying PSD ' \
                         'specification (%s vs %s)' % (detectors, psds.keys()))
    else:
        psds = pycbc.psd.from_cli_multi_ifos(
            opts,
            length_dict=dict((det, flen) for det in detectors),
            delta_f_dict=dict((det, delta_f) for det in detectors),
            low_frequency_cutoff_dict=dict((det, f_low) for det in detectors),
            ifos=detectors,
            strain_dict=dict((det, None) for det in detectors),
            dyn_range_factor=pycbc.DYN_RANGE_FAC)
        for det in detectors:
            psds[det] = TimeIndependentPSD(psds[det].astype(float32))

    def get_injection(injections, det, injection_time, simulation_id):
        """ Do an injection from the injection XML file, specified by
        IFO and end time"""
        # leave 4 s of padding at the end for possible ringdown
        start_time = int(injection_time + 4. - seg_len)
        strain = TimeSeries(zeros(tlen, dtype=float32), delta_t=delta_t,
                            epoch=start_time)
        injections.apply(strain, det, distance_scale=1./pycbc.DYN_RANGE_FAC,
                         simulation_ids=[simulation_id])
        return make_frequency_series(strain)

    def compute_optimal_snr(inj):
        if not ligolw:
            inj = inj.view(np.recarray)
        for det, column in opts.snr_columns.items():
            injection_time = get_gc_end_time(inj)
            psd = psds[det](injection_time)
            if psd is None:
                continue
            logging.debug('Trying injection %s at %s', inj.simulation_id, det)
            try:
                wave = get_injection(injections, det, injection_time,
                                     simulation_id=inj.simulation_id)
            except Exception as e:
                if opts.ignore_waveform_errors:
                    logging.debug(
                        '%s: waveform generation failed, skipping (%s)',
                        inj.simulation_id,
                        e
                    )
                    continue
                else:
                    logging.error('%s: waveform generation failed with the '
                                  'following exception', inj.simulation_id)
                    raise
            logging.debug(
                'Injection %s at %s completed',
                inj.simulation_id,
                det
            )
            sval = sigma(wave, psd=psd, low_frequency_cutoff=f_low)
            if ligolw:
                setattr(inj, column, sval)
            else:
                inj[column] = sval
        return inj

    logging.info("Loading injections")
    injections = pycbc.inject.InjectionSet.from_cli(opts)
    inj_table = injections.table

    # a bit of ugly special-casing to keep the traditional behavior of
    # pycbc_optimal_snr: if both input and output are LIGOLW documents, we want
    # to preserve the entire content of the document and only modify the
    # particular columns of the sim_inspiral table.  If HDF injections are
    # involved, we do not care.
    ligolw_suffixes = ('.xml', '.xml.gz')
    ligolw = opts.injection_file.endswith(ligolw_suffixes) \
            and opts.out_file.endswith(ligolw_suffixes)

    if not ligolw:
        # create placeholder fields for FieldArray injections
        for det, column in opts.snr_columns.items():
            if column in inj_table:
                continue
            inj_table = inj_table.add_fields(np.zeros(len(inj_table)),
                                             column)

        # make sure we have simulation IDs
        if 'simulation_id' not in inj_table:
            inj_table = inj_table.add_fields(np.arange(len(inj_table)),
                                             'simulation_id')

        inj_dtype = inj_table.dtype

    if opts.input_sort == 'random':
        np.random.seed(100)
        sort_func = lambda x: np.random.random()
    elif opts.input_sort == 'time':
        sort_func = get_gc_end_time
    inj_table = sorted(inj_table, key=sort_func)

    if ligolw:
        new_inj_table = lsctables.SimInspiralTable.new(
            columns=get_table_columns(injections.table)
        )
    else:
        new_inj_table = []

    # Cut inj_table down to the range defined by opts.injection_fraction_range
    num_injections = len(inj_table)
    imin, imax = parse_injection_range(num_injections,
                                       opts.injection_fraction_range)
    inj_table = inj_table[imin:imax]

    if opts.cores > 1:
        logging.info('Starting workers')
        pool = Pool(processes=opts.cores)
        iterator = pool.imap_unordered(compute_optimal_snr, inj_table)
    else:
        # do not bother spawning extra processes if running single-core
        iterator = (compute_optimal_snr(inj) for inj in inj_table)

    if opts.progress:
        try:
            from tqdm import tqdm

            iterator = tqdm(iterator, total=len(inj_table))
        except ImportError:
            logging.warning('cannot import tqdm; not showing progress bar')
            pass

    for inj in iterator:
        new_inj_table.append(inj)

    # always store injections sorted by coalescence time
    new_inj_table.sort(key=get_gc_end_time)

    if not ligolw:
        new_inj_table = pycbc.io.FieldArray.from_records(
                new_inj_table, dtype=inj_dtype)

    logging.info('Writing output')

    if ligolw:
        llw_doc = injections.indoc
        llw_root = llw_doc.childNodes[0]
        llw_root.removeChild(injections.table)
        llw_root.appendChild(new_inj_table)
        ligolw_utils.write_filename(llw_doc, opts.out_file, compress='auto')
    else:
        pycbc.inject.InjectionSet.write(opts.out_file, new_inj_table)

    logging.info('Done')
