#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

# Copyright (C) 2023 Sumit K.
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Generate the calibration configuration file for a given event/GPS time.
Also plots the calibration envelop for each detector for sanity checks.
"""

import os, argparse
import pycbc
import numpy
import matplotlib.pyplot as plt

from pycbc.types import MultiDetOptionAction
from pycbc.strain.recalibrate import get_calibration_files_O1_O2_O3, read_calibration_envelop_file


parser = argparse.ArgumentParser()

# command line usage
#parser = argparse.ArgumentParser(
#            usage="pycbc_inference_create_calibration_config [--options]",
#            description="Generate the calibration configuration files "
#                        "to be used woth pycbc_inference.")

pycbc.add_common_pycbc_options(parser)

parser.add_argument('--calibration-files-path',
                    help='Calibration envelope file path. It can be pre-'
                          'downloaded or downloaded on the fly',
                    type=str)


parser.add_argument('--calibration-files', nargs='+', action=MultiDetOptionAction,
                    metavar = 'DETECTOR:COLUMN', type=str,
                    help='Calibration file for each detector',)

parser.add_argument('--ifos', nargs='+',
                        help='Specify ifos for the analysis')

parser.add_argument('--minimum-frequency', nargs='+', action=MultiDetOptionAction,
                        metavar='DETECTOR:COLUMN', type=float,
                        help='For each detector, provide minimum frequency for'
                        'the analysis. The minimum frequency should be higher than'
                        'the minimum frequency in calibration envelop file so that'
                        'we do not extrapolate.')


parser.add_argument('--maximum-frequency', nargs='+', action=MultiDetOptionAction,
                        metavar='DETECTOR:COLUMN', type=float,
                        help='For each detector, provide maximum frequency for'
                        'the analysis. The minimum frequency should be lower than'
                        'the maximum frequency in calibration envelop file so that'
                        'we do not extrapolate.')


parser.add_argument('--gps-time',
                    help='GPS time at which calibration configuration files'
                          'are needed',
                    required=True, type=float)

parser.add_argument('--output-dir',                                 
                    help="Path for the output calibration config files",                    
                    default='.', type=str)

parser.add_argument('--plots-dir',
                    help="Path for the 'sanity check' plots",
                    default='.', type=str)

parser.add_argument('--n-nodes',
                    help='Number of frequency nodes to be used',
                    default=10, type=int)

parser.add_argument('--correction-type', nargs='+', action=MultiDetOptionAction, 
                    metavar='DETECTOR:COLUMN',
                    help="Provide the correction type for each detector",
                    required=True)

parser.add_argument('--plot-sanity-checks', help='Make some plots for sanity checks',
                    default=True, type=bool)

parser.add_argument('--tag', help='Provide your tag for naming calibration file',
                    default=None)


# parse the command line
opts = parser.parse_args()


if opts.calibration_files_path and opts.calibration_files:
    raise ValueError("Only one of the arguments should be provided among:"
                    "--calibration-files-path and --calibration-files")


calibration_files_path = opts.calibration_files_path
min_freq_list = opts.minimum_frequency
max_freq_list = opts.maximum_frequency
correction_type_list = opts.correction_type
gps_time = opts.gps_time
output_dir = opts.output_dir
n_nodes = opts.n_nodes
if opts.tag is None:
    tag = int(opts.gps_time)
else:
    tag = opts.tag

for ifo in opts.ifos:
  logging.info("Minimum %s frequency: %.3f", ifs, min_freq_list[ifo])

# Read calibration files
if opts.calibration_files:
    calibration_files_dict = {}
    for ifo in opts.ifos:
        calibration_files_dict[ifo] = opts.calibration_files[ifo]
else:
    calibration_files_dict = get_calibration_files_O1_O2_O3(opts.ifos, gps_time, calibration_files_path)

logging.debug("Calibration files dictionary: %s", calibration_files_dict)

input_dict = {}
for ii in range(len(opts.ifos)):
    ifo = str(opts.ifos[ii])
    input_dict['min_freq_%s'%ifo]=min_freq_list[ifo]
    input_dict['max_freq_%s'%ifo]=max_freq_list[ifo]
    input_dict['correction_type_%s'%ifo]=correction_type_list[ifo]
    input_dict['calibration_envelop_file_%s'%ifo] = calibration_files_dict[ifo.upper()]
logging.debug("Input dictionary: %s", input_dict)
# Upper and Lower indices of detectors seems unnecessary
# but they are there to make sure the correct convention
# is followed for recalibration module in inference config
# file as well as in detector notation.
prior_dict = {}
for ifo in opts.ifos:
    #ifo = ifo.lower()
    min_freq = input_dict['min_freq_%s'%ifo]
    max_freq = input_dict['max_freq_%s'%ifo]
    correction_type = input_dict['correction_type_%s'%ifo]
    calib_env_file = input_dict['calibration_envelop_file_%s'%ifo]
    log_nodes, amplitude_median_nodes, amplitude_sigma_nodes, phase_median_nodes, phase_sigma_nodes = \
               read_calibration_envelop_file(calib_env_file, correction_type, 
                                             min_freq, max_freq, n_nodes)
    prior_dict['log_nodes_%s'%ifo] = log_nodes
    prior_dict['amplitude_median_nodes_%s'%ifo] = amplitude_median_nodes
    prior_dict['amplitude_sigma_nodes_%s'%ifo] = amplitude_sigma_nodes
    prior_dict['phase_median_nodes_%s'%ifo] = phase_median_nodes
    prior_dict['phase_sigma_nodes_%s'%ifo] =phase_sigma_nodes
    logging.info('Done for detector %s', ifo)


calibration_filename = '%s/calibration-%s.ini'%(output_dir, tag)
if os.path.isfile(calibration_filename) != True:
    text_file = open(calibration_filename, "w")
    text_file.write("[calibration] \n")
    for ifo in opts.ifos:
        min_freq = input_dict['min_freq_%s'%ifo]
        max_freq = input_dict['max_freq_%s'%ifo]
        text_file.write("%s_model = cubic_spline \n"%ifo.lower())
        text_file.write("%s_minimum_frequency = %d \n"%(ifo.lower(),min_freq))
        text_file.write("%s_maximum_frequency = %d \n"%(ifo.lower(),max_freq))
        text_file.write("%s_n_points = %d \n"%(ifo.lower(),n_nodes))
    text_file.write(" \n")
    text_file.write("[variable_params] \n")
    for ifo in opts.ifos:
        for ii in range(n_nodes):
            text_file.write("recalib_amplitude_%s_%d = \n"%(ifo.lower(),ii))
            text_file.write("recalib_phase_%s_%d = \n"%(ifo.lower(),ii))
    text_file.write(" \n")
    for ifo in opts.ifos:
        for ii in range(n_nodes):
            amplitude_median_nodes = prior_dict['amplitude_median_nodes_%s'%ifo][ii]
            amplitude_sigma_nodes = prior_dict['amplitude_sigma_nodes_%s'%ifo][ii]
            text_file.write("[prior-recalib_amplitude_%s_%d] \n"%(ifo.lower(),ii))
            text_file.write("name = gaussian \n")
            text_file.write("recalib_amplitude_%s_%d_mean = %.3g \n"%(ifo.lower(),ii,amplitude_median_nodes))
            text_file.write("recalib_amplitude_%s_%d_var = %.3g \n"%(ifo.lower(),ii,amplitude_sigma_nodes**2))
            text_file.write(" \n")
    for ifo in opts.ifos:
        for ii in range(n_nodes):
            phase_median_nodes = prior_dict['phase_median_nodes_%s'%ifo][ii]
            phase_sigma_nodes = prior_dict['phase_sigma_nodes_%s'%ifo][ii]
            text_file.write("[prior-recalib_phase_%s_%d] \n"%(ifo.lower(),ii))
            text_file.write("name = gaussian \n")
            text_file.write("recalib_phase_%s_%d_mean = %.3g \n"%(ifo.lower(),ii,phase_median_nodes))
            text_file.write("recalib_phase_%s_%d_var = %.3g \n"%(ifo.lower(),ii,phase_sigma_nodes**2))
            text_file.write(" \n")
    text_file.write(" \n")
    text_file.write(" \n")
    text_file.close()
else:
    raise ValueError(
        'Calibration file already present. '
        'Either delete it or rename it. Overwriting '
        'this file is not allowed!!!'
    )

# Sanity check plots
plots_dir = '%s/plots'%output_dir
os.system('mkdir -p %s'%plots_dir)
if opts.plot_sanity_checks:
    for ifo in opts.ifos:
        calib_env_path = input_dict['calibration_envelop_file_%s'%ifo]
        min_freq = input_dict['min_freq_%s'%ifo]
        max_freq = input_dict['max_freq_%s'%ifo]
        log_nodes = prior_dict['log_nodes_%s'%ifo]
        amp_median_nodes = prior_dict['amplitude_median_nodes_%s'%ifo]
        amp_sigma_nodes = prior_dict['amplitude_sigma_nodes_%s'%ifo]
        d1 = numpy.loadtxt(calib_env_path)
        plt.figure(figsize=(8,5))
        plt.plot(d1[:,0],d1[:,1],label=r'$\mu$')
        plt.plot(d1[:,0],d1[:,3],label=r'$\mu-\sigma$')
        plt.plot(d1[:,0],d1[:,5],label=r'$\mu+\sigma$')
        plt.plot(numpy.exp(log_nodes),amp_median_nodes+1,'.',label=r'$\mu$(Prior)')
        plt.plot(numpy.exp(log_nodes),amp_median_nodes+1-amp_sigma_nodes,'.',label=r'$\mu-\sigma$ (Prior)')
        plt.plot(numpy.exp(log_nodes),amp_median_nodes+1+amp_sigma_nodes,'.',label=r'$\mu+\sigma$ (Prior)')
        plt.xscale('log')
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Amplitude')
        plt.title('Calibration envelop (Amplitude) for %s'%ifo)
        plt.legend()
        plt.savefig('%s/calibration_envelop_amplitude_%s_%s.png'%(plots_dir,tag,ifo))
        plt.clf()

        phase_median_nodes = prior_dict['phase_median_nodes_%s'%ifo]
        phase_sigma_nodes = prior_dict['phase_sigma_nodes_%s'%ifo]
        plt.figure(figsize=(8,5))
        plt.plot(d1[:,0],d1[:,2],label=r'$\mu$')
        plt.plot(d1[:,0],d1[:,4],label=r'$\mu-\sigma$')
        plt.plot(d1[:,0],d1[:,6],label=r'$\mu+\sigma$')
        plt.plot(numpy.exp(log_nodes),phase_median_nodes,'.',label=r'$\mu$ (Prior)')
        plt.plot(numpy.exp(log_nodes),phase_median_nodes-phase_sigma_nodes,'.',label=r'$\mu-\sigma$ (Prior)')
        plt.plot(numpy.exp(log_nodes),phase_median_nodes+phase_sigma_nodes,'.',label=r'$\mu+\sigma$')
        plt.xscale('log')
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Phase')
        plt.title('Calibration envelop (Phase) for %s'%ifo)
        plt.legend()
        plt.savefig('%s/calibration_envelop_phase_%s_%s.png'%(plots_dir,tag,ifo))
        plt.clf()

