Skip to content

Commit

Permalink
Merge pull request #300 from lsst/tickets/DM-45701
Browse files Browse the repository at this point in the history
DM-45701: Move dotTools to pipe_base
  • Loading branch information
timj committed Aug 9, 2024
2 parents 366e7aa + 336701f commit 8eedeb4
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 476 deletions.
1 change: 1 addition & 0 deletions doc/changes/DM-45701.removal.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The ``lsst.ctrl.mpexec.dotTools`` package has been relocated to ``lsst.pipe.base.dot_tools``.
2 changes: 1 addition & 1 deletion python/lsst/ctrl/mpexec/cmdLineFwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@
buildExecutionButler,
)
from lsst.pipe.base.all_dimensions_quantum_graph_builder import AllDimensionsQuantumGraphBuilder
from lsst.pipe.base.dot_tools import graph2dot
from lsst.pipe.base.pipeline_graph import NodeType
from lsst.utils import doImportType
from lsst.utils.logging import getLogger
from lsst.utils.threads import disable_implicit_threading

from .dotTools import graph2dot
from .executionGraphFixup import ExecutionGraphFixup
from .mpGraphExecutor import MPGraphExecutor
from .preExecInit import PreExecInit, PreExecInitLimited
Expand Down
300 changes: 24 additions & 276 deletions python/lsst/ctrl/mpexec/dotTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,141 +33,20 @@

__all__ = ["graph2dot", "pipeline2dot"]

# -------------------------------
# Imports of standard modules --
# -------------------------------
import html
import io
import re
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
from typing import Any

# -----------------------------
# Imports for other modules --
# -----------------------------
from lsst.daf.butler import DatasetType, DimensionUniverse
from lsst.pipe.base import Pipeline, connectionTypes, iterConnections
from deprecated.sphinx import deprecated
from lsst.pipe.base import Pipeline, QuantumGraph, TaskDef
from lsst.pipe.base.dot_tools import graph2dot as _graph2dot
from lsst.pipe.base.dot_tools import pipeline2dot as _pipeline2dot

if TYPE_CHECKING:
from lsst.daf.butler import DatasetRef
from lsst.pipe.base import QuantumGraph, QuantumNode, TaskDef

# ----------------------------------
# Local non-exported definitions --
# ----------------------------------

# Attributes applied to directed graph objects.
_NODELABELPOINTSIZE = "18"
_ATTRIBS = dict(
defaultGraph=dict(splines="ortho", nodesep="0.5", ranksep="0.75", pad="0.5"),
defaultNode=dict(shape="box", fontname="Monospace", fontsize="14", margin="0.2,0.1", penwidth="3"),
defaultEdge=dict(color="black", arrowsize="1.5", penwidth="1.5"),
task=dict(style="filled", color="black", fillcolor="#B1F2EF"),
quantum=dict(style="filled", color="black", fillcolor="#B1F2EF"),
dsType=dict(style="rounded,filled,bold", color="#00BABC", fillcolor="#F5F5F5"),
dataset=dict(style="rounded,filled,bold", color="#00BABC", fillcolor="#F5F5F5"),
@deprecated(
"graph2dot should now be imported from lsst.pipe.base.dot_tools. Will be removed in v29.",
version="v27.0",
category=FutureWarning,
)


def _renderDefault(type: str, attribs: dict[str, str], file: io.TextIOBase) -> None:
"""Set default attributes for a given type."""
default_attribs = ", ".join([f'{key}="{val}"' for key, val in attribs.items()])
print(f"{type} [{default_attribs}];", file=file)


def _renderNode(file: io.TextIOBase, nodeName: str, style: str, labels: list[str]) -> None:
"""Render GV node"""
label = r"</TD></TR><TR><TD>".join(labels)
attrib_dict = dict(_ATTRIBS[style], label=label)
pre = '<<TABLE BORDER="0" CELLPADDING="5"><TR><TD>'
post = "</TD></TR></TABLE>>"
attrib = ", ".join(
[
f'{key}="{val}"' if key != "label" else f"{key}={pre}{val}{post}"
for key, val in attrib_dict.items()
]
)
print(f'"{nodeName}" [{attrib}];', file=file)


def _renderTaskNode(nodeName: str, taskDef: TaskDef, file: io.TextIOBase, idx: Any = None) -> None:
"""Render GV node for a task"""
labels = [
f'<B><FONT POINT-SIZE="{_NODELABELPOINTSIZE}">' + html.escape(taskDef.label) + "</FONT></B>",
html.escape(taskDef.taskName),
]
if idx is not None:
labels.append(f"<I>index:</I>&nbsp;{idx}")
if taskDef.connections:
# don't print collection of str directly to avoid visually noisy quotes
dimensions_str = ", ".join(sorted(taskDef.connections.dimensions))
labels.append(f"<I>dimensions:</I>&nbsp;{html.escape(dimensions_str)}")
_renderNode(file, nodeName, "task", labels)


def _renderQuantumNode(
nodeName: str, taskDef: TaskDef, quantumNode: QuantumNode, file: io.TextIOBase
) -> None:
"""Render GV node for a quantum"""
labels = [f"{quantumNode.nodeId}", html.escape(taskDef.label)]
dataId = quantumNode.quantum.dataId
assert dataId is not None, "Quantum DataId cannot be None"
labels.extend(f"{key} = {dataId[key]}" for key in sorted(dataId.required.keys()))
_renderNode(file, nodeName, "quantum", labels)


def _renderDSTypeNode(name: str, dimensions: list[str], file: io.TextIOBase) -> None:
"""Render GV node for a dataset type"""
labels = [f'<B><FONT POINT-SIZE="{_NODELABELPOINTSIZE}">' + html.escape(name) + "</FONT></B>"]
if dimensions:
labels.append("<I>dimensions:</I>&nbsp;" + html.escape(", ".join(sorted(dimensions))))
_renderNode(file, name, "dsType", labels)


def _renderDSNode(nodeName: str, dsRef: DatasetRef, file: io.TextIOBase) -> None:
"""Render GV node for a dataset"""
labels = [html.escape(dsRef.datasetType.name), f"run: {dsRef.run!r}"]
labels.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.required.keys()))
_renderNode(file, nodeName, "dataset", labels)


