Skip to content

Commit

Permalink
Integrate PipelineGraph with Pipeline and use it for sorting.
Browse files Browse the repository at this point in the history
Much of the code changed here is actually stuff I want to deprecate in
the future, once PipelineGraph has been integrated with more things.
In the meantime, this addresses much the duplication caused by adding
PipelineGraph.
  • Loading branch information
TallJimbo committed Mar 13, 2023
1 parent 2fe2618 commit c706735
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 159 deletions.
148 changes: 27 additions & 121 deletions python/lsst/pipe/base/pipeTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,50 +27,23 @@
# No one should do import * from this module
__all__ = ["isPipelineOrdered", "orderPipeline"]

# -------------------------------
# Imports of standard modules --
# -------------------------------
import itertools
from typing import TYPE_CHECKING, Iterable

# -----------------------------
# Imports for other modules --
# -----------------------------
from .connections import iterConnections
from .pipeline import Pipeline, TaskDef

# Exceptions re-exported here for backwards compatibility.
from .pipeline_graph import DuplicateOutputError, MutablePipelineGraph, PipelineDataCycleError # noqa: F401

if TYPE_CHECKING:
from .pipeline import Pipeline, TaskDef
from .taskFactory import TaskFactory

# ----------------------------------
# Local non-exported definitions --
# ----------------------------------

# ------------------------
# Exported definitions --
# ------------------------


class MissingTaskFactoryError(Exception):
"""Exception raised when client fails to provide TaskFactory instance."""

pass


class DuplicateOutputError(Exception):
"""Exception raised when Pipeline has more than one task for the same
output.
"""

pass


class PipelineDataCycleError(Exception):
"""Exception raised when Pipeline has data dependency cycle."""

pass


def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskFactory | None = None) -> bool:
"""Checks whether tasks in pipeline are correctly ordered.
Expand All @@ -96,35 +69,27 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF
DuplicateOutputError
Raised when there is more than one producer for a dataset type.
"""
# Build a map of DatasetType name to producer's index in a pipeline
producerIndex = {}
for idx, taskDef in enumerate(pipeline):
for attr in iterConnections(taskDef.connections, "outputs"):
if attr.name in producerIndex:
raise DuplicateOutputError(
"DatasetType `{}' appears more than once as output".format(attr.name)
)
producerIndex[attr.name] = idx

# check all inputs that are also someone's outputs
for idx, taskDef in enumerate(pipeline):
# get task input DatasetTypes, this can only be done via class method
inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs}
for dsTypeDescr in inputs.values():
# all pre-existing datasets have effective index -1
prodIdx = producerIndex.get(dsTypeDescr.name, -1)
if prodIdx >= idx:
# not good, producer is downstream
return False

if isinstance(pipeline, Pipeline):
graph = pipeline.to_graph()
else:
graph = MutablePipelineGraph()
for task_def in pipeline:
graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
# Can't use graph.is_sorted because that requires sorted dataset type names
# as well as sorted tasks.
tasks_xgraph = graph.make_task_xgraph()
seen: set[str] = set()
for task_label in tasks_xgraph:
successors = set(tasks_xgraph.successors(task_label))
if not successors.isdisjoint(seen):
return False
seen.add(task_label)
return True


def orderPipeline(pipeline: Pipeline | Iterable[TaskDef]) -> list[TaskDef]:
"""Re-order tasks in pipeline to satisfy data dependencies.
When possible new ordering keeps original relative order of the tasks.
Parameters
----------
pipeline : `pipe.base.Pipeline` or `collections.abc.Iterable` [ `TaskDef` ]
Expand All @@ -142,70 +107,11 @@ def orderPipeline(pipeline: Pipeline | Iterable[TaskDef]) -> list[TaskDef]:
PipelineDataCycleError
Raised when the pipeline has dependency cycles.
"""

# This is a modified version of Kahn's algorithm that preserves order

# build mapping of the tasks to their inputs and outputs
inputs = {} # maps task index to its input DatasetType names
outputs = {} # maps task index to its output DatasetType names
allInputs = set() # all inputs of all tasks
allOutputs = set() # all outputs of all tasks
dsTypeTaskLabels: dict[str, str] = {} # maps DatasetType name to the label of its parent task
for idx, taskDef in enumerate(pipeline):
# task outputs
dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs}
for dsTypeDescr in dsMap.values():
if dsTypeDescr.name in allOutputs:
raise DuplicateOutputError(
f"DatasetType `{dsTypeDescr.name}' in task `{taskDef.label}' already appears as an "
f"output in task `{dsTypeTaskLabels[dsTypeDescr.name]}'."
)
dsTypeTaskLabels[dsTypeDescr.name] = taskDef.label
outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values())
allOutputs.update(outputs[idx])

