#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python
import copy, argparse, logging, numpy, numpy.random
import shutil, uuid, os.path, atexit
from igwn_segments import infinity
import pycbc
from pycbc.events import veto, coinc, stat, cuts
from pycbc.io import HFile
from pycbc import pool, init_logging
from numpy.random import seed, shuffle
from pycbc.io.hdf import ReadByTemplate
from pycbc.types.optparse import MultiDetOptionAction

parser = argparse.ArgumentParser()
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--veto-files", nargs='*', action='append', default=[],
                    help="Optional veto file. Triggers within veto segments "
                         "contained in the file are ignored")
parser.add_argument("--segment-name", nargs='*', action='append', default=[],
                    help="Optional, name of veto segment in veto file")
parser.add_argument("--gating-veto-windows", nargs='+',
                    action=MultiDetOptionAction,
                    help="Seconds to be vetoed before and after the central time "
                         "of each gate. Given as detector-values pairs, e.g. "
                         "H1:-1,2.5 L1:-1,2.5 V1:0,0")
parser.add_argument("--trigger-files", nargs='*', action='append', default=[],
                    help="Files containing single-detector triggers")
parser.add_argument("--template-bank", required=True,
                    help="Template bank file in HDF format")
parser.add_argument("--pivot-ifo", required=True,
                    help="Add the ifo to use as the pivot for multi "
                         "detector coincidence")
parser.add_argument("--fixed-ifo", required=True,
                    help="Add the ifo to use as the fixed ifo for "
                         "multi detector coincidence")
# produces a list of lists to allow multiple invocations and multiple args
parser.add_argument("--use-maxalpha", action="store_true")
parser.add_argument("--coinc-threshold", type=float, default=0.0,
                    help="Seconds to add to time-of-flight coincidence window")
parser.add_argument("--timeslide-interval", type=float,
                    help="Interval between timeslides in seconds. Timeslides are"
                         " disabled if the option is omitted.")
parser.add_argument("--loudest-keep-values",
                    default='[6:1]',
                    help="Apply successive multiplicative levels of"
                         " decimation to coincs with stat value below the"
                         " given thresholds. Supply as a comma-separated list"
                         " of threshold:decimation value pairs surrounded by"
                         " square brackets (no spaces!). Decimation values must"
                         " be positive integers."
                         " Ex. [15:5,10:30,5:30,0:30]."
                         " Default: no decimation")
parser.add_argument("--template-fraction-range", default="0/1",
                    help="Optional, analyze only part of template bank. Format"
                         " PART/NUM_PARTS")
parser.add_argument("--randomize-template-order", action="store_true",
                    help="Random shuffle templates with fixed seed "
                         "before selecting range to analyze")
parser.add_argument("--cluster-window", type=float,
                    help="Optional, window size in seconds to cluster "
                         "coincidences over the bank")
parser.add_argument("--output-file",
                    help="File to store the coincident triggers")
parser.add_argument("--batch-singles", default=5000, type=int,
                    help="Number of single triggers to process at once")
parser.add_argument('--nprocesses', type=int, default=1,
                    help="Number of processes to use")
parser.add_argument('--stage-input', action='store_true',
                    help="Stage input files through to speed up"
                         "access by multiple processes")
parser.add_argument('--stage-input-dir', type=str, default='/dev/shm',
                    help="Directory to stage input files")
stat.insert_statistic_option_group(parser)
cuts.insert_cuts_option_group(parser)
args = parser.parse_args()

# flatten the list of lists of filenames to a single list (may be empty)
args.segment_name = sum(args.segment_name, [])
args.veto_files = sum(args.veto_files, [])
args.trigger_files = sum(args.trigger_files, [])

init_logging(args.verbose)

def parse_template_range(num_templates, rangestr):
    part = int(rangestr.split('/')[0])
    pieces = int(rangestr.split('/')[1])
    tmin = int(num_templates / float(pieces) * part)
    tmax = int(num_templates / float(pieces) * (part+1))
    return tmin, tmax

logging.info('Starting...')

trigger_cut_dict, template_cut_dict = cuts.ingest_cuts_option_group(args)

num_templates = len(HFile(args.template_bank, "r")['template_hash'])
tmin, tmax = parse_template_range(num_templates, args.template_fraction_range)
logging.info('Analyzing template %s - %s' % (tmin, tmax-1))