def _renderEdge(fromName: str, toName: str, file: io.TextIOBase, **kwargs: Any) -> None:
"""Render GV edge"""
if kwargs:
attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()])
print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file)
else:
print(f'"{fromName}" -> "{toName}";', file=file)


def _datasetRefId(dsRef: DatasetRef) -> str:
"""Make an identifying string for given ref"""
dsId = [dsRef.datasetType.name]
dsId.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.required.keys()))
return ":".join(dsId)


def _makeDSNode(dsRef: DatasetRef, allDatasetRefs: dict[str, str], file: io.TextIOBase) -> str:
"""Make new node for dataset if it does not exist.
Returns node name.
"""
dsRefId = _datasetRefId(dsRef)
nodeName = allDatasetRefs.get(dsRefId)
if nodeName is None:
idx = len(allDatasetRefs)
nodeName = f"dsref_{idx}"
allDatasetRefs[dsRefId] = nodeName
_renderDSNode(nodeName, dsRef, file)
return nodeName


# ------------------------
# Exported definitions --
# ------------------------


def graph2dot(qgraph: QuantumGraph, file: Any) -> None:
"""Convert QuantumGraph into GraphViz digraph.
Expand All @@ -183,45 +62,19 @@ def graph2dot(qgraph: QuantumGraph, file: Any) -> None:
Raises
------
`OSError` is raised when output file cannot be open.
`ImportError` is raised when task class cannot be imported.
OSError
Raised if the output file cannot be opened.
ImportError
Raised if the task class cannot be imported.
"""
# open a file if needed
close = False
if not hasattr(file, "write"):
file = open(file, "w")
close = True

print("digraph QuantumGraph {", file=file)
_renderDefault("graph", _ATTRIBS["defaultGraph"], file)
_renderDefault("node", _ATTRIBS["defaultNode"], file)
_renderDefault("edge", _ATTRIBS["defaultEdge"], file)

allDatasetRefs: dict[str, str] = {}
for taskId, taskDef in enumerate(qgraph.taskGraph):
quanta = qgraph.getNodesForTask(taskDef)
for qId, quantumNode in enumerate(quanta):
# node for a task
taskNodeName = f"task_{taskId}_{qId}"
_renderQuantumNode(taskNodeName, taskDef, quantumNode, file)

# quantum inputs
for dsRefs in quantumNode.quantum.inputs.values():
for dsRef in dsRefs:
nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
_renderEdge(nodeName, taskNodeName, file)

# quantum outputs
for dsRefs in quantumNode.quantum.outputs.values():
for dsRef in dsRefs:
nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
_renderEdge(taskNodeName, nodeName, file)

print("}", file=file)
if close:
file.close()
_graph2dot(qgraph, file)


