#!/usr/bin/env python

# Copyright (C) 2025- The University of Notre Dame
# This software is distributed under the GNU General Public License.
# See the file COPYING for details.

# Parse a taskvine performance log, and produce ten different plots
# that summarize different aspects of task performance, worker utilization, etc.

import sys
import re
import os
import getopt
import matplotlib.pyplot as plt
from functools import reduce

# Default output format, can choose any type supported by matplotlib
extension = "svg"

# Sample the log file at this time period, to keep graphs readable:
time_resolution = 30

# Try these different time resolutions to get a reasonable number of points:
time_resolution_list = [30,5,1]

# If unspecified, then plot all the valid range.
x_range = None

# Size of the output image in inches
image_width = 640
image_height = 480
image_dpi = 100

# An array of the valid timestamps loaded from the log.
times = None

# A dictionary of records, keyed by timestamp.
log_entries = None

# Multiplication factors and for each label abbreviation type.
unit_labels = {
    "s": "seconds",
    "m": "minutes",
    "h": "hours",
    "d": "days",
    "GB": "GB",
    "i": "",
}
unit_factors = {
    None: 1,
    "s": 1,
    "m": 60,
    "h": 3600,
    "d": 86400,
    "KB": 1024,
    "GB": 1073741824,
}

# Default X unit is in minutes.
x_units = "m"

# Fixed colors to use in each field in order
colors = ["#1b9377", "#d95f02", "#7570b3", "#e7298a", "#66a61e", "#e6ab02", "#a676ad"]

# Search for the log file (comment) header which indicates each of the field names.


def read_fields(file, lines_patience=10):
    for line in file:
        if line[0] != "#":
            lines_patience = lines_patience - 1
        else:
            return line.strip("#\n\r\t ").split()
        if lines_patience < 1:
            break
    sys.stderr.write(
        "Could not find fields descriptions (a line such as # timestamp workers_....)\n"
    )
    sys.exit(1)


def time_to_resolution(t):
    return (t - (t % (time_resolution * 1000000))) / 1000000


def time_field_p(field):
    return re.search("^time_.*$", field)


# Consume the log file, reading in all of the valid timestamps to "times"
# and data records to log_entries for each timestamp.


def read_log_entries(file, fields, reset_until_dispatch=False):
    log_entries = {}
    idxs = range(0, len(fields))
    pairs = list(zip(idxs, fields))
    epoch = None
    count_lines = 0
    prev_line = None

    for line in file:
        count_lines = count_lines + 1
        try:
            numbers = [float(x) for x in line.split()]
            record = {}

            for i, field in pairs:
                if field == "timestamp":
                    record["time_step"] = time_to_resolution(numbers[i])
                    if not epoch:
                        epoch = numbers[i] / 1000000
                    record["timestamp"] = numbers[i] / 1000000 - epoch
                elif time_field_p(field):
                    record[field] = numbers[i] / 1000000
                else:
                    record[field] = numbers[i]

            if reset_until_dispatch:
                if record["tasks_dispatched"] > 0:
                    reset_until_dispatch = False
                else:
                    epoch = None
                    continue

            if prev_line and record["time_step"] == prev_line["time_step"]:
                continue

            log_entries[record["timestamp"]] = record

            if prev_line is None:
                prev_line = record

            delta = {}
            for key in record.keys():
                delta[key] = record[key] - prev_line[key]

            total_time = reduce(
                lambda x, y: x + y,
                [
                    delta[m]
                    for m in [
                        "time_status_msgs",
                        "time_internal",
                        "time_polling",
                        "time_send",
                        "time_receive",
                        "time_application",
                    ]
                ],
            )

            # One of time_* may be larger than the given resolution, and here we force everything to the same height
            delta["timestamp"] = total_time

            res_factor = 100
            if delta["timestamp"] > 0:
                record["stack_time_status_msgs"] = res_factor * (
                    delta["time_status_msgs"] / delta["timestamp"]
                )
                record["stack_time_internal"] = (
                    res_factor * (delta["time_internal"] / delta["timestamp"])
                    + record["stack_time_status_msgs"]
                )
                record["stack_time_polling"] = (
                    res_factor * (delta["time_polling"] / delta["timestamp"])
                    + record["stack_time_internal"]
                )
                record["stack_time_send"] = (
                    res_factor * (delta["time_send"] / delta["timestamp"])
                    + record["stack_time_polling"]
                )
                record["stack_time_receive"] = (
                    res_factor * (delta["time_receive"] / delta["timestamp"])
                    + record["stack_time_send"]
                )
                record["stack_time_application"] = (
                    res_factor * (delta["time_application"] / delta["timestamp"])
                    + record["stack_time_receive"]
                )
                record["stack_time_other"] = res_factor
            else:
                record["stack_time_status_msgs"] = 0
                record["stack_time_internal"] = 0
                record["stack_time_polling"] = 0
                record["stack_time_send"] = 0
                record["stack_time_receive"] = 0
                record["stack_time_application"] = 0
                record["stack_time_other"] = 0

            record["bytes_sent"] /= unit_factors["GB"]
            record["bytes_received"] /= unit_factors["GB"]
            record["bytes_transfered"] = record["bytes_sent"] + record["bytes_received"]

            record["total_disk"] /= unit_factors["KB"]  # MB to GB
            record["committed_disk"] /= unit_factors["KB"]  # MB to GB

            try:
                record["inuse_cache"] /= unit_factors["KB"]  # MB to GB
            except KeyError:
                record["inuse_cache"] = (
                    0  # inuse_cache is not a stat that all logs have
                )

            prev_line = record

        except ValueError:
            sys.stderr.write(
                "Line %d has an invalid value. Ignoring.\n" % (count_lines,)
            )
            continue
        except IndexError:
            sys.stderr.write(
                "Line %d has less than %d fields. Aborting.\n"
                % (count_lines, len(fields))
            )
            sys.exit(1)

    return log_entries


