Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-45701: Move dotTools to pipe_base #300

Merged
merged 2 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/DM-45701.removal.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The ``lsst.ctrl.mpexec.dotTools`` package has been relocated to ``lsst.pipe.base.dot_tools``.
2 changes: 1 addition & 1 deletion python/lsst/ctrl/mpexec/cmdLineFwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@
buildExecutionButler,
)
from lsst.pipe.base.all_dimensions_quantum_graph_builder import AllDimensionsQuantumGraphBuilder
from lsst.pipe.base.dot_tools import graph2dot
from lsst.pipe.base.pipeline_graph import NodeType
from lsst.utils import doImportType
from lsst.utils.logging import getLogger
from lsst.utils.threads import disable_implicit_threading

from .dotTools import graph2dot
from .executionGraphFixup import ExecutionGraphFixup
from .mpGraphExecutor import MPGraphExecutor
from .preExecInit import PreExecInit, PreExecInitLimited
Expand Down
300 changes: 24 additions & 276 deletions python/lsst/ctrl/mpexec/dotTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,141 +33,20 @@

__all__ = ["graph2dot", "pipeline2dot"]

# -------------------------------
# Imports of standard modules --
# -------------------------------
import html
import io
import re
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
from typing import Any

# -----------------------------
# Imports for other modules --
# -----------------------------
from lsst.daf.butler import DatasetType, DimensionUniverse
from lsst.pipe.base import Pipeline, connectionTypes, iterConnections
from deprecated.sphinx import deprecated
from lsst.pipe.base import Pipeline, QuantumGraph, TaskDef
from lsst.pipe.base.dot_tools import graph2dot as _graph2dot
from lsst.pipe.base.dot_tools import pipeline2dot as _pipeline2dot

if TYPE_CHECKING:
from lsst.daf.butler import DatasetRef
from lsst.pipe.base import QuantumGraph, QuantumNode, TaskDef

# ----------------------------------
# Local non-exported definitions --
# ----------------------------------

# Attributes applied to directed graph objects.
_NODELABELPOINTSIZE = "18"
_ATTRIBS = dict(
defaultGraph=dict(splines="ortho", nodesep="0.5", ranksep="0.75", pad="0.5"),
defaultNode=dict(shape="box", fontname="Monospace", fontsize="14", margin="0.2,0.1", penwidth="3"),
defaultEdge=dict(color="black", arrowsize="1.5", penwidth="1.5"),
task=dict(style="filled", color="black", fillcolor="#B1F2EF"),
quantum=dict(style="filled", color="black", fillcolor="#B1F2EF"),
dsType=dict(style="rounded,filled,bold", color="#00BABC", fillcolor="#F5F5F5"),
dataset=dict(style="rounded,filled,bold", color="#00BABC", fillcolor="#F5F5F5"),
@deprecated(
"graph2dot should now be imported from lsst.pipe.base.dot_tools. Will be removed in v29.",
version="v27.0",
category=FutureWarning,
)


def _renderDefault(type: str, attribs: dict[str, str], file: io.TextIOBase) -> None:
"""Set default attributes for a given type."""
default_attribs = ", ".join([f'{key}="{val}"' for key, val in attribs.items()])
print(f"{type} [{default_attribs}];", file=file)


def _renderNode(file: io.TextIOBase, nodeName: str, style: str, labels: list[str]) -> None:
"""Render GV node"""
label = r"</TD></TR><TR><TD>".join(labels)
attrib_dict = dict(_ATTRIBS[style], label=label)
pre = '<<TABLE BORDER="0" CELLPADDING="5"><TR><TD>'
post = "</TD></TR></TABLE>>"
attrib = ", ".join(
[
f'{key}="{val}"' if key != "label" else f"{key}={pre}{val}{post}"
for key, val in attrib_dict.items()
]
)
print(f'"{nodeName}" [{attrib}];', file=file)


def _renderTaskNode(nodeName: str, taskDef: TaskDef, file: io.TextIOBase, idx: Any = None) -> None:
"""Render GV node for a task"""
labels = [
f'<B><FONT POINT-SIZE="{_NODELABELPOINTSIZE}">' + html.escape(taskDef.label) + "</FONT></B>",
html.escape(taskDef.taskName),
]
if idx is not None:
labels.append(f"<I>index:</I>&nbsp;{idx}")
if taskDef.connections:
# don't print collection of str directly to avoid visually noisy quotes
dimensions_str = ", ".join(sorted(taskDef.connections.dimensions))
labels.append(f"<I>dimensions:</I>&nbsp;{html.escape(dimensions_str)}")
_renderNode(file, nodeName, "task", labels)


def _renderQuantumNode(
nodeName: str, taskDef: TaskDef, quantumNode: QuantumNode, file: io.TextIOBase
) -> None:
"""Render GV node for a quantum"""
labels = [f"{quantumNode.nodeId}", html.escape(taskDef.label)]
dataId = quantumNode.quantum.dataId
assert dataId is not None, "Quantum DataId cannot be None"
labels.extend(f"{key} = {dataId[key]}" for key in sorted(dataId.required.keys()))
_renderNode(file, nodeName, "quantum", labels)


def _renderDSTypeNode(name: str, dimensions: list[str], file: io.TextIOBase) -> None:
"""Render GV node for a dataset type"""
labels = [f'<B><FONT POINT-SIZE="{_NODELABELPOINTSIZE}">' + html.escape(name) + "</FONT></B>"]
if dimensions:
labels.append("<I>dimensions:</I>&nbsp;" + html.escape(", ".join(sorted(dimensions))))
_renderNode(file, name, "dsType", labels)


def _renderDSNode(nodeName: str, dsRef: DatasetRef, file: io.TextIOBase) -> None:
"""Render GV node for a dataset"""
labels = [html.escape(dsRef.datasetType.name), f"run: {dsRef.run!r}"]
labels.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.required.keys()))
_renderNode(file, nodeName, "dataset", labels)