@deprecated(
"pipeline2dot should now be imported from lsst.pipe.base.dot_tools. Will be removed in v29.",
version="v27.0",
category=FutureWarning,
)
def pipeline2dot(pipeline: Pipeline | Iterable[TaskDef], file: Any) -> None:
"""Convert `~lsst.pipe.base.Pipeline` into GraphViz digraph.
Expand All @@ -238,114 +91,9 @@ def pipeline2dot(pipeline: Pipeline | Iterable[TaskDef], file: Any) -> None:
Raises
------
`OSError` is raised when output file cannot be open.
`ImportError` is raised when task class cannot be imported.
`MissingTaskFactoryError` is raised when TaskFactory is needed but not
provided.
OSError
Raised if the output file cannot be opened.
ImportError
Raised if the task class cannot be imported.
"""
universe = DimensionUniverse()

def expand_dimensions(connection: connectionTypes.BaseConnection) -> list[str]:
"""Return expanded list of dimensions, with special skypix treatment.
Parameters
----------
connection : `list` [`str`]
Connection to examine.
Returns
-------
dimensions : `list` [`str`]
Expanded list of dimensions.
"""
dimension_set = set()
if isinstance(connection, connectionTypes.DimensionedConnection):
dimension_set = set(connection.dimensions)
skypix_dim = []
if "skypix" in dimension_set:
dimension_set.remove("skypix")
skypix_dim = ["skypix"]
dimensions = universe.conform(dimension_set)
return list(dimensions.names) + skypix_dim

# open a file if needed
close = False
if not hasattr(file, "write"):
file = open(file, "w")
close = True

print("digraph Pipeline {", file=file)
_renderDefault("graph", _ATTRIBS["defaultGraph"], file)
_renderDefault("node", _ATTRIBS["defaultNode"], file)
_renderDefault("edge", _ATTRIBS["defaultEdge"], file)

allDatasets: set[str | tuple[str, str]] = set()
if isinstance(pipeline, Pipeline):
# TODO: DM-40639 will rewrite this code and finish off the deprecation
# of toExpandedPipeline.
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
pipeline = pipeline.toExpandedPipeline()

# The next two lines are a workaround until DM-29658 at which time metadata
# connections should start working with the above code
labelToTaskName = {}
metadataNodesToLink = set()

for idx, taskDef in enumerate(sorted(pipeline, key=lambda x: x.label)):
# node for a task
taskNodeName = f"task{idx}"

# next line is workaround until DM-29658
labelToTaskName[taskDef.label] = taskNodeName

_renderTaskNode(taskNodeName, taskDef, file, None)

metadataRePattern = re.compile("^(.*)_metadata$")
for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name):
if attr.name not in allDatasets:
dimensions = expand_dimensions(attr)
_renderDSTypeNode(attr.name, dimensions, file)
allDatasets.add(attr.name)
nodeName, component = DatasetType.splitDatasetTypeName(attr.name)
_renderEdge(attr.name, taskNodeName, file)
# connect component dataset types to the composite type that
# produced it
if component is not None and (nodeName, attr.name) not in allDatasets:
_renderEdge(nodeName, attr.name, file)
allDatasets.add((nodeName, attr.name))
if nodeName not in allDatasets:
dimensions = expand_dimensions(attr)
_renderDSTypeNode(nodeName, dimensions, file)
# The next if block is a workaround until DM-29658 at which time
# metadata connections should start working with the above code
if (match := metadataRePattern.match(attr.name)) is not None:
matchTaskLabel = match.group(1)
metadataNodesToLink.add((matchTaskLabel, attr.name))

for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name):
if attr.name not in allDatasets:
dimensions = expand_dimensions(attr)
_renderDSTypeNode(attr.name, dimensions, file)
allDatasets.add(attr.name)
# use dashed line for prerequisite edges to distinguish them
_renderEdge(attr.name, taskNodeName, file, style="dashed")

for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name):
if attr.name not in allDatasets:
dimensions = expand_dimensions(attr)
_renderDSTypeNode(attr.name, dimensions, file)
allDatasets.add(attr.name)
_renderEdge(taskNodeName, attr.name, file)

# This for loop is a workaround until DM-29658 at which time metadata
# connections should start working with the above code
for matchLabel, dsTypeName in metadataNodesToLink:
# only render an edge to metadata if the label is part of the current
# graph
if (result := labelToTaskName.get(matchLabel)) is not None:
_renderEdge(result, dsTypeName, file)

print("}", file=file)
if close:
file.close()
_pipeline2dot(pipeline, file)
Loading

0 comments on commit 8eedeb4

Please sign in to comment.