Skip to content

Commit

Permalink
Integrate PipelineGraph with Pipeline and use it for sorting.
Browse files Browse the repository at this point in the history
Much of the code changed here is actually stuff I want to deprecate in
the future, once PipelineGraph has been integrated with more things.
In the meantime, this addresses much the duplication caused by adding
PipelineGraph.
  • Loading branch information
TallJimbo committed Jun 23, 2023
1 parent 4f92f7a commit 386ecd1
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 174 deletions.
159 changes: 33 additions & 126 deletions python/lsst/pipe/base/pipeTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# (http://www.lsst.org).XS
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
Expand All @@ -27,51 +27,24 @@
# No one should do import * from this module
__all__ = ["isPipelineOrdered", "orderPipeline"]

# -------------------------------
# Imports of standard modules --
# -------------------------------
import itertools
from collections.abc import Iterable, Sequence
from collections.abc import Iterable
from typing import TYPE_CHECKING

# -----------------------------
# Imports for other modules --
# -----------------------------
from .connections import iterConnections
from .pipeline import Pipeline, TaskDef

# Exceptions re-exported here for backwards compatibility.
from .pipeline_graph import DuplicateOutputError, PipelineDataCycleError, PipelineGraph # noqa: F401

if TYPE_CHECKING:
from .pipeline import Pipeline, TaskDef
from .taskFactory import TaskFactory

# ----------------------------------
# Local non-exported definitions --
# ----------------------------------

# ------------------------
# Exported definitions --
# ------------------------


class MissingTaskFactoryError(Exception):
"""Exception raised when client fails to provide TaskFactory instance."""

pass


class DuplicateOutputError(Exception):
"""Exception raised when Pipeline has more than one task for the same
output.
"""

pass


class PipelineDataCycleError(Exception):
"""Exception raised when Pipeline has data dependency cycle."""

pass


def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskFactory | None = None) -> bool:
"""Check whether tasks in pipeline are correctly ordered.
Expand All @@ -80,9 +53,9 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF
Parameters
----------
pipeline : `pipe.base.Pipeline` or `collections.abc.Iterable` [ `TaskDef` ]
pipeline : `Pipeline` or `collections.abc.Iterable` [ `TaskDef` ]
Pipeline description.
taskFactory: `pipe.base.TaskFactory`, optional
taskFactory: `TaskFactory`, optional
Ignored; present only for backwards compatibility.
Returns
Expand All @@ -97,38 +70,30 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF
DuplicateOutputError
Raised when there is more than one producer for a dataset type.
"""
# Build a map of DatasetType name to producer's index in a pipeline
producerIndex = {}
for idx, taskDef in enumerate(pipeline):
for attr in iterConnections(taskDef.connections, "outputs"):
if attr.name in producerIndex:
raise DuplicateOutputError(
"DatasetType `{}' appears more than once as output".format(attr.name)
)
producerIndex[attr.name] = idx

# check all inputs that are also someone's outputs
for idx, taskDef in enumerate(pipeline):
# get task input DatasetTypes, this can only be done via class method
inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs}
for dsTypeDescr in inputs.values():
# all pre-existing datasets have effective index -1
prodIdx = producerIndex.get(dsTypeDescr.name, -1)
if prodIdx >= idx:
# not good, producer is downstream
return False

if isinstance(pipeline, Pipeline):
graph = pipeline.to_graph()
else:
graph = PipelineGraph()
for task_def in pipeline:
graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
# Can't use graph.is_sorted because that requires sorted dataset type names
# as well as sorted tasks.
tasks_xgraph = graph.make_task_xgraph()
seen: set[str] = set()
for task_label in tasks_xgraph:
successors = set(tasks_xgraph.successors(task_label))
if not successors.isdisjoint(seen):
return False
seen.add(task_label)
return True


def orderPipeline(pipeline: Sequence[TaskDef]) -> list[TaskDef]:
def orderPipeline(pipeline: Pipeline | Iterable[TaskDef]) -> list[TaskDef]:
"""Re-order tasks in pipeline to satisfy data dependencies.
When possible new ordering keeps original relative order of the tasks.
Parameters
----------
pipeline : `pipe.base.Pipeline` or `collections.abc.Iterable` [ `TaskDef` ]
pipeline : `Pipeline` or `collections.abc.Iterable` [ `TaskDef` ]
Pipeline description.
Returns
Expand All @@ -143,69 +108,11 @@ def orderPipeline(pipeline: Sequence[TaskDef]) -> list[TaskDef]:
PipelineDataCycleError
Raised when the pipeline has dependency cycles.
"""
# This is a modified version of Kahn's algorithm that preserves order

