#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1767619825/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

# Copyright 2024 Max Trevor
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

"""
Calculate the noise trigger rate adjustment as a function of data-quality state
for a day of PyCBC Live triggers.
"""

import logging
import argparse
import hashlib
import os
import glob

import numpy as np

import gwdatafind

import pycbc
from pycbc.io import HFile

from gwpy.timeseries import TimeSeriesDict
from gwpy.segments import Segment, DataQualityFlag


def find_frames(ifo, frame_dir, frame_type, start, end):
    """
    Find the frame files for a given time range.
    """
    ifo_frame_type = frame_type.format(ifo=ifo)
    ifo_pattern = f'{ifo[0]}-{ifo_frame_type}_llhoft-'
    folder_pattern = ifo_pattern + '{supertime}'
    path_pattern = os.path.join(
        frame_dir,
        ifo_frame_type,
        folder_pattern,
    )

    start_round = int(start / 10000)
    end_round = int(end / 10000)
    files = []
    for t in range(start_round, end_round + 1):
        framedir = path_pattern.format(supertime=t)
        files += glob.glob(os.path.join(framedir, '*.gwf'))

    return files


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--trigger-file", required=True)
parser.add_argument("--ifo", required=True)
parser.add_argument("--gps-start-time", required=True, type=int)
parser.add_argument("--gps-end-time", required=True, type=int)
parser.add_argument("--template-bin-file", required=True)
parser.add_argument("--use-gwdatafind", action='store_true',
                    help='Use gwdatafind to find frame files. If provided, '
                         'frame-directory argument will be ignored.')
parser.add_argument("--frame-directory",
                    help='Directory containing frame files. Required if '
                         'not using gwdatafind.')
parser.add_argument("--frame-type", required=True)
parser.add_argument("--analysis-flag-name", required=True)
parser.add_argument("--dq-channel", required=True)
parser.add_argument("--dq-ok-channel", required=True)
parser.add_argument("--dq-thresh", required=True, type=float)
parser.add_argument("--dq-padding", type=float, default=1.0)
parser.add_argument("--replay-offset", type=int, default=0)
parser.add_argument("--output", required=True)
pycbc.add_common_pycbc_options(parser)
args = parser.parse_args()

pycbc.init_logging(args.verbose)

if not args.use_gwdatafind and args.frame_directory is None:
    raise ValueError('Must provide frame-directory if not using gwdatafind.')

# Get observing segs
ar_flag_name = args.analysis_flag_name.format(ifo=args.ifo)
day_seg = Segment(args.gps_start_time, args.gps_end_time)
observing_flag = DataQualityFlag.query(ar_flag_name, day_seg)

# shift observing flag to current time
observing_flag = observing_flag.pad(args.replay_offset, args.replay_offset)

observing_segs = observing_flag.active
livetime = observing_flag.livetime
logging.info(f'Found {livetime} seconds of observing time at {args.ifo}.')

# for each segment, check how much time was dq flagged
flagged_time = 0
dq_channel = args.dq_channel.format(ifo=args.ifo)
dq_ok_channel = args.dq_ok_channel.format(ifo=args.ifo)
for seg in observing_segs:
    if args.use_gwdatafind:
        frame_type = args.analysis_frame_type.format(ifo=args.ifo)
        frames = gwdatafind.find_urls(
            args.ifo[0],
            frame_type,
            seg[0],
            seg[1],
        )
    else:
        frames = find_frames(
            args.ifo,
            args.frame_directory,
            args.frame_type,
            seg[0],
            seg[1],
        )

    tsdict = TimeSeriesDict.read(
        frames,
        channels=[dq_channel, dq_ok_channel],
        start=seg[0],
        end=seg[1],
        pad=0,
    )

    dq_ok_flag = (tsdict[dq_ok_channel] == 1).to_dqflag()
    dq_flag = (tsdict[dq_channel] <= args.dq_thresh).to_dqflag()

    valid_bad_dq = dq_flag & dq_ok_flag

    # pad flag outwards to match the padding used in the live script
    valid_bad_dq.protract(args.dq_padding)
    valid_bad_dq.coalesce()

    flagged_time += (valid_bad_dq & observing_flag).livetime

