"""Node class and associated graph traversal functionality."""
from abc import ABC, abstractmethod
import networkx as nx
[docs]class Node(ABC):
"""Abstract node class."""
[docs] @abstractmethod
def parents(self):
"""Return list of parent 'None' objects upon which this object depends."""
pass
[docs] def ancestors(self):
"""
Return ancestry tree of node.
This method works by recursively calling itself on all parents of each
Node, which themselves are assumed to be Nodes.
Returns
=======
anc: list
list of (object, ancestor tree) pairs.
"""
return [(p, p.ancestors()) for p in self.parents()]
[docs]def node_label(obj):
"""Compute a node label for this object."""
objname = obj.__class__.__name__
if "name" in obj.__dict__:
objname = obj.name
try:
label = str(objname) + "."
except AttributeError:
label = ""
label += "." + str(id(obj)) + "." + obj.get_hash()
return label
[docs]def fill_graph(g, o):
"""Add given node to graph and all ancestors."""
olab = node_label(o)
for p in o.parents():
pg = fill_graph(p)
g.add_edge(pg, olab, type="ParentOf")
return olab
[docs]def compute_dag(outputs):
"""Compute a directed acyclic graph with the given outputs."""
g = nx.DiGraph()
for o in outputs:
fill_graph(g, o)
return g
[docs]def unique_objects(l):
"""Given list of objects, ensure that they are all distinct python objects."""
ids = [id(li) for li in l]
keptids = []
kept = []
for (i, o) in zip(ids, l):
if i in keptids:
continue
keptids.append(i)
kept.append(o)
return kept
[docs]def visualize(g, filename, outputs):
"""Draw computation DAG using graphviz."""
a = nx.nx_agraph.to_agraph(g)
for n in a.nodes():
if n.attr["type"] == "Tool":
n.attr["shape"] = "box"
n.attr["style"] = "filled"
n.attr["color"] = "#CCDDFF"
elif n.attr["type"] == "Data":
n.attr["shape"] = "oval"
n.attr["style"] = "filled"
n.attr["color"] = "#CCFFDD"
# TODO: Check if this is an input or output and color accordingly
for e in a.edges():
if e.attr["type"] == "Provides":
e.attr["color"] == "green"
elif e.attr["type"] == "InputTo":
e.attr["color"] == "magenta"
a.layout(prog="dot")
a.draw(filename)