# build mapping of the tasks to their inputs and outputs
inputs = {} # maps task index to its input DatasetType names
outputs = {} # maps task index to its output DatasetType names
allInputs = set() # all inputs of all tasks
allOutputs = set() # all outputs of all tasks
dsTypeTaskLabels: dict[str, str] = {} # maps DatasetType name to the label of its parent task
for idx, taskDef in enumerate(pipeline):
# task outputs
dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs}
for dsTypeDescr in dsMap.values():
if dsTypeDescr.name in allOutputs:
raise DuplicateOutputError(
f"DatasetType `{dsTypeDescr.name}' in task `{taskDef.label}' already appears as an "
f"output in task `{dsTypeTaskLabels[dsTypeDescr.name]}'."
)
dsTypeTaskLabels[dsTypeDescr.name] = taskDef.label
outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values())
allOutputs.update(outputs[idx])

# task inputs
connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs)
inputs[idx] = set(getattr(taskDef.connections, name).name for name in connectionInputs)
allInputs.update(inputs[idx])

# for simplicity add pseudo-node which is a producer for all pre-existing
# inputs, its index is -1
preExisting = allInputs - allOutputs
outputs[-1] = preExisting

# Set of nodes with no incoming edges, initially set to pseudo-node
queue = [-1]
result = []
while queue:
# move to final list, drop -1
idx = queue.pop(0)
if idx >= 0:
result.append(idx)

# remove task outputs from other tasks inputs
thisTaskOutputs = outputs.get(idx, set())
for taskInputs in inputs.values():
taskInputs -= thisTaskOutputs

# find all nodes with no incoming edges and move them to the queue
topNodes = [key for key, value in inputs.items() if not value]
queue += topNodes
for key in topNodes:
del inputs[key]

# keep queue ordered
queue.sort()

# if there is something left it means cycles
if inputs:
# format it in usable way
loops = []
for idx, inputNames in inputs.items():
taskName = pipeline[idx].label
outputNames = outputs[idx]
edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames)
loops.append(edge)
raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops))

return [pipeline[idx] for idx in result]
if isinstance(pipeline, Pipeline):
graph = pipeline.to_graph()
else:
graph = PipelineGraph()
for task_def in pipeline:
graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
graph.sort()
return list(graph._iter_task_defs())
100 changes: 72 additions & 28 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,12 @@
from lsst.utils.introspection import get_full_type_name

from . import automatic_connection_constants as acc
from . import pipelineIR, pipeTools
from . import pipeline_graph, pipelineIR
from ._instrument import Instrument as PipeBaseInstrument
from ._task_metadata import TaskMetadata
from .config import PipelineTaskConfig
from .connections import PipelineTaskConnections, iterConnections
from .connectionTypes import Input
from .pipelineTask import PipelineTask
from .task import _TASK_METADATA_TYPE

