#!/home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycbc_1767619825/host_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehol/bin/python

"""Generate a grid of points in the sky to be used by `pycbc_multi_inspiral` to
find multi-detector gravitational wave triggers and calculate the coherent SNRs
and related statistics. The grid is constructed using a stochastic sampling
algorithm.  The region of the sky covered by the grid can be optionally
constrained to a circular region, or via an arbitrary probability density
function and a given amount of probability to cover, typically representing the
sky location of a non-gravitational-wave transient event. The angular spacing
between the points is determined by the relative arrival times at the different
detectors, based on the method described in Section V of
https://arxiv.org/abs/1410.6042.  Please refer to the documentation of
`pycbc.types.angle_as_radians()` for the recommended configuration file syntax
for angle arguments.
"""

import logging
import numpy as np
import argparse
import itertools

import pycbc
import pycbc.distributions
from pycbc.detector import Detector
from pycbc.types import angle_as_radians
from pycbc.tmpltbank.sky_grid import SkyGrid


def spher_to_cart(sky_points):
    """Convert spherical coordinates to cartesian coordinates."""
    cart = np.zeros((len(sky_points), 3))
    cart[:, 0] = np.cos(sky_points[:, 0]) * np.cos(sky_points[:, 1])
    cart[:, 1] = np.sin(sky_points[:, 0]) * np.cos(sky_points[:, 1])
    cart[:, 2] = np.sin(sky_points[:, 1])
    return cart


def angular_distance(test_point, grid):
    grid_cart = spher_to_cart(grid)
    test_cart = spher_to_cart(np.array([test_point]))
    dot = np.tensordot(grid_cart, test_cart, ((1,),(1,))).ravel()
    dists = np.arccos(dot)
    return min(dists)


def dist_from_str(s):
    """Instantiate and return a sky-location distribution from the given
    string argument.
    """
    assert s.endswith(')')
    return eval('pycbc.distributions.' + s)


def make_single_det_grid(args):
    """Construct a sky grid for a single-detector analysis.
    Since a one-detector network has no sky localization capability,
    we just pick a single point.
    """
    if (args.ra and args.dec and args.sky_error) is not None:
        # --ra/--dec arguments
        ra, dec = args.ra, args.dec
    elif args.input_dist == 'UniformSky()':
        # Covering the whole sky, just pick a reference point
        ra, dec = 0., 0.
    else:
        # Restricted non-uniform input distribution
        input_dist = dist_from_str(args.input_dist)
        mpp = input_dist.get_max_prob_point()
        ra, dec = mpp[0], mpp[1]
    return SkyGrid(
        ra, dec, args.instruments, args.trigger_time
    )


