Skip to content

Commit

Permalink
Merge pull request #347 from lsst/tickets/DM-33027
Browse files Browse the repository at this point in the history
DM-33027: add PipelineGraph class
  • Loading branch information
TallJimbo committed Aug 4, 2023
2 parents eb4b848 + 8091ec1 commit d81147c
Show file tree
Hide file tree
Showing 20 changed files with 6,457 additions and 313 deletions.
1 change: 1 addition & 0 deletions doc/changes/DM-33027.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a PipelineGraph class that represents a Pipeline with all configuration overrides applied as a graph.
5 changes: 5 additions & 0 deletions doc/lsst.pipe.base/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Developing Pipelines

creating-a-pipeline.rst
testing-pipelines-with-mocks.rst
working-with-pipeline-graphs.rst

.. _lsst.pipe.base-contributing:

Expand All @@ -77,6 +78,10 @@ Python API reference
:no-main-docstr:
:skip: BuildId
:skip: DatasetTypeName
:skip: PipelineGraph

.. automodapi:: lsst.pipe.base.pipeline_graph
:no-main-docstr:

.. automodapi:: lsst.pipe.base.testUtils
:no-main-docstr:
Expand Down
88 changes: 88 additions & 0 deletions doc/lsst.pipe.base/working-with-pipeline-graphs.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
.. _pipe_base_pipeline_graphs:

.. py:currentmodule:: lsst.pipe.base.pipeline_graph
############################
Working with Pipeline Graphs
############################

Pipeline objects are written as YAML documents, but once they are fully configured, they are conceptually directed acyclic graphs (DAGs).
In code, this graph version of a pipeline is represented by the `PipelineGraph` class.
`PipelineGraph` objects are usually constructed by calling the `.Pipeline.to_graph` method::

from lsst.daf.butler import Butler
from lsst.pipe.base import Pipeline

butler = Butler("/some/repo")
pipeline = Pipeline.from_uri("my_pipeline.yaml")
graph = pipeline.to_graph(registry=butler.registry)

The ``registry`` argument here is optional, but without it the graph will be incomplete ("unresolved") and the pipeline will not be checked for correctness until the `~PipelineGraph.resolve` method is called.
Resolving a graph compares all of the task connections (which are edges in the graph) that reference each dataset type to each other and to the registry's definition of that dataset to determine a common graph-wide definition.
A definition in the registry always takes precedence, followed by the output connection that produces the dataset type (if there is one).
When a pipeline graph is used to register dataset types in a data repository, it is this common definition in the dataset type node that is registered.
Edge dataset type descriptions represent storage class overrides for a task, or specify that the task only wants a component.

Simple Graph Inspection
-----------------------

The basic structure of the graph can be explored via the `~PipelineGraph.tasks` and `~PipelineGraph.dataset_types` mapping attributes.
These are keyed by task label and *parent* (never component) dataset type name, and they have `TaskNode` and `DatasetTypeNode` objects as values, respectively.
A resolved pipeline graph is always sorted, which means iterations over these mappings will be in topological order.
`TaskNode` objects have an `~TaskNode.init` attribute that holds a `TaskInitNode` instance - these resemble `TaskNode` objects and have edges to dataset types as well, but these edges represent the "init input" and "init output" connections of those tasks.

Task and dataset type node objects have attributes holding all of their edges, but to get neighboring nodes, you have to go back to the graph object::

task_node = graph.tasks["task_a"]
for edge in task.inputs.values():
dataset_type_node = graph.dataset_types[edge.parent_dataset_type_name]
print(f"{task_node.label} takes {dataset_type_node.name} as an input.")

There are also convenience methods on `PipelineGraph` to get the edges or neighbors of a node:

- `~PipelineGraph.producing_edge_of`: an alternative to `DatasetTypeNode.producing_edge`
- `~PipelineGraph.consuming_edges_of`: an alternative to `DatasetTypeNode.consuming_edges`
- `~PipelineGraph.producer_of`: a shortcut for getting the task that write a dataset type
- `~PipelineGraph.consumers_of`: a shortcut for getting the tasks that read a dataset type
- `~PipelineGraph.inputs_of`: a shortcut for getting the dataset types that a task reads
- `~PipelineGraph.outputs_of`: a shortcut for getting the dataset types that a task writes

Pipeline graphs also hold the `~PipelineGraph.description` and `~PipelineGraph.data_id` (usually just an instrument value) of the pipeline used to construct them, as well as the same mapping of labeled task subsets (`~PipelineGraph.task_subsets`).

Modifying PipelineGraphs
------------------------

Usually the tasks in a pipeline are set before a `PipelineGraph` is ever constructed.
In some cases it may be more convenient to add tasks to an existing `PipelineGraph`, either because a related graph is being created from an existing one, or because a (rare) task needs to be configured in a way that depends on the content or structure of the rest of the graph.
`PipelineGraph` provides a number of mutation methods:

- `~PipelineGraph.add_task` adds a brand new task from a `.PipelineTask` type object and its configuration;
- `~PipelineGraph.add_task_nodes` adds one or more tasks from a different `PipelineGraph` instance;
- `~PipelineGraph.reconfigure_tasks` replaces the configuration of an existing task with new configuration (note that this is typically less convenient than adding config *overrides* to a `Pipeline` object, because all configuration in a `PipelineGraph` must be validated and frozen);
- `~PipelineGraph.remove_task_nodes` removes existing tasks;
- `~PipelineGraph.add_task_subset` and `~PipelineGraph.remove_task_subset` modify the mapping of labeled task subsets (which can also be modified in-place).

**The most important thing to remember when modifying `PipelineGraph` objects is that modifications typically reset some or all of the graph to an unresolved state.**

The reference documentation for these methods describes exactly what guarantees they make about existing resolutions in detail, and what operations are still supported on unresolved or partially-resolved graphs, but it is easiest to just ensure `resolve` is called after any modifications are complete.

`PipelineGraph` mutator methods provide strong exception safety (the graph is left unchanged when an exception is raised and caught by calling code) unless the exception type raised is `PipelineGraphExceptionSafetyError`.

Exporting to NetworkX
---------------------

NetworkX is a powerful Python library for graph manipulation, and in addition to being used in the implementation, `PipelineGraph` provides methods to create various native NetworkX graph objects.
The node attributes of these graphs provide much of the same information as the `TaskNode` and `DatasetTypeNode` objects (see the documentation for those objects for details).

The export methods include:

- `~PipelineGraph.make_xgraph` exports all nodes, including task nodes, dataset type nodes, and task init nodes, and the edges between them.
This is a `networkx.MultiDiGraph` because there can be (albeit) rarely multiple edges (representing different connections) between a dataset type and a task.
The edges of this graph have attributes as well as the nodes.
- `~PipelineGraph.make_bipartite_graph` exports just task nodes and dataset type nodes and the edges between them (or, if ``init=True``, just task init nodes and the dataset type nodes and edges between them).
A "bipartite" graph is one in which there are two kinds of nodes and edges only connect one type to the other.
This is also a `networkx.MultiDiGraph`, and its edges also have attributes.
- `~PipelineGraph.make_task_graph` exports just task (or task init) nodes; it is one "bipartite projection" of the full graph.
This is a `networkx.DiGraph`, because all dataset types that connect a pair of tasks are rolled into one edge, and edges have no state.
- `~PipelineGraph.make_dataset_type_graph` exports just dataset type nodes; it is one "bipartite projection" of the full graph.
This is a `networkx.DiGraph`, because all tasks that connect a pair of dataset types are rolled into one edge, and edges have no state.
6 changes: 5 additions & 1 deletion python/lsst/pipe/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import automatic_connection_constants, connectionTypes, pipelineIR
from . import automatic_connection_constants, connectionTypes, pipeline_graph, pipelineIR
from ._dataset_handle import *
from ._instrument import *
from ._observation_dimension_packer import *
Expand All @@ -11,6 +11,10 @@
from .graph import *
from .graphBuilder import *
from .pipeline import *

# We import the main PipelineGraph type and the module (above), but we don't
# lift all symbols to package scope.
from .pipeline_graph import PipelineGraph
from .pipelineTask import *
from .struct import *
from .task import *
Expand Down
175 changes: 40 additions & 135 deletions python/lsst/pipe/base/pipeTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,51 +27,24 @@
# No one should do import * from this module
__all__ = ["isPipelineOrdered", "orderPipeline"]

# -------------------------------
# Imports of standard modules --
# -------------------------------
import itertools
from collections.abc import Iterable
from typing import TYPE_CHECKING

# -----------------------------
# Imports for other modules --
# -----------------------------
from .connections import iterConnections
from .pipeline import Pipeline, TaskDef

# Exceptions re-exported here for backwards compatibility.
from .pipeline_graph import DuplicateOutputError, PipelineDataCycleError, PipelineGraph # noqa: F401

if TYPE_CHECKING:
from .pipeline import Pipeline, TaskDef
from .taskFactory import TaskFactory

# ----------------------------------
# Local non-exported definitions --
# ----------------------------------

# ------------------------
# Exported definitions --
# ------------------------


class MissingTaskFactoryError(Exception):
"""Exception raised when client fails to provide TaskFactory instance."""

pass


class DuplicateOutputError(Exception):
"""Exception raised when Pipeline has more than one task for the same
output.
"""

pass


class PipelineDataCycleError(Exception):
"""Exception raised when Pipeline has data dependency cycle."""

pass