logging.info(f'Found {flagged_time} seconds of dq flagged time at {args.ifo}.')

bg_livetime = livetime - flagged_time
state_time = np.array([bg_livetime, flagged_time])

# read in template bins
logging.info(f'Reading template bins from {args.template_bin_file}')
template_bins = {}
num_templates = 0
with HFile(args.template_bin_file, 'r') as bin_file:
    f_lower = bin_file.attrs['f_lower']
    bank_file = bin_file.attrs['bank_file']
    bin_string = bin_file.attrs['background_bins']

    # only ever one group in this file
    grp = bin_file[list(bin_file.keys())[0]]
    num_bins = len(grp.keys())
    for k in grp.keys():
        template_bins[k] = grp[k]['tids'][:]
        num_templates += len(template_bins[k])

# for each bin, get total number of triggers and number of dq flagged triggers
bin_total_triggers = np.zeros(num_bins)
bin_dq_triggers = np.zeros(num_bins)

logging.info(f'Reading triggers from {args.trigger_file}')
with HFile(args.trigger_file, 'r') as trigf:
    for bin_name in template_bins.keys():
        # bins are named as 'bin{bin_num}'
        bin_num = int(bin_name[3:])
        template_ids = template_bins[bin_name]
        for template_id in template_ids:
            # get dq states for all triggers with this template
            # dq state is either 0 or 1 for each trigger
            dq_key = trigf[f'{args.ifo}/dq_state_template'][template_id]
            dq_states = trigf[f'{args.ifo}/dq_state'][dq_key]

            # update trigger counts
            bin_total_triggers[bin_num] += len(dq_states)
            bin_dq_triggers[bin_num] += np.sum(dq_states)

# write outputs to file
logging.info(f'Writing results to {args.output}')
with HFile(args.output, 'w') as f:
    ifo_group = f.create_group(args.ifo)
    ifo_group.create_dataset('observing_livetime', data=livetime)
    ifo_group.create_dataset('dq_flag_livetime', data=flagged_time)
    bin_group = ifo_group.create_group('bins')
    for bin_name in template_bins.keys():
        bin_num = int(bin_name[3:])
        bgrp = bin_group.create_group(bin_name)
        bgrp.create_dataset('tids', data=template_bins[bin_name])
        bgrp.create_dataset('total_triggers', data=bin_total_triggers[bin_num])
        bgrp.create_dataset('dq_triggers', data=bin_dq_triggers[bin_num])

        bg_triggers = bin_total_triggers[bin_num] - bin_dq_triggers[bin_num]
        num_trigs = np.array([bg_triggers, bin_dq_triggers[bin_num]])
        trig_rates = num_trigs / state_time
        mean_rate = bin_total_triggers[bin_num] / livetime
        normalized_rates = trig_rates / mean_rate
        bgrp.create_dataset('dq_rates', data=normalized_rates)

    f.attrs['dq_thresh'] = args.dq_thresh
    f.attrs['dq_channel'] = dq_channel
    f.attrs['dq_ok_channel'] = dq_ok_channel
    f.attrs['gps_start_time'] = args.gps_start_time
    f.attrs['gps_end_time'] = args.gps_end_time
    f.attrs['f_lower'] = f_lower
    f.attrs['bank_file'] = bank_file
    f.attrs['background_bins'] = bin_string

    # hash is used to check if different files have compatible settings
    settings_to_hash = [args.dq_thresh, dq_channel, dq_ok_channel,
                        f_lower, bank_file, bin_string]
    setting_str = ' '.join([str(s) for s in settings_to_hash])
    hash_object = hashlib.sha256(setting_str.encode())
    f.attrs['settings_hash'] = hash_object.hexdigest()