def sort_time(log_entries):
    times = []
    for k in log_entries.keys():
        times.append(k)
    times.sort()
    return times


class VinePlot:
    def __init__(
        self,
        title,
        ylabel,
        fields,
        labels=None,
        x_units=x_units,
        y_units=None,
        range=x_range,
        stack=False,
        logscale=False,
    ):
        self.title = title
        self.fields = fields
        self.labels = labels or self.fields
        self.x_units = x_units
        self.y_units = y_units
        self.ylabel = ylabel
        self.range = range
        self.stack = stack
        self.logscale = logscale

    def plot(self, output):
        try:
            self.doplot(output)
        except IOError:
            sys.stderr.write("Could not generate file %s.\n" % (output,))
            exit(1)

    def doplot(self, output):

        plt.figure(
            figsize=(image_width / image_dpi, image_height / image_dpi), dpi=image_dpi
        )

        for i in range(len(self.fields)):
            (x, y) = self.get_data_for_field(self.fields[i])
            if self.stack:
                plt.bar(x, y, label=self.fields[i], color=colors[i])
            else:
                plt.plot(x, y, label=self.fields[i], color=colors[i])

        if self.y_units:
            plt.ylabel("{} ({})".format(self.ylabel, self.y_units))

        if self.logscale:
            plt.yscale("log")

        plt.title(self.title)
        plt.xlabel("manager lifetime in {}".format(unit_labels[self.x_units]))
        plt.legend(loc="upper left")
        plt.savefig(output)
        plt.close()

    def get_data_for_field(self, field):

        time_scale = unit_factors[self.x_units]
        # if a time field, then scale
        mod = time_field_p(field) and unit_factors[self.y_units] or 1

        xvalues = list()
        yvalues = list()

        for t in times:
            r = log_entries[t]
            try:
                x = t / time_scale
                y = r[field] / mod
                xvalues.append(x)
                yvalues.append(y)
            except KeyError:
                sys.stderr.write("Field '%s' does not exist in the log\n" % (field,))
                break

        return (xvalues, yvalues)


def show_usage():
    usage_string = """
{command} [options] <vine-performance-log>
        -h                  This message.
        -l                  Graph y-axis with logscale.
        -D                  Start plots when the first task is dispatched.
        -o <prefix-output>  Generate prefix-output.{{workers,workers-accum,
                                                    tasks,tasks-accum,
                                                    time-manager,time-workers,
                                                    transfer,workers-disk}}.{extension}
                            Default is <work-queue-log>.
        -r <range>          Range of time to plot, in time units (see -u) from
                            the start of execution. Of the form: min:max,
                            min:, or :max.
        -s <seconds>        Sample log every <seconds> (default is {time_resolution}).
        -u <time-unit>      Time scale to output. One of s,m,h or d, for seconds,
                            minutes (default), hours or days.
        -T <output-format>  Set output format. Default is {extension}.
        -g <width,height>   Size of each plot. Default is {image_width},{image_height}.
""".format(
        command=os.path.basename(
            sys.argv[0],
        ),
        extension=extension,
        time_resolution=time_resolution,
        image_width=image_width,
        image_height=image_height,
    )
    print(usage_string)