class MultiifoTrigs(object):
    """store trigger info in parallel with ifo name and shift vector"""
    def __init__(self):
        self.ifos = []
        self.to_shift = []
        self.singles = []

trigs = MultiifoTrigs()
cleanup_files = []

def exit_cleaning():
    for fname in cleanup_files:
        logging.info('cleaning up %s', fname)
        try:
            os.remove(fname)
        except OSError as e:
            print(e)
            pass
atexit.register(exit_cleaning)

# If some templates have no triggers one or more ifos, there can be no
# coincs: so, track which templates contain triggers in all ifos
tids_with_trigs = None

for i in range(len(args.trigger_files)):
    if args.stage_input:
        dest = os.path.join(args.stage_input_dir, str(uuid.uuid4()) + '.hdf')
        logging.info("Moving %s to shared memory as %s",
                     args.trigger_files[i], dest)
        cleanup_files.append(dest)
        shutil.copyfile(args.trigger_files[i], dest)
    else:
        dest = args.trigger_files[i]

    logging.info('Opening trigger file %s: %s' % (i, dest))
    reader = ReadByTemplate(dest,
                            args.template_bank,
                            args.segment_name,
                            args.veto_files,
                            args.gating_veto_windows)
    ifo = reader.ifo
    trigs.ifos.append(ifo)

    # We don't have that many triggers, see if we can skip some templates
    if len(reader.file[ifo]['template_id']) < 2**27:
        uniq = numpy.unique(reader.file[ifo]['template_id'][:]).astype(numpy.int64)
        if tids_with_trigs is None:
            tids_with_trigs = numpy.arange(0, num_templates, dtype=numpy.int64)
        tids_with_trigs = numpy.intersect1d(tids_with_trigs, uniq)

    # time shift is subtracted from pivot ifo time
    trigs.to_shift.append(-1 if ifo == args.pivot_ifo else 0)
    logging.info('Applying time shift multiple %i to ifo %s' %
                 (trigs.to_shift[-1], trigs.ifos[-1]))
    trigs.singles.append(reader)

# Coinc_segs contains only segments where all ifos are analyzed
coinc_segs = veto.start_end_to_segments([-infinity()], [infinity()])
for i, sngl in zip(trigs.ifos, trigs.singles):
    coinc_segs = (coinc_segs & sngl.segs)
for sngl in trigs.singles:
    sngl.segs = coinc_segs
    sngl.valid = veto.segments_to_start_end(sngl.segs)

# Stat class instance to calculate the coinc ranking statistic
rank_method = stat.get_statistic_from_opts(args, trigs.ifos)

# Sanity check, time slide interval should be larger than twice the
# Earth crossing time, which is approximately 0.085 seconds.
TWOEARTH = 0.085
if args.timeslide_interval is not None and args.timeslide_interval <= TWOEARTH:
    raise parser.error("The time slide interval should be larger "
                       "than twice the Earth crossing time.")

# slide = 0 means don't do timeslides
if args.timeslide_interval is None:
    args.timeslide_interval = 0.0

if args.randomize_template_order:
    seed(0)
    template_ids = numpy.arange(0, num_templates)
    shuffle(template_ids)
    template_ids = template_ids[tmin:tmax]
else:
    template_ids = numpy.arange(tmin, tmax)

original_bank_len = len(template_ids)

# Only analyze templates which might have coincs
if tids_with_trigs is not None:
    template_ids = numpy.intersect1d(tids_with_trigs, template_ids)

# Apply cuts to templates
template_ids = cuts.apply_template_cuts(
    trigs.singles[0].bank,
    template_cut_dict,
    statistic=rank_method,
    ifos=trigs.ifos,
    template_ids=template_ids)

logging.info("%d out of %d templates kept after applying template cuts",
             len(template_ids), original_bank_len)

# 'data' will store output of coinc finding
# in addition to these lists of coinc info, will also store trigger times and
# ids in each ifo
data = {'stat': [], 'decimation_factor': [], 'timeslide_id': [], 'template_id': []}
for ifo in trigs.ifos:
    data['%s/time' % ifo] = []
    data['%s/trigger_id' % ifo] = []

factors = [1]
threshes = [numpy.inf]
loudest_keep_vals = args.loudest_keep_values.strip('[]').split(',')

for decstr in loudest_keep_vals:
    thresh, factor = decstr.split(':')
    if float(factor) % 1:
        raise RuntimeError("Non-integer decimation is not supported") 
    if int(factor) < 1:
        raise RuntimeError("Negative or zero decimation does not make sense")
    if int(factor) == 1:
        continue
    threshes.append(float(thresh))
    factors.append(int(factor))

