#!/usr/bin/env python3
#
# Copyright (C) 2019  Patrick Godwin
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


from collections import defaultdict
import configparser
import argparse
import json
import logging
import sys, os
import time
import timeit

import numpy

from kafka import KafkaConsumer

from ligo.scald.io import influx


#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#

# Read command line options
def parse_command_line():

	parser = argparse.ArgumentParser(description="Online calibration aggregator")

	# directory to put everything in
	parser.add_argument("--config-file", help="Specify configuration file.")

	args = parser.parse_args()

	return args

# Parse config sections
def ConfigSectionMap(section):
	dict1 = {}
	options = Config.options(section)
	for option in options:
		try:
			dict1[option] = Config.get(section, option)
			if dict1[option] == -1:
				DebugPrint("skip: %s" % option)
		except:
			print("exception on %s!" % option)
			dict1[option] = None
	return dict1

#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#

if __name__ == '__main__':
	options = parse_command_line()
	Config = configparser.ConfigParser()
	Config.read(options.config_file)
	MonitoringConfigs = ConfigSectionMap("MonitoringConfigurations")
	CalibrationConfigs = ConfigSectionMap("CalibrationConfigurations")
	OutputConfigs = ConfigSectionMap("OutputConfigurations")

	# Read in monitoring options
	kafka_server = MonitoringConfigs["kafkaserver"]
	influx_hostname = MonitoringConfigs["influxhostname"]
	influx_port = MonitoringConfigs["influxport"]
	influx_database_name = MonitoringConfigs["influxdatabasename"]
	enable_auth = Config.getboolean("MonitoringConfigurations", "enableauth")
	enable_https = Config.getboolean("MonitoringConfigurations", "enablehttps")
	across_jobs = Config.getboolean("MonitoringConfigurations", "acrossjobs")
	data_type = MonitoringConfigs["datatype"]
	dump_period = float(MonitoringConfigs["dumpperiod"])

	# Read in ifo specific options
	ifo = CalibrationConfigs["ifo"]
	topics = ['%s_latency_production' % ifo, '%s_latency_redundant' % ifo, '%s_latency_testing' % ifo, '%s_latency_other_1' % ifo, '%s_latency_other_2' % ifo, '%s_statevector_bit_check_production' % ifo, '%s_statevector_bit_check_redundant' % ifo, '%s_statevector_bit_check_testing' % ifo, '%s_statevector_bit_check_other_1' % ifo]
	channel = OutputConfigs["frametype"]
	statevector_tags = ['TDCFs_valid', 'monitor_on']

	logging.basicConfig(level = logging.INFO, format = "%(asctime)s %(levelname)s:%(processName)s(%(process)d):%(funcName)s: %(message)s")

	consumer = KafkaConsumer(
		*topics,
		bootstrap_servers=[kafka_server],
		value_deserializer=lambda m: json.loads(m.decode('utf-8')),
		group_id='%s_aggregator' % topics[0],
		auto_offset_reset='latest',
		max_poll_interval_ms = 60000,
		session_timeout_ms=30000,
		heartbeat_interval_ms=10000,
		reconnect_backoff_ms=5000,
		reconnect_backoff_max_ms=30000
	)

	# set up aggregator sink
	agg_sink = influx.Aggregator(
		hostname=influx_hostname,
		port=influx_port,
		db=influx_database_name,
		auth=enable_auth,
		https=enable_https,
		reduce_across_tags=across_jobs
	)

	# register measurement schemas for aggregators
	for topic in topics:
		if 'latency' in topic:
			agg_sink.register_schema(topic, columns='data', column_key='data', tags='stage', tag_key='stage')
		elif 'statevector' in topic:
			agg_sink.register_schema(topic, columns='data', column_key='data', tags='check', tag_key='check')

	# start an infinite loop to keep updating and aggregating data
	while True:
		logging.info("sleeping for %.1f s" % dump_period)
		time.sleep(dump_period)

		logging.info("retrieving data from kafka")
		start = timeit.default_timer()
		data = {topic: defaultdict(lambda: {'time': [], 'fields': {'data': []}}) for topic in topics}

		### poll consumer for messages
		msg_pack = consumer.poll(timeout_ms = 1000, max_records = 1000)
		for tp, messages in msg_pack.items():
			for message in messages:
				try:
					topic = message.topic
					if 'latency' in topic:
						ifo = topic.split('_')[0]
						tag = [name for name in message.value.keys() if channel in name][0]
						formatted_tag = tag.strip(channel+'_')
						data[topic][formatted_tag]['time'].append(message.value['time'])
						data[topic][formatted_tag]['fields']['data'].append(message.value[tag])
					elif 'statevector' in topic:
						tags = [name for name in message.value.keys() if name in statevector_tags]
						for tag in tags:
							data[topic][tag]['time'].append(message.value['time'])
							data[topic][tag]['fields']['data'].append(message.value[tag])

				except KeyError: ### no metrics
					pass
		
		### convert series to numpy arrays
		for topic in topics:
			for tag in data[topic].keys():
				data[topic][tag]['time'] = numpy.array(data[topic][tag]['time'])
				data[topic][tag]['fields']['data'] = numpy.array(data[topic][tag]['fields']['data'])

		elapsed = timeit.default_timer() - start
		logging.info("time to retrieve data: %.1f s" % elapsed)

		# store and reduce data for each job
		start = timeit.default_timer()
		for topic in topics:
			logging.info("storing and reducing timeseries for measurement: %s" % topic)
			agg_sink.store_columns(topic, data[topic], aggregate=data_type)
		elapsed = timeit.default_timer() - start
		logging.info("time to store/reduce timeseries: %.1f s" % elapsed)

	# close connection to consumer if using kafka
	if consumer:
		consumer.close()

	#
	# always end on an error so that condor won't think we're done and will
	# restart us
	#

	sys.exit(1)