# task inputs
connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs)
inputs[idx] = set(getattr(taskDef.connections, name).name for name in connectionInputs)
allInputs.update(inputs[idx])

# for simplicity add pseudo-node which is a producer for all pre-existing
# inputs, its index is -1
preExisting = allInputs - allOutputs
outputs[-1] = preExisting

# Set of nodes with no incoming edges, initially set to pseudo-node
queue = [-1]
result = []
while queue:
# move to final list, drop -1
idx = queue.pop(0)
if idx >= 0:
result.append(idx)

# remove task outputs from other tasks inputs
thisTaskOutputs = outputs.get(idx, set())
for taskInputs in inputs.values():
taskInputs -= thisTaskOutputs

# find all nodes with no incoming edges and move them to the queue
topNodes = [key for key, value in inputs.items() if not value]
queue += topNodes
for key in topNodes:
del inputs[key]

# keep queue ordered
queue.sort()

# if there is something left it means cycles
if inputs:
# format it in usable way
loops = []
for idx, inputNames in inputs.items():
taskName = pipeline[idx].label
outputNames = outputs[idx]
edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames)
loops.append(edge)
raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops))

return [pipeline[idx] for idx in result]
if isinstance(pipeline, Pipeline):
graph = pipeline.to_graph()
else:
graph = MutablePipelineGraph()
for task_def in pipeline:
graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
graph.sort()
return list(graph._iter_task_defs())
121 changes: 84 additions & 37 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,12 @@
from lsst.utils.introspection import get_full_type_name

from . import automatic_connection_constants as acc
from . import pipelineIR, pipeTools
from ._task_metadata import TaskMetadata
from . import pipeline_graph, pipelineIR
from .config import PipelineTaskConfig
from .configOverrides import ConfigOverrides
from .connections import iterConnections
from .connections import PipelineTaskConnections, iterConnections
from .connectionTypes import Input
from .pipelineTask import PipelineTask
from .task import _TASK_METADATA_TYPE

if TYPE_CHECKING: # Imports needed only for type annotations; may be circular.
from lsst.obs.base import Instrument
Expand Down Expand Up @@ -135,6 +133,11 @@ class TaskDef:
Task label, usually a short string unique in a pipeline. If not
provided, ``taskClass`` must be, and ``taskClass._DefaultName`` will
be used.
connections : `PipelineTaskConnections`, optional
Object that describes the dataset types used by the task. If not
provided, one will be constructed from the given configuration. If
provided, it is assumed that ``config`` has already been validated
and frozen.
"""

def __init__(
Expand All @@ -143,6 +146,7 @@ def __init__(
config: Optional[PipelineTaskConfig] = None,
taskClass: Optional[Type[PipelineTask]] = None,
label: Optional[str] = None,
connections: PipelineTaskConnections | None = None,
):
if taskName is None:
if taskClass is None:
Expand All @@ -159,16 +163,20 @@ def __init__(
raise ValueError("`taskClass` must be provided if `label` is not.")
label = taskClass._DefaultName
self.taskName = taskName
try:
config.validate()
except Exception:
_LOG.error("Configuration validation failed for task %s (%s)", label, taskName)
raise
config.freeze()
if connections is None:
# If we don't have connections yet, assume the config hasn't been
# validated yet.
try:
config.validate()
except Exception:
_LOG.error("Configuration validation failed for task %s (%s)", label, taskName)
raise
config.freeze()
connections = config.connections.ConnectionsClass(config=config)
self.config = config
self.taskClass = taskClass
self.label = label
self.connections = config.connections.ConnectionsClass(config=config)
self.connections = connections

@property
def configDatasetName(self) -> str:
Expand All @@ -181,7 +189,7 @@ def metadataDatasetName(self) -> Optional[str]:
metadata is not to be saved (`str`)
"""
if self.config.saveMetadata:
return self.makeMetadataDatasetName(self.label)
return self.makeMetadataDatasetName(label=self.label)
else:
return None