if __name__ == "__main__":

    try:
        optlist, args = getopt.getopt(sys.argv[1:], "c:Dg:hlo:r:s:T:u:")
    except getopt.GetoptError as e:
        sys.stderr.write(str(e) + "\n")
        show_usage()
        sys.exit(1)

    if len(args) < 1:
        show_usage()
        sys.exit(1)

    logname = args[0]
    prefix = logname
    reset_until_dispatch = False
    logscale = False

    for opt, arg in optlist:
        if opt == "-o":
            prefix = arg
        elif opt == "-h":
            show_usage()
            sys.exit(0)
        elif opt == "-r":
            x_range = arg
        elif opt == "-s":
            # If a specific resolution is given, use that one alone
            time_resolution_list = [float(arg)]
        elif opt == "-g":
            (image_width, image_height) = arg.split(",")
        elif opt == "-T":
            format = arg
            extension = format
        elif opt == "-u":
            if arg in unit_factors:
                x_units = arg
            else:
                sys.stderr.write(
                    "Time scale factor '%s' is not valid. Options: s,m,h or d.\n"
                    % (arg,)
                )
                exit(1)
        elif opt == "-D":
            reset_until_dispatch = True
        elif opt == "-l":
            logscale = True

    try:

        for time_resolution in time_resolution_list:
            
            with open(logname) as file:
                log_entries = read_log_entries(
                    file, read_fields(file), reset_until_dispatch
                )

                times = sort_time(log_entries)
    
            # If this is a small log file and we didn't get many entries,
            # then reload again with a finer time resolution.
            
            if(len(times)<30):
                continue
                
        plots = {}
        plots["workers"] = VinePlot(
            x_units=x_units,
            ylabel="number of workers",
            range=x_range,
            logscale=logscale,
            title="workers instantaneous counts",
            fields=["workers_connected", "workers_idle", "workers_busy"],
            labels=["connected", "idle", "busy"],
        )

        plots["workers-accum"] = VinePlot(
            x_units=x_units,
            ylabel="number of workers",
            range=x_range,
            logscale=logscale,
            title="workers cumulative counts",
            fields=[
                "workers_joined",
                "workers_removed",
                "workers_released",
                "workers_slow",
                "workers_idled_out",
                "workers_lost",
            ],
            labels=["joined", "removed", "released", "slow", "idled out,", "lost"],
        )

        plots["workers-disk"] = VinePlot(
            x_units=x_units,
            y_units="GB",
            ylabel="disk space",
            range=x_range,
            logscale=logscale,
            title="workers disk utilization",
            fields=["total_disk", "committed_disk", "inuse_cache"],
            labels=["total", "committed to tasks", "used by cache"],
        )

        plots["tasks"] = VinePlot(
            x_units=x_units,
            ylabel="number of tasks",
            range=x_range,
            logscale=logscale,
            title="tasks instantaneous counts",
            fields=[
                "tasks_waiting",
                "tasks_on_workers",
                "tasks_running",
                "tasks_with_results",
            ],
            labels=["waiting", "on workers", "running", "with results"],
        )

        plots["tasks-capacities"] = VinePlot(
            x_units=x_units,
            ylabel="number of tasks",
            range=x_range,
            logscale=logscale,
            title="tasks instantaneous capacities",
            fields=["tasks_running", "capacity_instantaneous", "capacity_weighted"],
            labels=["tasks running", "tasks capacity raw", "tasks capacity weighted"],
        )

        plots["tasks-accum"] = VinePlot(
            x_units=x_units,
            ylabel="number of tasks",
            range=x_range,
            logscale=logscale,
            title="tasks cumulative counts",
            fields=[
                "tasks_submitted",
                "tasks_dispatched",
                "tasks_done",
                "tasks_failed",
                "tasks_cancelled",
                "tasks_exhausted_attempts",
            ],
            labels=[
                "submitted",
                "dispatched",
                "done",
                "failed",
                "cancelled",
                "exhausted attempts",
            ],
        )

        plots["time-manager"] = VinePlot(
            x_units=x_units,
            ylabel="cumulative time",
            y_units=x_units,
            range=x_range,
            logscale=logscale,
            title="cumulative times at the manager",
            fields=[
                "time_send",
                "time_receive",
                "time_polling",
                "time_status_msgs",
                "time_internal",
                "time_application",
            ],
            labels=[
                "send",
                "receive",
                "manager polling",
                "manager status msgs",
                "manager internal",
                "manager application",
            ],
        )

        plots["time-workers"] = VinePlot(
            x_units=x_units,
            y_units="h",
            ylabel="cumulative time",
            range=x_range,
            logscale=logscale,
            title="cumulative times at workers",
            fields=["time_execute", "time_execute_good", "time_execute_exhaustion"],
            labels=["execute", "execute good", "execute exhaustion"],
        )

        plots["transfer"] = VinePlot(
            x_units=x_units,
            y_units="GB",
            ylabel="size",
            range=x_range,
            logscale=logscale,
            title="manager data transfer",
            fields=["bytes_sent", "bytes_received"],
            labels=["sent", "received"],
        )

        plots["times-stacked"] = VinePlot(
            x_units=x_units,
            y_units="%",
            ylabel="utilization per time sample",
            range=x_range,
            stack=True,
            logscale=logscale,
            title="manager time proportions",
            fields=[
                "stack_time_other",
                "stack_time_application",
                "stack_time_receive",
                "stack_time_send",
                "stack_time_polling",
                "stack_time_internal",
                "stack_time_status_msgs",
            ],
            labels=[
                "other",
                "application",
                "receive",
                "send",
                "polling",
                "internal",
                "status msgs",
            ],
        )

        for name in plots.keys():
            plots[name].plot(prefix + "." + name + "." + extension)

    except IOError:
        sys.stderr.write("Could not open file %s\n" % (logname,))
        sys.exit(1)

# vim: tabstop=8 shiftwidth=4 softtabstop=4 expandtab shiftround autoindent