# Sort the threshes into descending order
# - allows decimation factors to be given in any order
threshorder = numpy.argsort(threshes)[::-1]
threshes = numpy.array(threshes)[threshorder]
factors = numpy.array(factors)[threshorder]

# Decimation factors are applied successively in descending order of threshold
total_factors = numpy.cumprod(factors)

# Gather the coincs from a single template
def process_template(tnum):
    local_data = copy.deepcopy(data)
    times_full = {}
    sds_full = {}
    tids_full = {}
    logging.debug('Obtaining trigs for template %i ..' % (tnum))
    for i, sngl in zip(trigs.ifos, trigs.singles):
        # Apply cuts to triggers
        tids_uncut = sngl.set_template(tnum)
        trigger_keep_ids = cuts.apply_trigger_cuts(sngl, trigger_cut_dict,
                                                   statistic=rank_method)

        tids_full[i] = tids_uncut[trigger_keep_ids]
        times_full[i] = sngl['end_time'][trigger_keep_ids]
        logging.debug('%s:%s', i, len(tids_uncut))
        if len(tids_full[i]) < len(tids_uncut):
            logging.info("%s triggers cut",
                         len(tids_uncut) - len(tids_full[i]))


        # get single-detector statistic
        sds_full[i] = rank_method.single(sngl)[trigger_keep_ids]


    mintrigs = min([len(ti) for ti in tids_full.values()])
    if mintrigs == 0:
        logging.info('No triggers in at least one ifo for template %i, '
                     'skipping' % tnum)
        return local_data

    if type(rank_method.single_dtype) == list:
        entries = [e[0] for e in rank_method.single_dtype]
        if 'snr' in entries:
            min_snr = min([numpy.min(sds_full[i]['snr']) for i in trigs.ifos])
        else:
            min_snr = None
        if 'sigmasq'in entries:
            max_sigmasq = min([numpy.max(sds_full[i]['sigmasq']) for i in trigs.ifos])
        else:
            max_sigmasq = None
    else:
        min_snr = None
        max_sigmasq = None

    # Test whether sds_full contains single arrays or record arrays
    # this depends on the stat being used
    try:
        pivot_stat = sds_full[args.pivot_ifo]['snglstat'].copy()
        fixed_stat = sds_full[args.fixed_ifo]['snglstat'].copy()
    except IndexError:
        pivot_stat = sds_full[args.pivot_ifo].copy()
        fixed_stat = sds_full[args.fixed_ifo].copy()

    pivot_sort = numpy.argsort(pivot_stat)
    fixed_sort = numpy.argsort(fixed_stat)

    if not rank_method.single_increasing:
        pivot_sort = pivot_sort[::-1]
        fixed_sort = fixed_sort[::-1]

    # Loop over the single triggers and calculate the coincs they can form
    start0 = 0
    while start0 < len(sds_full[args.fixed_ifo]):
        start1 = 0

        end0 = start0 + args.batch_singles
        if end0 > len(sds_full[args.fixed_ifo]):
            end0 = len(sds_full[args.fixed_ifo])

        fixed_idxs = fixed_sort[start0:end0]

        times = times_full.copy()
        times[args.fixed_ifo] = times_full[args.fixed_ifo][fixed_idxs]

        sds = sds_full.copy()
        sds[args.fixed_ifo] = sds_full[args.fixed_ifo][fixed_idxs]

        tids = tids_full.copy()
        tids[args.fixed_ifo] = tids_full[args.fixed_ifo][fixed_idxs]

        # if decimation is being used precalculate the fixed network coincidences
        if len(threshes) > 1:
            fixed_ifos = [i for i in trigs.ifos if i != args.pivot_ifo]
            fixed_times = {i: times[i] for i in fixed_ifos}

            if len(fixed_ifos) > 1:
                fixed_ids, fixed_slide = coinc.time_multi_coincidence(fixed_times, 0.,
                                                                      args.coinc_threshold,
                                                                      fixed_ifos[0],
                                                                      fixed_ifos[1])
                if len(fixed_slide) == 0:
                    start0 += args.batch_singles
                    continue

                # list in ifo order of remaining trigger data
                fixed_single_info = [(i, sds[i][fixed_ids[i]]) for i in fixed_ifos]
                fixed_times = {i: fixed_times[i][fixed_ids[i]] for i in fixed_ifos}

            else:
                det = fixed_ifos[0]
                fixed_ids = {det: numpy.arange(len(times[det]))}
                fixed_single_info = [(i, sds[i]) for i in fixed_ifos]
                fixed_times = {i: times[i] for i in fixed_ifos}

        pivot_lims = {}
        pivot_lower = {}
        for kidx in range(1, len(threshes)):
            # For each trigger in the fixed network calculate the limit int
            # the pivot detector to pass the current decimation threshold
            pivot_lims[kidx] = rank_method.coinc_lim_for_thresh(
                fixed_single_info, threshes[kidx],
                args.pivot_ifo,
                time_addition=args.coinc_threshold,
                min_snr=min_snr,
                max_sigmasq=max_sigmasq
            )
            if not rank_method.single_increasing:
                pivot_lims[kidx] *= -1.
            # subtract small amount to account for errors due to rounding
            pivot_lims[kidx] -= 1e-6
            # Get the minimum statistic required for all triggers at the
            # current decimation threshold
            pivot_lower[kidx] = pivot_lims[kidx].min()

        while start1 < len(sds_full[args.pivot_ifo]):
            end1 = start1 + args.batch_singles
            if end1 > len(sds_full[args.pivot_ifo]):
                end1 = len(sds_full[args.pivot_ifo])

            pivot_idxs = pivot_sort[start1:end1]

            times[args.pivot_ifo] = times_full[args.pivot_ifo][pivot_idxs]

            sds[args.pivot_ifo] = sds_full[args.pivot_ifo][pivot_idxs]

            tids[args.pivot_ifo] = tids_full[args.pivot_ifo][pivot_idxs]

            # Do time coincidence for slides that will be kept after the last decimation
            ids, slide = coinc.time_multi_coincidence(times,
                                                      args.timeslide_interval*total_factors[-1],
                                                      args.coinc_threshold,
                                                      args.pivot_ifo,
                                                      args.fixed_ifo)
            slide *= total_factors[-1]

            single_info = [(i, sds[i][ids[i]]) for i in trigs.ifos]
            cstat = rank_method.rank_stat_coinc(
                single_info, slide, args.timeslide_interval,
                to_shift=trigs.to_shift,
                time_addition=args.coinc_threshold
            )

            #index values of the zerolag triggers
            fi = numpy.where(slide == 0)[0]

            #index values of the background triggers
            bi = numpy.where(slide != 0)[0]
            bl = bi[cstat[bi] < threshes[-1]]

            ti = numpy.concatenate([fi, bl])

            ids = {i: t[ti] for i, t in ids.items()}

            cstat = cstat[ti]
            slide = slide[ti]
            dec = numpy.concatenate([numpy.ones(len(fi)), numpy.repeat(total_factors[-1], len(bl))])

            try:
                pivot_stat = sds[args.pivot_ifo]['snglstat'].copy()
            except IndexError:
                pivot_stat = sds[args.pivot_ifo].copy()

            if not rank_method.single_increasing:
                pivot_stat *= -1.

            # Starting from the largest decimation threshold, find the first decimation step
            # where the loudest single detector trigger in pivot can pass the decimation threshold
            # with any trigger in the fixed network
            tidx = len(threshes)
            for i in range(1, len(threshes)):
                if pivot_stat[-1] >= pivot_lower[i]:
                    tidx = i
                    break

            # loop through decimation steps starting from the first step where passing
            # the threshold is possible
            for kidx in range(tidx, len(threshes)):

                # Remove triggers in pivot that cannot form coincidences above
                # the current decimation threshold
                pivot_cut = numpy.searchsorted(pivot_stat, pivot_lower[kidx])

                pivot_s = pivot_stat[pivot_cut:]

                test_times = fixed_times.copy()
                test_times[args.pivot_ifo] = times[args.pivot_ifo][pivot_cut:]

                # Do time coincidence for the current decimation factor
                set_ids, set_slide = coinc.time_multi_coincidence(test_times,
                                                                  args.timeslide_interval*total_factors[kidx - 1],
                                                                  args.coinc_threshold,
                                                                  args.pivot_ifo,
                                                                  args.fixed_ifo)
                set_slide *= total_factors[kidx - 1]

                # Remove foreground triggers
                sets = set_slide != 0

                # Each precalculated coincidence from the fixed network should
                # be treated like a single trigger, only keep coincidences
                # where all triggers from the fixed network coincidence are still
                # together
                for i in range(1, len(fixed_ifos)):
                    sets *= set_ids[fixed_ifos[i-1]] == set_ids[fixed_ifos[i]]

                set_ids = {i: set_ids[i][sets] for i in trigs.ifos}
                set_slide = set_slide[sets]

                # Only keep coincidences where pivot has a single stat above the threshold
                # calculated earlier
                above = pivot_s[set_ids[args.pivot_ifo]] >= pivot_lims[kidx][set_ids[fixed_ifos[0]]]

                test_ids = {i: fixed_ids[i][set_ids[i][above]] for i in fixed_ifos}
                test_ids[args.pivot_ifo] = set_ids[args.pivot_ifo][above] + pivot_cut
                set_slide = set_slide[above]
                test_single_info = [(i, sds[i][test_ids[i]]) for i in trigs.ifos]

                test_cstat = rank_method.rank_stat_coinc(
                    test_single_info, set_slide, args.timeslide_interval,
                    to_shift=trigs.to_shift,
                    time_addition=args.coinc_threshold)

                # Keep triggers in the current decimation range
                test_bi = numpy.where(test_cstat >= threshes[kidx])[0]
                test_bi = test_bi[test_cstat[test_bi] < threshes[kidx - 1]]

                for i in trigs.ifos:
                    ids[i] = numpy.concatenate([ids[i], test_ids[i][test_bi]])
                cstat = numpy.concatenate([cstat, test_cstat[test_bi]])
                slide = numpy.concatenate([slide, set_slide[test_bi]])
                dec = numpy.concatenate([dec, numpy.repeat(total_factors[kidx - 1], len(test_bi))])

            # temporary storage for decimated trigger ids
            for ifo in trigs.ifos:
                addtime = times[ifo][ids[ifo]]
                addtriggerid = tids[ifo][ids[ifo]]
                local_data['%s/time' % ifo] += [addtime]
                local_data['%s/trigger_id' % ifo] += [addtriggerid]
            local_data['stat'] += [cstat]
            local_data['decimation_factor'] += [dec]
            local_data['timeslide_id'] += [slide]
            local_data['template_id'] += [numpy.repeat(tnum, len(slide))]
            start1 += args.batch_singles
        start0 += args.batch_singles
    return local_data

