#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

# Copyright 2024 Max Trevor
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

"""Combine the data-quality adjusted trigger rates from multiple days."""

import logging
import argparse

import numpy as np

import pycbc
from pycbc.io import HFile

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--daily-dq-files", nargs="+", required=True,
                    help="Files containing daily dq trigger rates")
parser.add_argument("--ifo", required=True)
parser.add_argument("--output", required=True)
args = parser.parse_args()

pycbc.init_logging(args.verbose)

daily_files = args.daily_dq_files
daily_files.sort()

# need all files to use compatible settings
# get settings hash from the last file
# we will only use files that have the same hash
with HFile(daily_files[-1], 'r') as last_file:
    settings_hash = last_file.attrs['settings_hash']
    bin_str = last_file.attrs['background_bins']
    bank_file = last_file.attrs['bank_file']
    f_lower = last_file.attrs['f_lower']
    dq_thresh = last_file.attrs['dq_thresh']
    dq_channel = last_file.attrs['dq_channel']
    dq_ok_channel = last_file.attrs['dq_ok_channel']
    bin_group = last_file[f'{args.ifo}/bins']
    template_bins = {bin_name: bin_group[bin_name]['tids'][:]
                     for bin_name in bin_group.keys()}
num_bins = len(bin_str.split())

total_livetime = 0
flagged_livetime = 0
total_triggers = np.zeros(num_bins)
flagged_triggers = np.zeros(num_bins)
for fpath in daily_files:
    with HFile(fpath, 'r') as f:
        if f.attrs['settings_hash'] != settings_hash:
            warning_str = f'File {fpath} has incompatible settings, skipping'
            logging.warning(warning_str)
            continue
        total_livetime += f[f'{args.ifo}/observing_livetime'][()]
        flagged_livetime += f[f'{args.ifo}/dq_flag_livetime'][()]
        bin_group = f[f'{args.ifo}/bins']
        for bin_name in bin_group.keys():
            # bins are named as 'bin{bin_num}'
            bin_num = int(bin_name[3:])
            bgrp = bin_group[bin_name]
            total_triggers[bin_num] += bgrp['total_triggers'][()]
            flagged_triggers[bin_num] += bgrp['dq_triggers'][()]

total_trigger_rate = total_triggers / total_livetime
flag_trigger_rate = flagged_triggers / flagged_livetime
bg_triggers = total_triggers - flagged_triggers
bg_livetime = total_livetime - flagged_livetime
bg_trigger_rate = bg_triggers / bg_livetime

# save results
with HFile(args.output, 'w') as f:
    ifo_grp = f.create_group(args.ifo)
    all_bin_grp = ifo_grp.create_group('bins')

    for bin_name, bin_tids in template_bins.items():
        bin_grp = all_bin_grp.create_group(bin_name)
        bin_grp['tids'] = bin_tids
        bin_num = int(bin_name[3:])
        bin_trig_rates = [bg_trigger_rate[bin_num], flag_trigger_rate[bin_num]]
        bin_trig_rates /= total_trigger_rate[bin_num]
        bin_grp['dq_rates'] = bin_trig_rates
        bin_grp['num_triggers'] = total_triggers[bin_num]

    f.attrs['settings_hash'] = settings_hash
    f.attrs['stat'] = f'{args.ifo}-dq_stat_info'
    f.attrs['total_livetime'] = total_livetime
    f.attrs['flagged_livetime'] = flagged_livetime
    f.attrs['dq_thresh'] = dq_thresh
    f.attrs['dq_channel'] = dq_channel
    f.attrs['dq_ok_channel'] = dq_ok_channel
    f.attrs['background_bins'] = bin_str
    f.attrs['bank_file'] = bank_file
    f.attrs['f_lower'] = f_lower