Expand Down Expand Up @@ -732,6 +740,40 @@ def write_to_uri(self, uri: ResourcePathExpression) -> None:
"""
self._pipelineIR.write_to_uri(uri)

def to_graph(self) -> pipeline_graph.MutablePipelineGraph:
"""Construct a pipeline graph from this pipeline.
Constructing a graph applies all configuration overrides, freezes all
configuration, checks all contracts, and checks for dataset type
consistency between tasks (as much as possible without access to a data
repository). It cannot be reversed.
Returns
-------
graph : `pipeline_graph.MutablePipelineGraph`
Representation of the pipeline as a graph.
"""
graph = pipeline_graph.MutablePipelineGraph()
for label in self._pipelineIR.tasks:
self._add_task_to_graph(label, graph)
if self._pipelineIR.contracts is not None:
label_to_config = {x.label: x.config for x in graph.tasks.values()}
for contract in self._pipelineIR.contracts:
# execute this in its own line so it can raise a good error
# message if there was problems with the eval
success = eval(contract.contract, None, label_to_config)
if not success:
extra_info = f": {contract.msg}" if contract.msg is not None else ""
raise pipelineIR.ContractError(
f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
)
for label, subset in self._pipelineIR.labeled_subsets.items():
graph.add_task_subset(
label, subset.subset, subset.description if subset.description is not None else ""
)
graph.sort()
return graph

def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
"""Returns a generator of TaskDefs which can be used to create quantum
graphs.
Expand All @@ -748,31 +790,22 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
If a dataId is supplied in a config block. This is in place for
future use
"""
taskDefs = []
for label in self._pipelineIR.tasks:
taskDefs.append(self._buildTaskDef(label))

# lets evaluate the contracts
if self._pipelineIR.contracts is not None:
label_to_config = {x.label: x.config for x in taskDefs}
for contract in self._pipelineIR.contracts:
# execute this in its own line so it can raise a good error
# message if there was problems with the eval
success = eval(contract.contract, None, label_to_config)
if not success:
extra_info = f": {contract.msg}" if contract.msg is not None else ""
raise pipelineIR.ContractError(
f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
)
yield from self.to_graph()._iter_task_defs()

taskDefs = sorted(taskDefs, key=lambda x: x.label)
yield from pipeTools.orderPipeline(taskDefs)
def _add_task_to_graph(self, label: str, graph: pipeline_graph.MutablePipelineGraph) -> None:
"""Add a single task from this pipeline to a pipeline graph that is
under construction.
def _buildTaskDef(self, label: str) -> TaskDef:
Parameters
----------
label : `str`
Label for the task to be added.
graph : `pipeline_graph.MutablePipelineGraph`
Graph to add the task to.
"""
if (taskIR := self._pipelineIR.tasks.get(label)) is None:
raise NameError(f"Label {label} does not appear in this pipeline")
taskClass: Type[PipelineTask] = doImportType(taskIR.klass)
taskName = get_full_type_name(taskClass)
config = taskClass.ConfigClass()
overrides = ConfigOverrides()
if self._pipelineIR.instrument is not None:
Expand All @@ -794,13 +827,19 @@ def _buildTaskDef(self, label: str) -> TaskDef:
for key, value in configIR.rest.items():
overrides.addValueOverride(key, value)
overrides.applyTo(config)
return TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label)
graph.add_task(label, taskClass, config)

def __iter__(self) -> Generator[TaskDef, None, None]:
return self.toExpandedPipeline()

def __getitem__(self, item: str) -> TaskDef:
return self._buildTaskDef(item)
# Making a whole graph and then making a TaskDef from that is pretty
# backwards, but I'm hoping to deprecate this method shortly in favor
# of making the graph explicitly and working with its node objects.
graph = pipeline_graph.MutablePipelineGraph()
self._add_task_to_graph(item, graph)
(result,) = graph._iter_task_defs()
return result

def __len__(self) -> int:
return len(self._pipelineIR.tasks)
Expand Down Expand Up @@ -1072,7 +1111,7 @@ def makeDatasetTypesSet(
DatasetType(
taskDef.configDatasetName,
registry.dimensions.empty,
storageClass="Config",
storageClass=acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
)
)
initOutputs.freeze()
Expand All @@ -1090,15 +1129,23 @@ def makeDatasetTypesSet(
current = registry.getDatasetType(taskDef.metadataDatasetName)
except KeyError:
# No previous definition so use the default.
storageClass = "TaskMetadata" if _TASK_METADATA_TYPE is TaskMetadata else "PropertySet"
storageClass = acc.METADATA_OUTPUT_STORAGE_CLASS
else:
storageClass = current.storageClass.name

outputs.update({DatasetType(taskDef.metadataDatasetName, dimensions, storageClass)})
if taskDef.logOutputDatasetName is not None:
# Log output dimensions correspond to a task quantum.
dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
outputs.update({DatasetType(taskDef.logOutputDatasetName, dimensions, "ButlerLogRecords")})
outputs.update(
{
DatasetType(
taskDef.logOutputDatasetName,
dimensions,
acc.LOG_OUTPUT_STORAGE_CLASS,
)
}
)

outputs.freeze()

Expand Down
Loading

0 comments on commit c706735

Please sign in to comment.