if args.nprocesses == 1:
    ldatas = list(map(process_template, list(template_ids)))
else:
    p = pool.BroadcastPool(args.nprocesses)
    ldatas = p.map(process_template, list(template_ids))

logging.info('merging data from the templates')
for ldata in ldatas:
    for key in data:
        data[key] += ldata[key]

if len(data['stat']) > 0:
    for key in data:
        data[key] = numpy.concatenate(data[key])

if args.cluster_window and len(data['stat']) > 0:
    timestring0 = '%s/time' % args.pivot_ifo
    timestring1 = '%s/time' % args.fixed_ifo
    cid = coinc.cluster_coincs(data['stat'], data[timestring0], data[timestring1],
                               data['timeslide_id'], args.timeslide_interval,
                               args.cluster_window)

logging.info('saving coincident triggers')
f = HFile(args.output_file, 'w')
if len(data['stat']) > 0:
    for key in data:
        var = data[key][cid] if args.cluster_window else data[key]
        f.create_dataset(key, data=var,
                         compression='gzip',
                         compression_opts=9,
                         shuffle=True)

# Store coinc segments keyed by detector combination
key = ''.join(sorted(trigs.ifos))
f['segments/%s/start' % key], f['segments/%s/end' % key] = trigs.singles[0].valid

f.attrs['timeslide_interval'] = args.timeslide_interval
f.attrs['coinc_time'] = abs(coinc_segs)
f.attrs['num_of_ifos'] = len(args.trigger_files)
f.attrs['pivot'] = args.pivot_ifo
f.attrs['fixed'] = args.fixed_ifo
for i, sngl in zip(trigs.ifos, trigs.singles):
    f.attrs['%s_foreground_time' % i] = abs(sngl.segs)
f.attrs['ifos'] = ' '.join(sorted(trigs.ifos))

# What does this code actually calculate?
if args.timeslide_interval:
    maxtrigs = max([abs(sngl.segs) for sngl in trigs.singles])
    nslides = int(maxtrigs / args.timeslide_interval)
else:
    nslides = 0
f.attrs['num_slides'] = nslides


logging.info('Done')