def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskFactory | None = None) -> bool:
"""Check whether tasks in pipeline are correctly ordered.
Expand All @@ -80,134 +53,66 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF
Parameters
----------
pipeline : `pipe.base.Pipeline`
pipeline : `Pipeline` or `collections.abc.Iterable` [ `TaskDef` ]
Pipeline description.
taskFactory: `pipe.base.TaskFactory`, optional
Instance of an object which knows how to import task classes. It is
only used if pipeline task definitions do not define task classes.
taskFactory: `TaskFactory`, optional
Ignored; present only for backwards compatibility.
Returns
-------
True for correctly ordered pipeline, False otherwise.
is_ordered : `bool`
True for correctly ordered pipeline, False otherwise.
Raises
------
ImportError
Raised when task class cannot be imported.
DuplicateOutputError
Raised when there is more than one producer for a dataset type.
MissingTaskFactoryError
Raised when TaskFactory is needed but not provided.
"""
# Build a map of DatasetType name to producer's index in a pipeline
producerIndex = {}
for idx, taskDef in enumerate(pipeline):
for attr in iterConnections(taskDef.connections, "outputs"):
if attr.name in producerIndex:
raise DuplicateOutputError(
"DatasetType `{}' appears more than once as output".format(attr.name)
)
producerIndex[attr.name] = idx

# check all inputs that are also someone's outputs
for idx, taskDef in enumerate(pipeline):
# get task input DatasetTypes, this can only be done via class method
inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs}
for dsTypeDescr in inputs.values():
# all pre-existing datasets have effective index -1
prodIdx = producerIndex.get(dsTypeDescr.name, -1)
if prodIdx >= idx:
# not good, producer is downstream
return False

if isinstance(pipeline, Pipeline):
graph = pipeline.to_graph()
else:
graph = PipelineGraph()
for task_def in pipeline:
graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
# Can't use graph.is_sorted because that requires sorted dataset type names
# as well as sorted tasks.
tasks_xgraph = graph.make_task_xgraph()
seen: set[str] = set()
for task_label in tasks_xgraph:
successors = set(tasks_xgraph.successors(task_label))
if not successors.isdisjoint(seen):
return False
seen.add(task_label)
return True


def orderPipeline(pipeline: list[TaskDef]) -> list[TaskDef]:
def orderPipeline(pipeline: Pipeline | Iterable[TaskDef]) -> list[TaskDef]:
"""Re-order tasks in pipeline to satisfy data dependencies.
When possible new ordering keeps original relative order of the tasks.
Parameters
----------
pipeline : `list` of `pipe.base.TaskDef`
pipeline : `Pipeline` or `collections.abc.Iterable` [ `TaskDef` ]
Pipeline description.
Returns
-------
Correctly ordered pipeline (`list` of `pipe.base.TaskDef` objects).
ordered : `list` [ `TaskDef` ]
Correctly ordered pipeline.
Raises
------
`DuplicateOutputError` is raised when there is more than one producer for a
dataset type.
`PipelineDataCycleError` is also raised when pipeline has dependency
cycles. `MissingTaskFactoryError` is raised when `TaskFactory` is needed
but not provided.
DuplicateOutputError
Raised when there is more than one producer for a dataset type.
PipelineDataCycleError
Raised when the pipeline has dependency cycles.
"""
# This is a modified version of Kahn's algorithm that preserves order

# build mapping of the tasks to their inputs and outputs
inputs = {} # maps task index to its input DatasetType names
outputs = {} # maps task index to its output DatasetType names
allInputs = set() # all inputs of all tasks
allOutputs = set() # all outputs of all tasks
dsTypeTaskLabels: dict[str, str] = {} # maps DatasetType name to the label of its parent task
for idx, taskDef in enumerate(pipeline):
# task outputs
dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs}
for dsTypeDescr in dsMap.values():
if dsTypeDescr.name in allOutputs:
raise DuplicateOutputError(
f"DatasetType `{dsTypeDescr.name}' in task `{taskDef.label}' already appears as an "
f"output in task `{dsTypeTaskLabels[dsTypeDescr.name]}'."
)
dsTypeTaskLabels[dsTypeDescr.name] = taskDef.label
outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values())
allOutputs.update(outputs[idx])

# task inputs
connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs)
inputs[idx] = set(getattr(taskDef.connections, name).name for name in connectionInputs)
allInputs.update(inputs[idx])

# for simplicity add pseudo-node which is a producer for all pre-existing
# inputs, its index is -1
preExisting = allInputs - allOutputs
outputs[-1] = preExisting

# Set of nodes with no incoming edges, initially set to pseudo-node
queue = [-1]
result = []
while queue:
# move to final list, drop -1
idx = queue.pop(0)
if idx >= 0:
result.append(idx)

# remove task outputs from other tasks inputs
thisTaskOutputs = outputs.get(idx, set())
for taskInputs in inputs.values():
taskInputs -= thisTaskOutputs

# find all nodes with no incoming edges and move them to the queue
topNodes = [key for key, value in inputs.items() if not value]
queue += topNodes
for key in topNodes:
del inputs[key]

# keep queue ordered
queue.sort()

# if there is something left it means cycles
if inputs:
# format it in usable way
loops = []
for idx, inputNames in inputs.items():
taskName = pipeline[idx].label
outputNames = outputs[idx]
edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames)
loops.append(edge)
raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops))

return [pipeline[idx] for idx in result]
if isinstance(pipeline, Pipeline):
graph = pipeline.to_graph()
else:
graph = PipelineGraph()
for task_def in pipeline:
graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
graph.sort()
return list(graph._iter_task_defs())
Loading

0 comments on commit d81147c

Please sign in to comment.