Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-33027: add PipelineGraph and supporting classes #221

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/lsst/pipe/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import automatic_connection_constants, connectionTypes, pipelineIR
from . import automatic_connection_constants, connectionTypes, pipeline_graph, pipelineIR
from ._dataset_handle import *
from ._instrument import *
from ._observation_dimension_packer import *
Expand All @@ -11,6 +11,10 @@
from .graph import *
from .graphBuilder import *
from .pipeline import *

# We import the main PipelineGraph type and the module (above), but we don't
# lift all symbols to package scope.
from .pipeline_graph import PipelineGraph
from .pipelineTask import *
from .struct import *
from .task import *
Expand Down
177 changes: 41 additions & 136 deletions python/lsst/pipe/base/pipeTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# (http://www.lsst.org).XS
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
Expand All @@ -27,51 +27,24 @@
# No one should do import * from this module
__all__ = ["isPipelineOrdered", "orderPipeline"]

# -------------------------------
# Imports of standard modules --
# -------------------------------
import itertools
from collections.abc import Iterable
from typing import TYPE_CHECKING

# -----------------------------
# Imports for other modules --
# -----------------------------
from .connections import iterConnections
from .pipeline import Pipeline, TaskDef

# Exceptions re-exported here for backwards compatibility.
from .pipeline_graph import DuplicateOutputError, PipelineDataCycleError, PipelineGraph # noqa: F401

if TYPE_CHECKING:
from .pipeline import Pipeline, TaskDef
from .taskFactory import TaskFactory

# ----------------------------------
# Local non-exported definitions --
# ----------------------------------

# ------------------------
# Exported definitions --
# ------------------------


class MissingTaskFactoryError(Exception):
"""Exception raised when client fails to provide TaskFactory instance."""

pass


class DuplicateOutputError(Exception):
"""Exception raised when Pipeline has more than one task for the same
output.
"""

pass


class PipelineDataCycleError(Exception):
"""Exception raised when Pipeline has data dependency cycle."""

pass


def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskFactory | None = None) -> bool:
"""Check whether tasks in pipeline are correctly ordered.

Expand All @@ -80,134 +53,66 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF

Parameters
----------
pipeline : `pipe.base.Pipeline`
pipeline : `Pipeline` or `collections.abc.Iterable` [ `TaskDef` ]
Pipeline description.
taskFactory: `pipe.base.TaskFactory`, optional
Instance of an object which knows how to import task classes. It is
only used if pipeline task definitions do not define task classes.
taskFactory: `TaskFactory`, optional
Ignored; present only for backwards compatibility.

Returns
-------
True for correctly ordered pipeline, False otherwise.
is_ordered : `bool`
True for correctly ordered pipeline, False otherwise.

Raises
------
ImportError
Raised when task class cannot be imported.
DuplicateOutputError
Raised when there is more than one producer for a dataset type.
MissingTaskFactoryError
Raised when TaskFactory is needed but not provided.
"""
# Build a map of DatasetType name to producer's index in a pipeline
producerIndex = {}
for idx, taskDef in enumerate(pipeline):
for attr in iterConnections(taskDef.connections, "outputs"):
if attr.name in producerIndex:
raise DuplicateOutputError(
"DatasetType `{}' appears more than once as output".format(attr.name)
)
producerIndex[attr.name] = idx

# check all inputs that are also someone's outputs
for idx, taskDef in enumerate(pipeline):
# get task input DatasetTypes, this can only be done via class method
inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs}
for dsTypeDescr in inputs.values():
# all pre-existing datasets have effective index -1
prodIdx = producerIndex.get(dsTypeDescr.name, -1)
if prodIdx >= idx:
# not good, producer is downstream
return False

if isinstance(pipeline, Pipeline):
graph = pipeline.to_graph()
else:
graph = PipelineGraph()
for task_def in pipeline:
graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
# Can't use graph.is_sorted because that requires sorted dataset type names
# as well as sorted tasks.
tasks_xgraph = graph.make_task_xgraph()
seen: set[str] = set()
for task_label in tasks_xgraph:
successors = set(tasks_xgraph.successors(task_label))
if not successors.isdisjoint(seen):
return False
seen.add(task_label)
return True


def orderPipeline(pipeline: list[TaskDef]) -> list[TaskDef]:
def orderPipeline(pipeline: Pipeline | Iterable[TaskDef]) -> list[TaskDef]:
"""Re-order tasks in pipeline to satisfy data dependencies.

When possible new ordering keeps original relative order of the tasks.

Parameters
----------
pipeline : `list` of `pipe.base.TaskDef`
pipeline : `Pipeline` or `collections.abc.Iterable` [ `TaskDef` ]
Pipeline description.

Returns
-------
Correctly ordered pipeline (`list` of `pipe.base.TaskDef` objects).
ordered : `list` [ `TaskDef` ]
Correctly ordered pipeline.

Raises
------
`DuplicateOutputError` is raised when there is more than one producer for a
dataset type.
`PipelineDataCycleError` is also raised when pipeline has dependency
cycles. `MissingTaskFactoryError` is raised when `TaskFactory` is needed
but not provided.
DuplicateOutputError
Raised when there is more than one producer for a dataset type.
PipelineDataCycleError
Raised when the pipeline has dependency cycles.
"""
# This is a modified version of Kahn's algorithm that preserves order

# build mapping of the tasks to their inputs and outputs
inputs = {} # maps task index to its input DatasetType names
outputs = {} # maps task index to its output DatasetType names
allInputs = set() # all inputs of all tasks
allOutputs = set() # all outputs of all tasks
dsTypeTaskLabels: dict[str, str] = {} # maps DatasetType name to the label of its parent task
for idx, taskDef in enumerate(pipeline):
# task outputs
dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs}
for dsTypeDescr in dsMap.values():
if dsTypeDescr.name in allOutputs:
raise DuplicateOutputError(
f"DatasetType `{dsTypeDescr.name}' in task `{taskDef.label}' already appears as an "
f"output in task `{dsTypeTaskLabels[dsTypeDescr.name]}'."
)
dsTypeTaskLabels[dsTypeDescr.name] = taskDef.label
outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values())
allOutputs.update(outputs[idx])

# task inputs
connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs)
inputs[idx] = set(getattr(taskDef.connections, name).name for name in connectionInputs)
allInputs.update(inputs[idx])

# for simplicity add pseudo-node which is a producer for all pre-existing
# inputs, its index is -1
preExisting = allInputs - allOutputs
outputs[-1] = preExisting

# Set of nodes with no incoming edges, initially set to pseudo-node
queue = [-1]
result = []
while queue:
# move to final list, drop -1
idx = queue.pop(0)
if idx >= 0:
result.append(idx)

# remove task outputs from other tasks inputs
thisTaskOutputs = outputs.get(idx, set())
for taskInputs in inputs.values():
taskInputs -= thisTaskOutputs

# find all nodes with no incoming edges and move them to the queue
topNodes = [key for key, value in inputs.items() if not value]
queue += topNodes
for key in topNodes:
del inputs[key]

# keep queue ordered
queue.sort()

# if there is something left it means cycles
if inputs:
# format it in usable way
loops = []
for idx, inputNames in inputs.items():
taskName = pipeline[idx].label
outputNames = outputs[idx]
edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames)
loops.append(edge)
raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops))

return [pipeline[idx] for idx in result]
if isinstance(pipeline, Pipeline):
graph = pipeline.to_graph()
else:
graph = PipelineGraph()
for task_def in pipeline:
graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
graph.sort()
return list(graph._iter_task_defs())
Loading