#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1768936173/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Produce the sky grid plot for the triggered search (PyGRB)."""

# =============================================================================
# Preamble
# =============================================================================

import sys
import os
import logging
import numpy
import h5py
from matplotlib import pyplot as plt
import matplotlib.colors as colors
from matplotlib import rc
from matplotlib.ticker import MaxNLocator
import pycbc.version
from pycbc import init_logging
from pycbc.results import save_fig_with_metadata
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.detector import Detector
import pycbc.distributions

plt.switch_backend('Agg')
rc('font', size=14)

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_skygrid"


def define_rows_cols_subplot(nplots):
    cols = int(numpy.ceil(numpy.sqrt(nplots)))
    rows = int(numpy.ceil(nplots/cols))
    return cols, rows

def ra_to_ra_mollweide(ra):
    mollweide_ra = numpy.remainder(ra + 2*numpy.pi, 2*numpy.pi)
    mollweide_ra[mollweide_ra > numpy.pi] -= 2*numpy.pi
    return mollweide_ra

# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__)
parser.add_argument("--sky-grid", required=True,
                    help="The location of the sky grid file")
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s:%(levelname)s : %(message)s")

sky_grid = os.path.abspath(opts.sky_grid)
outfile = opts.output_file
if opts.plot_title is None:
    opts.plot_title = 'PyGRB sky grid'
plot_caption = 'Search sky grid points.'

logging.info("Imported and ready to go.")

# Set output directories
outdirs = [os.path.split(os.path.abspath(outfile))[0]]
for outdir in outdirs:
    if not os.path.isdir(outdir):
        os.makedirs(outdir)

#Extract all informations from sky grid
with h5py.File(sky_grid, "r") as f:
    ra, dec = f['ra'][:], f['dec'][:]
    dist = f.attrs['input_distribution']
    input_dist = eval("pycbc.distributions."+dist)
    samples = input_dist.rvs(1000000)
    input_ra, input_dec = samples['ra'], samples['dec'] 
    ifos = f.attrs["detectors"]
    detectors = [Detector(d) for d in ifos]
    gps_time = f.attrs['ref_gps_time']

xlabel = "RA (deg)"
ylabel = "DEC (deg)"

uni_points = pycbc.distributions.UniformSky().rvs(size=100000)

#Convert ra from [0,2*pi] to [-pi,pi] for the mollweide plot 
uni_ra = ra_to_ra_mollweide(uni_points["ra"])
uni_dec = uni_points["dec"]

in_ra = ra_to_ra_mollweide(input_ra)

grb_ra = ra_to_ra_mollweide(ra)

#Generation of the input density for the mollweide plot 
input_density, ra_edge, dec_edge = numpy.histogram2d(
    in_ra,input_dec, bins=600, 
    range=[[-numpy.pi, numpy.pi], [-numpy.pi/2, numpy.pi/2]]
)

cols , rows = define_rows_cols_subplot(len(detectors)+2)
fig, ax = plt.subplots(nrows=rows, ncols=cols,subplot_kw=dict(projection="mollweide"), figsize=(20,20))

#Sky grid over input distribution plot
levels = MaxNLocator(nbins=100).tick_values(0, input_density.max())
ax[0,0].contourf(ra_edge[:-1],dec_edge[:-1],input_density.T, levels=levels,cmap="Greens")
ax[0,0].plot(grb_ra, dec, 'x', c='red')
cb = fig.colorbar(
    plt.cm.ScalarMappable(colors.Normalize(vmin=0, vmax=input_density.max()), cmap="Greens"), 
    location="bottom", ax=ax[0,0], label="Counts"
)
ax[0,0].set_xlabel(xlabel)
ax[0,0].set_ylabel(ylabel)
ax[0,0].set_title("Sky Grid over input distribution")
ax[0,0].grid(True)

#Hide the second "plot"
ax[0,1].axis("Off")

#Zoomed in plot
ax[0,1] = plt.subplot(rows,cols,2, projection="rectilinear")
ax[0,1].set_xlim(numpy.degrees(grb_ra.min())-5, numpy.degrees(grb_ra.max())+5)
ax[0,1].set_ylim(numpy.degrees(dec.min())-5, numpy.degrees(dec.max())+5)
ax[0,1].contourf(numpy.degrees(ra_edge[:-1]),numpy.degrees(dec_edge[:-1]),input_density.T, levels=levels, cmap="Greens")
ax[0,1].plot(numpy.degrees(grb_ra), numpy.degrees(dec), 'x', c='red')
ax[0,1].set_xlabel(xlabel)
ax[0,1].set_ylabel(ylabel)
ax[0,1].grid(True)

if len(detectors) <= 2:
    idx_col, idx_row = 0, 1 
else:
    idx_col, idx_row = 2, 0

#Sky grid over Antenna pattern plots
for det in detectors:
    if idx_col >= cols:
        idx_col = 0
        idx_row += 1
    ant_pat = det.antenna_pattern(uni_points["ra"], uni_points["dec"], 0, t_gps=gps_time)
    quad_ant = numpy.sqrt(ant_pat[0]**2 + ant_pat[1]**2)
    logging.info('Plotting %s', os.path.basename(outfile))
    sc = ax[idx_row,idx_col].scatter(uni_ra, uni_dec, c=quad_ant, marker='.', cmap="viridis")
    ax[idx_row,idx_col].scatter((grb_ra), (dec),c='red', marker='x')
    ax[idx_row,idx_col].set_title(f"Sky grid over {det.name} antenna pattern")
    plt.colorbar(sc, ax=ax[idx_row,idx_col], label=r"$\sqrt{F_+^2 + F_\times^2}$", location="bottom")
    ax[idx_row,idx_col].grid(True)
    idx_col += 1

#Hiding axis on which there are no plot
if len(detectors)+2 < rows*cols:
    ax[-1,-1].axis("Off")
    if len(detectors)+2 < rows*cols - 1:
        ax[-1,-2].axis("Off")

# Wrap up
save_fig_with_metadata(fig, outfile, cmd=' '.join(sys.argv),
                       title=opts.plot_title, caption=plot_caption)
plt.close()