def make_multi_det_grid(args):
    """Construct a sky grid for a network of multiple detectors,
    using the time delay among detectors to determine the required spacing
    between points.
    """
    detectors = [Detector(d) for d in args.instruments]
    detector_pairs = list(itertools.combinations(detectors, 2))

    # Initialize a proposal distribution and seed the first point of the grid
    # based on the input arguments
    if (args.ra and args.dec and args.sky_error) is not None:
        # --ra/--dec/--sky-error arguments
        assert args.coverage is None
        sky_dist = pycbc.distributions.UniformDiskSky(
            mean_ra=args.ra, mean_dec=args.dec, radius=args.sky_error
        )
        grid = np.vstack((np.empty((0, 2)), (args.ra, args.dec)))
    elif args.input_dist == 'UniformSky()':
        # Covering the whole sky
        assert args.coverage is None
        sky_dist = dist_from_str(args.input_dist)
        grid = np.vstack((np.empty((0, 2)), (0.0, 0.0)))
    else:
        # Restricted non-uniform input distribution
        assert args.coverage is not None
        sky_dist = dist_from_str(args.input_dist)
        grid = np.vstack(
            (
                np.empty((0, 2)),
                np.reshape(sky_dist.get_max_prob_point(), (1, 2))
            )
        )
        sky_dist = sky_dist.to_uniform_patch(args.coverage)

    # Calculate the max possible light travel times between detector pairs
    max_light_travel_times = np.array([
        a.light_travel_time_to_detector(b)
        for a, b in detector_pairs
    ])

    while True:
        prev_size = grid.shape[0]
        sky_dist_samples = sky_dist.rvs(size=10000)
        sky_ra = sky_dist_samples['ra']
        sky_dec = sky_dist_samples['dec']
        for prop_ra, prop_dec in zip(sky_ra, sky_dec):
            # Calculate the light travel times between detector pairs
            # for the given proposed sky positions
            light_travel_times = np.array([
                a.time_delay_from_detector(
                    b, prop_ra, prop_dec, args.trigger_time
                )
                for a, b in detector_pairs
            ])

            # Calculate the required angular spacing between the sky points
            ang_spacings = (2 * args.timing_uncertainty) / np.sqrt(
                max_light_travel_times ** 2 - light_travel_times ** 2
            )
            angular_spacing = np.min(ang_spacings)

            # FIXME it is not necessary to calculate *all* distances,
            # we can reject as soon as we find one below threshold!
            # But in practice this is fast enough already.
            dist = angular_distance((prop_ra, prop_dec), grid)
            if dist > angular_spacing:
                # far enough from other points, accept
                grid = np.vstack((grid, (prop_ra, prop_dec)))
        num_new_accepted = grid.shape[0] - prev_size
        logging.info(
            '%d points accepted, %d total', num_new_accepted, grid.shape[0]
        )
        if num_new_accepted == 0:
            # We reached our convergence criterion
            break
    sky_grid = SkyGrid(
        grid[:, 0], grid[:, 1], args.instruments, args.trigger_time
    )
    return sky_grid


parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument(
    '--ra',
    type=angle_as_radians,
    help="Right ascension of the center of the external trigger "
    "error box. Use the rad or deg suffix to specify units, "
    "otherwise radians are assumed.",
)
parser.add_argument(
    '--dec',
    type=angle_as_radians,
    help="Declination of the center of the external trigger "
    "error box. Use the rad or deg suffix to specify units, "
    "otherwise radians are assumed.",
)
parser.add_argument(
    '--instruments',
    nargs="+",
    type=str,
    required=True,
    help="List of instruments to analyze.",
)
parser.add_argument(
    '--sky-error',
    type=angle_as_radians,
    help="3-sigma confidence radius of the external trigger error "
    "box. Use the rad or deg suffix to specify units, otherwise "
    "radians are assumed.",
)
parser.add_argument(
    "--input-dist",
    default='UniformSky()',
    help="Input distribution of the sky map that you have, e.g. "
    "HealpixSky or FisherSky. See the sky location documentation for "
    "more details. If not specified, a UniformSky() distribution is used.",
)
parser.add_argument(
    '--coverage',
    type=float,
    help="Fraction of probability in the input distribution that you want "
    "to cover, when using --input-dist with a distribution that is not "
    "UniformSky.",
)
parser.add_argument(
    '--trigger-time',
    type=int,
    required=True,
    help="Time (in s) of the external trigger",
)
parser.add_argument(
    '--timing-uncertainty',
    type=float,
    default=0.0005,
    help="Timing uncertainty (in s) we are willing to accept to determine "
    "the density of points.",
)
parser.add_argument(
    '--output', required=True, help="Name of the output sky grid file."
)

args = parser.parse_args()

# Put the ifos in alphabetical order
args.instruments.sort()

pycbc.init_logging(args.verbose)

if len(args.instruments) == 1:
    sky_grid = make_single_det_grid(args)
else:
    sky_grid = make_multi_det_grid(args)

# Record the input arguments as extra attributes
if "Uniform" not in args.input_dist and args.coverage is not None:
    extra_attributes = {
        'input_distribution': args.input_dist,
        'coverage': args.coverage,
    }
elif (args.ra and args.dec and args.sky_error) == None and "Uniform" in args.input_dist:
    extra_attributes = {
        'input_distribution': args.input_dist,
    }
else:
    extra_attributes = {
        'trigger_ra': args.ra,
        'trigger_dec': args.dec,
        'sky_error': args.sky_error,
    }
extra_attributes['timing_uncertainty'] = args.timing_uncertainty

sky_grid.write_to_file(args.output, extra_attributes)
