#! /usr/bin/env python

# Copyright (C) 2025- The University of Notre Dame
# This software is distributed under the GNU General Public License.
# See the file COPYING for details.

# vine_plot_taskgraph converts the raw taskgraph format into an external form that can be plotted.
# The taskgraph written by the TaskVine manager comes in one of two formats:
#
# 1 - The old format was written directly out as an input file for the Graphviz DOT language. This was easy to use dot specifically, but a bit difficult to parse and use by other tools. If the input begins with "digraph g {" then this tool just turns around and calls the DOT program directly to generate a PDF.
#
# 2 - The new format is somewhat compressed and writes out exactly one line for each task and file:
#
# taskvine taskgraph version 2
# TASK taskid \"name\" INPUTS fileid1 fileid2 fileid3 ... OUTPUTS fileid4 fileid5 ...
# FILE fileid \"name\" size
#
# This format is easier to parse, and permits us to adjust the style of the graph
# post-runtime, and facilitates converting to other rendering formats in the future.

import os
import sys
import shlex
import shutil
import argparse

# Number of errors to tolerate on the input before giving up
max_error_count = 5

# Dictionary of tasks, key is taskid.
taskinfo = {}

# Dictionary of files, key is fileid.
fileinfo = {}

# Determine if taskgraph is in old format by looking for "digraph g {"
# which is the preamble for the graphviz dot format.

def taskgraph_is_old_format( filename ):
    with open(filename) as file:
        line = file.readline()
        if line.startswith("digraph"):
            return True
        else:
            return False
            
# Parse the taskgraph file and load all nodes/edges into the taskinfo and fileinfo dictionaries.

def parse_taskgraph_file( args, filename ):

    global taskinfo
    global fileinfo
    global max_error_count
    
    linenum = 0
    with open(filename) as file:
        for nline in file:
            line = nline.rstrip()
            linenum = linenum+1

            if line.startswith("#"):
                # skip over comments
                continue
            elif line.startswith("TASK"):
                # use shlex.split to respect quoted fields in the input
                fields = shlex.split(line)

                # TASK line starts with task properties
                taskid = fields[1]
                name = fields[2]

                # list of fileids following INPUT
                i = 4
                inputs = list()
                while fields[i] != "OUTPUTS":
                    inputs.append(fields[i])
                    i=i+1

                # list of fileids following OUTPUT
               	i=i+1 
                outputs = list()
                while i < len(fields):
                    outputs.append(fields[i])
                    i=i+1

                # record the complete task
                taskinfo[taskid] = { "name" : name, "inputs" : inputs, "outputs" : outputs }

            elif line.startswith("FILE"):
                # use shlex.split to respect quoted fields in the input
                (x,fileid,name,size) = shlex.split(line)
                fileinfo[fileid] = { "name" : name, "size" : size, "source": None, "sinks": list() }

            else:
                print("WARNING: invalid data at {}:{}: {}\n".format(filename,linenum,line))
                error_count=error_count+1
                if(error_count>max_error_count):
                    print("aborting after {} errors".format(max_error_count))
                    sys.exit(1)


# Write the contents of the taskgraph back out as a graphviz dot file.
# Apply various styling parameters to control the output.
                    
def print_taskgraph_as_dot( args, filename ):

    global taskinfo
    global fileinfo
    
    with open(filename,"w") as fp:

        print("digraph g {",file=fp)

        # Emit all of the tasks with a constant green style
        print("node [style=filled,font=Helvetica,fontsize={},shape=oval,color=green];".format(args.fontsize), file=fp)
        for t in taskinfo:
            if args.labels:
                label = taskinfo[t]["name"]
            else:
                label = ""
            print("\"{}\" [label=\"{}\"];".format(t,label),file=fp)

        # Emit all of the files with a constant style
        print("node [style=filled,font=Helvetica,fontsize={},shape=rect,color=blue];".format(args.fontsize), file=fp)
        for f in fileinfo:
            if args.labels:
                label = fileinfo[f]["name"]
            else:
                label = ""
            print("\"{}\" [label=\"{}\"];".format(f,label),file=fp)

        # Finally emit all of the edges.
        for t in taskinfo:
            for f in taskinfo[t]["inputs"]:
                print("\"{}\" -> \"{}\";".format(f,t),file=fp);
            for f in taskinfo[t]["outputs"]:
                print("\"{}\" -> \"{}\";".format(t,f),file=fp);

        print("}\n",file=fp)

        
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="vine_plot_taskgraph",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description="Plot TaskVine workflow structure from a taskgraph log file."
    )

    parser.add_argument("taskgraph", help="taskgraph log file")
    parser.add_argument("--output", nargs="?", required=True, help="output PDF file to create")
    parser.add_argument("--labels", action="store_true", help="display task/file labels")
    parser.add_argument("--fontsize",nargs="?", default=10 ,help="fontsize of task/file labels, in points")

    args = parser.parse_args()

    print("reading taskgraph {}...".format(args.taskgraph))

    if taskgraph_is_old_format(args.taskgraph):
        print("taskgraph is in version 1 format")
        dotfile = args.taskgraph
    else:
        print("taskgraph is in version 2 format")
        parse_taskgraph_file( args, args.taskgraph )

        dotfile = "{}.dot".format(args.taskgraph)
        
        print("writing dot graph {}...".format(dotfile))
        print_taskgraph_as_dot( args, dotfile )

    path_to_dot = shutil.which("dot")
    if path_to_dot:
        print("found graphviz dot at {}".format(path_to_dot))
    else:
        print("{}: could not find graphviz dot in PATH!")
        print("Please install graphviz before proceeding.")
        sys.exit(1)

    print("running dot to create {} (this may be slow)...".format(args.output))
    os.system("dot -Tpdf {} > {}".format(dotfile,args.output))
    