def _renderEdge(fromName: str, toName: str, file: io.TextIOBase, **kwargs: Any) -> None:
"""Render GV edge"""
if kwargs:
attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()])
print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file)
else:
print(f'"{fromName}" -> "{toName}";', file=file)


def _datasetRefId(dsRef: DatasetRef) -> str:
"""Make an identifying string for given ref"""
dsId = [dsRef.datasetType.name]
dsId.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.required.keys()))
return ":".join(dsId)


def _makeDSNode(dsRef: DatasetRef, allDatasetRefs: dict[str, str], file: io.TextIOBase) -> str:
"""Make new node for dataset if it does not exist.

Returns node name.
"""
dsRefId = _datasetRefId(dsRef)
nodeName = allDatasetRefs.get(dsRefId)
if nodeName is None:
idx = len(allDatasetRefs)
nodeName = f"dsref_{idx}"
allDatasetRefs[dsRefId] = nodeName
_renderDSNode(nodeName, dsRef, file)
return nodeName


# ------------------------
# Exported definitions --
# ------------------------


def graph2dot(qgraph: QuantumGraph, file: Any) -> None:
"""Convert QuantumGraph into GraphViz digraph.

Expand All @@ -183,45 +62,19 @@ def graph2dot(qgraph: QuantumGraph, file: Any) -> None:

Raises
------
`OSError` is raised when output file cannot be open.
`ImportError` is raised when task class cannot be imported.
OSError
Raised if the output file cannot be opened.
ImportError
Raised if the task class cannot be imported.
"""
# open a file if needed
close = False
if not hasattr(file, "write"):
file = open(file, "w")
close = True

print("digraph QuantumGraph {", file=file)
_renderDefault("graph", _ATTRIBS["defaultGraph"], file)
_renderDefault("node", _ATTRIBS["defaultNode"], file)
_renderDefault("edge", _ATTRIBS["defaultEdge"], file)

allDatasetRefs: dict[str, str] = {}
for taskId, taskDef in enumerate(qgraph.taskGraph):
quanta = qgraph.getNodesForTask(taskDef)
for qId, quantumNode in enumerate(quanta):
# node for a task
taskNodeName = f"task_{taskId}_{qId}"
_renderQuantumNode(taskNodeName, taskDef, quantumNode, file)

# quantum inputs
for dsRefs in quantumNode.quantum.inputs.values():
for dsRef in dsRefs:
nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
_renderEdge(nodeName, taskNodeName, file)

# quantum outputs
for dsRefs in quantumNode.quantum.outputs.values():
for dsRef in dsRefs:
nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
_renderEdge(taskNodeName, nodeName, file)

print("}", file=file)
if close:
file.close()
_graph2dot(qgraph, file)