if TYPE_CHECKING: # Imports needed only for type annotations; may be circular.
from lsst.obs.base import Instrument
Expand Down Expand Up @@ -749,6 +747,47 @@ def write_to_uri(self, uri: ResourcePathExpression) -> None:
"""
self._pipelineIR.write_to_uri(uri)

def to_graph(self) -> pipeline_graph.PipelineGraph:
"""Construct a pipeline graph from this pipeline.
Constructing a graph applies all configuration overrides, freezes all
configuration, checks all contracts, and checks for dataset type
consistency between tasks (as much as possible without access to a data
repository). It cannot be reversed.
Returns
-------
graph : `pipeline_graph.PipelineGraph`
Representation of the pipeline as a graph.
"""
instrument_class_name = self._pipelineIR.instrument
data_id = {}
if instrument_class_name is not None:
instrument_class = doImportType(instrument_class_name)
if instrument_class is not None:
data_id["instrument"] = instrument_class.getName()
graph = pipeline_graph.PipelineGraph(data_id=data_id)
graph.description = self._pipelineIR.description
for label in self._pipelineIR.tasks:
self._add_task_to_graph(label, graph)
if self._pipelineIR.contracts is not None:
label_to_config = {x.label: x.config for x in graph.tasks.values()}
for contract in self._pipelineIR.contracts:
# execute this in its own line so it can raise a good error
# message if there was problems with the eval
success = eval(contract.contract, None, label_to_config)
if not success:
extra_info = f": {contract.msg}" if contract.msg is not None else ""
raise pipelineIR.ContractError(
f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
)
for label, subset in self._pipelineIR.labeled_subsets.items():
graph.add_task_subset(
label, subset.subset, subset.description if subset.description is not None else ""
)
graph.sort()
return graph

def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
r"""Return a generator of `TaskDef`\s which can be used to create
quantum graphs.
Expand All @@ -765,31 +804,22 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
If a dataId is supplied in a config block. This is in place for
future use
"""
taskDefs = []
for label in self._pipelineIR.tasks:
taskDefs.append(self._buildTaskDef(label))

# lets evaluate the contracts
if self._pipelineIR.contracts is not None:
label_to_config = {x.label: x.config for x in taskDefs}
for contract in self._pipelineIR.contracts:
# execute this in its own line so it can raise a good error
# message if there was problems with the eval
success = eval(contract.contract, None, label_to_config)
if not success:
extra_info = f": {contract.msg}" if contract.msg is not None else ""
raise pipelineIR.ContractError(
f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
)
yield from self.to_graph()._iter_task_defs()

taskDefs = sorted(taskDefs, key=lambda x: x.label)
yield from pipeTools.orderPipeline(taskDefs)
def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) -> None:
"""Add a single task from this pipeline to a pipeline graph that is
under construction.
def _buildTaskDef(self, label: str) -> TaskDef:
Parameters
----------
label : `str`
Label for the task to be added.
graph : `pipeline_graph.PipelineGraph`
Graph to add the task to.
"""
if (taskIR := self._pipelineIR.tasks.get(label)) is None:
raise NameError(f"Label {label} does not appear in this pipeline")
taskClass: type[PipelineTask] = doImportType(taskIR.klass)
taskName = get_full_type_name(taskClass)
config = taskClass.ConfigClass()
instrument: PipeBaseInstrument | None = None
if (instrumentName := self._pipelineIR.instrument) is not None:
Expand All @@ -802,13 +832,19 @@ def _buildTaskDef(self, label: str) -> TaskDef:
self._pipelineIR.parameters,
label,
)
return TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label)
graph.add_task(label, taskClass, config)

def __iter__(self) -> Generator[TaskDef, None, None]:
return self.toExpandedPipeline()

def __getitem__(self, item: str) -> TaskDef:
return self._buildTaskDef(item)
# Making a whole graph and then making a TaskDef from that is pretty
# backwards, but I'm hoping to deprecate this method shortly in favor
# of making the graph explicitly and working with its node objects.
graph = pipeline_graph.PipelineGraph()
self._add_task_to_graph(item, graph)
(result,) = graph._iter_task_defs()
return result

def __len__(self) -> int:
return len(self._pipelineIR.tasks)
Expand Down Expand Up @@ -1082,7 +1118,7 @@ def makeDatasetTypesSet(
DatasetType(
taskDef.configDatasetName,
registry.dimensions.empty,
storageClass="Config",
storageClass=acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
)
)
initOutputs.freeze()
Expand All @@ -1100,15 +1136,23 @@ def makeDatasetTypesSet(
current = registry.getDatasetType(taskDef.metadataDatasetName)
except KeyError:
# No previous definition so use the default.
storageClass = "TaskMetadata" if _TASK_METADATA_TYPE is TaskMetadata else "PropertySet"
storageClass = acc.METADATA_OUTPUT_STORAGE_CLASS
else:
storageClass = current.storageClass.name
outputs.update({DatasetType(taskDef.metadataDatasetName, dimensions, storageClass)})

if taskDef.logOutputDatasetName is not None:
# Log output dimensions correspond to a task quantum.
dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
outputs.update({DatasetType(taskDef.logOutputDatasetName, dimensions, "ButlerLogRecords")})
outputs.update(
{
DatasetType(
taskDef.logOutputDatasetName,
dimensions,
acc.LOG_OUTPUT_STORAGE_CLASS,
)
}
)

outputs.freeze()

Expand Down
Loading

0 comments on commit 386ecd1

Please sign in to comment.