@deprecated(
"pipeline2dot should now be imported from lsst.pipe.base.dot_tools. Will be removed in v29.",
version="v27.0",
category=FutureWarning,
)
def pipeline2dot(pipeline: Pipeline | Iterable[TaskDef], file: Any) -> None:
"""Convert `~lsst.pipe.base.Pipeline` into GraphViz digraph.

Expand All @@ -238,114 +91,9 @@ def pipeline2dot(pipeline: Pipeline | Iterable[TaskDef], file: Any) -> None:

Raises
------
`OSError` is raised when output file cannot be open.
`ImportError` is raised when task class cannot be imported.
`MissingTaskFactoryError` is raised when TaskFactory is needed but not
provided.
OSError
Raised if the output file cannot be opened.
ImportError
Raised if the task class cannot be imported.
"""
universe = DimensionUniverse()

def expand_dimensions(connection: connectionTypes.BaseConnection) -> list[str]:
"""Return expanded list of dimensions, with special skypix treatment.

Parameters
----------
connection : `list` [`str`]
Connection to examine.

Returns
-------
dimensions : `list` [`str`]
Expanded list of dimensions.
"""
dimension_set = set()
if isinstance(connection, connectionTypes.DimensionedConnection):
dimension_set = set(connection.dimensions)
skypix_dim = []
if "skypix" in dimension_set:
dimension_set.remove("skypix")
skypix_dim = ["skypix"]
dimensions = universe.conform(dimension_set)
return list(dimensions.names) + skypix_dim

# open a file if needed
close = False
if not hasattr(file, "write"):
file = open(file, "w")
close = True

print("digraph Pipeline {", file=file)
_renderDefault("graph", _ATTRIBS["defaultGraph"], file)
_renderDefault("node", _ATTRIBS["defaultNode"], file)
_renderDefault("edge", _ATTRIBS["defaultEdge"], file)

allDatasets: set[str | tuple[str, str]] = set()
if isinstance(pipeline, Pipeline):
# TODO: DM-40639 will rewrite this code and finish off the deprecation
# of toExpandedPipeline.
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
pipeline = pipeline.toExpandedPipeline()

# The next two lines are a workaround until DM-29658 at which time metadata
# connections should start working with the above code
labelToTaskName = {}
metadataNodesToLink = set()

for idx, taskDef in enumerate(sorted(pipeline, key=lambda x: x.label)):
# node for a task
taskNodeName = f"task{idx}"

# next line is workaround until DM-29658
labelToTaskName[taskDef.label] = taskNodeName

_renderTaskNode(taskNodeName, taskDef, file, None)

metadataRePattern = re.compile("^(.*)_metadata$")
for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name):
if attr.name not in allDatasets:
dimensions = expand_dimensions(attr)
_renderDSTypeNode(attr.name, dimensions, file)
allDatasets.add(attr.name)
nodeName, component = DatasetType.splitDatasetTypeName(attr.name)
_renderEdge(attr.name, taskNodeName, file)
# connect component dataset types to the composite type that
# produced it
if component is not None and (nodeName, attr.name) not in allDatasets:
_renderEdge(nodeName, attr.name, file)
allDatasets.add((nodeName, attr.name))
if nodeName not in allDatasets:
dimensions = expand_dimensions(attr)
_renderDSTypeNode(nodeName, dimensions, file)
# The next if block is a workaround until DM-29658 at which time
# metadata connections should start working with the above code
if (match := metadataRePattern.match(attr.name)) is not None:
matchTaskLabel = match.group(1)
metadataNodesToLink.add((matchTaskLabel, attr.name))

for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name):
if attr.name not in allDatasets:
dimensions = expand_dimensions(attr)
_renderDSTypeNode(attr.name, dimensions, file)
allDatasets.add(attr.name)
# use dashed line for prerequisite edges to distinguish them
_renderEdge(attr.name, taskNodeName, file, style="dashed")

for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name):
if attr.name not in allDatasets:
dimensions = expand_dimensions(attr)
_renderDSTypeNode(attr.name, dimensions, file)
allDatasets.add(attr.name)
_renderEdge(taskNodeName, attr.name, file)

# This for loop is a workaround until DM-29658 at which time metadata
# connections should start working with the above code
for matchLabel, dsTypeName in metadataNodesToLink:
# only render an edge to metadata if the label is part of the current
# graph
if (result := labelToTaskName.get(matchLabel)) is not None:
_renderEdge(result, dsTypeName, file)

print("}", file=file)
if close:
file.close()
_pipeline2dot(pipeline, file)
Loading
Loading