From e8d33490185802d3dbdc9103e299adb38fa4aafc Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Sat, 10 Jun 2023 15:26:53 -0400 Subject: [PATCH] Switch to a single PipelineGraph class. --- python/lsst/pipe/base/__init__.py | 4 +- python/lsst/pipe/base/pipeTools.py | 6 +- python/lsst/pipe/base/pipeline.py | 67 +- .../lsst/pipe/base/pipeline_graph/__init__.py | 2 - .../base/pipeline_graph/_dataset_types.py | 194 +++--- .../lsst/pipe/base/pipeline_graph/_edges.py | 33 +- .../pipe/base/pipeline_graph/_exceptions.py | 9 +- .../base/pipeline_graph/_extract_helper.py | 106 ---- python/lsst/pipe/base/pipeline_graph/_io.py | 174 +++--- .../base/pipeline_graph/_mapping_views.py | 10 +- .../pipeline_graph/_mutable_pipeline_graph.py | 298 --------- .../lsst/pipe/base/pipeline_graph/_nodes.py | 83 +-- .../base/pipeline_graph/_pipeline_graph.py | 573 +++++++++++++----- .../_resolved_pipeline_graph.py | 231 ------- .../pipe/base/pipeline_graph/_task_subsets.py | 62 +- .../lsst/pipe/base/pipeline_graph/_tasks.py | 116 ++-- .../pipe/base/tests/pipelineStepTester.py | 15 +- tests/test_pipeline_graph.py | 207 +++---- 18 files changed, 863 insertions(+), 1327 deletions(-) delete mode 100644 python/lsst/pipe/base/pipeline_graph/_extract_helper.py delete mode 100644 python/lsst/pipe/base/pipeline_graph/_mutable_pipeline_graph.py delete mode 100644 python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py diff --git a/python/lsst/pipe/base/__init__.py b/python/lsst/pipe/base/__init__.py index e440d9eff..ed58c5935 100644 --- a/python/lsst/pipe/base/__init__.py +++ b/python/lsst/pipe/base/__init__.py @@ -12,9 +12,9 @@ from .graphBuilder import * from .pipeline import * -# We import the main PipelineGraph types and the module (above), but we don't +# We import the main PipelineGraph type and the module (above), but we don't # lift all symbols to package scope. -from .pipeline_graph import MutablePipelineGraph, ResolvedPipelineGraph +from .pipeline_graph import PipelineGraph from .pipelineTask import * from .struct import * from .task import * diff --git a/python/lsst/pipe/base/pipeTools.py b/python/lsst/pipe/base/pipeTools.py index ed94d28d2..3fd907322 100644 --- a/python/lsst/pipe/base/pipeTools.py +++ b/python/lsst/pipe/base/pipeTools.py @@ -33,7 +33,7 @@ from .pipeline import Pipeline, TaskDef # Exceptions re-exported here for backwards compatibility. -from .pipeline_graph import DuplicateOutputError, MutablePipelineGraph, PipelineDataCycleError # noqa: F401 +from .pipeline_graph import DuplicateOutputError, PipelineDataCycleError, PipelineGraph # noqa: F401 if TYPE_CHECKING: from .taskFactory import TaskFactory @@ -73,7 +73,7 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF if isinstance(pipeline, Pipeline): graph = pipeline.to_graph() else: - graph = MutablePipelineGraph() + graph = PipelineGraph() for task_def in pipeline: graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections) # Can't use graph.is_sorted because that requires sorted dataset type names @@ -111,7 +111,7 @@ def orderPipeline(pipeline: Pipeline | Iterable[TaskDef]) -> list[TaskDef]: if isinstance(pipeline, Pipeline): graph = pipeline.to_graph() else: - graph = MutablePipelineGraph() + graph = PipelineGraph() for task_def in pipeline: graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections) graph.sort() diff --git a/python/lsst/pipe/base/pipeline.py b/python/lsst/pipe/base/pipeline.py index 3b6f011b7..dea33448f 100644 --- a/python/lsst/pipe/base/pipeline.py +++ b/python/lsst/pipe/base/pipeline.py @@ -53,13 +53,7 @@ # ----------------------------- # Imports for other modules -- -from lsst.daf.butler import ( - DataCoordinate, - DatasetType, - DimensionUniverse, - NamedValueSet, - Registry, -) +from lsst.daf.butler import DataCoordinate, DatasetType, DimensionUniverse, NamedValueSet, Registry from lsst.resources import ResourcePath, ResourcePathExpression from lsst.utils import doImportType from lsst.utils.introspection import get_full_type_name @@ -757,7 +751,7 @@ def write_to_uri(self, uri: ResourcePathExpression) -> None: """ self._pipelineIR.write_to_uri(uri) - def to_graph(self) -> pipeline_graph.MutablePipelineGraph: + def to_graph(self) -> pipeline_graph.PipelineGraph: """Construct a pipeline graph from this pipeline. Constructing a graph applies all configuration overrides, freezes all @@ -767,10 +761,10 @@ def to_graph(self) -> pipeline_graph.MutablePipelineGraph: Returns ------- - graph : `pipeline_graph.MutablePipelineGraph` + graph : `pipeline_graph.PipelineGraph` Representation of the pipeline as a graph. """ - graph = pipeline_graph.MutablePipelineGraph() + graph = pipeline_graph.PipelineGraph() graph.description = self._pipelineIR.description for label in self._pipelineIR.tasks: self._add_task_to_graph(label, graph) @@ -810,7 +804,7 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]: """ yield from self.to_graph()._iter_task_defs() - def _add_task_to_graph(self, label: str, graph: pipeline_graph.MutablePipelineGraph) -> None: + def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) -> None: """Add a single task from this pipeline to a pipeline graph that is under construction. @@ -818,7 +812,7 @@ def _add_task_to_graph(self, label: str, graph: pipeline_graph.MutablePipelineGr ---------- label : `str` Label for the task to be added. - graph : `pipeline_graph.MutablePipelineGraph` + graph : `pipeline_graph.PipelineGraph` Graph to add the task to. """ if (taskIR := self._pipelineIR.tasks.get(label)) is None: @@ -845,7 +839,7 @@ def __getitem__(self, item: str) -> TaskDef: # Making a whole graph and then making a TaskDef from that is pretty # backwards, but I'm hoping to deprecate this method shortly in favor # of making the graph explicitly and working with its node objects. - graph = pipeline_graph.MutablePipelineGraph() + graph = pipeline_graph.PipelineGraph() self._add_task_to_graph(item, graph) (result,) = graph._iter_task_defs() return result @@ -971,17 +965,19 @@ def fromTaskDef( # the whole class soon, but for now and before it's actually removed # it's more important to avoid duplication with PipelineGraph's dataset # type resolution logic. - mgraph = pipeline_graph.MutablePipelineGraph() - mgraph.add_task(taskDef.label, taskDef.taskClass, taskDef.config, taskDef.connections) - rgraph = mgraph.resolved(registry) - (task_node,) = rgraph.tasks.values() - return cls._from_graph_nodes(task_node, rgraph.dataset_types) + graph = pipeline_graph.PipelineGraph() + graph.add_task(taskDef.label, taskDef.taskClass, taskDef.config, taskDef.connections) + graph.resolve(registry) + (task_node,) = graph.tasks.values() + return cls._from_graph_nodes( + task_node, cast(Mapping[str, pipeline_graph.DatasetTypeNode], graph.dataset_types) + ) @classmethod def _from_graph_nodes( cls, task_node: pipeline_graph.TaskNode, - dataset_type_nodes: Mapping[str, pipeline_graph.ResolvedDatasetTypeNode], + dataset_type_nodes: Mapping[str, pipeline_graph.DatasetTypeNode], include_configs: bool = True, ) -> TaskDatasetTypes: """Construct from `PipelineGraph` nodes. @@ -1146,17 +1142,17 @@ def fromPipeline( of the same `Pipeline`. """ if isinstance(pipeline, Pipeline): - mgraph = pipeline.to_graph() + graph = pipeline.to_graph() else: - mgraph = pipeline_graph.MutablePipelineGraph() + graph = pipeline_graph.PipelineGraph() for task_def in pipeline: - mgraph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections) - rgraph = mgraph.resolved(registry) + graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections) + graph.resolve(registry) byTask = dict() - for task_node in rgraph.tasks.values(): + for task_node in graph.tasks.values(): byTask[task_node.label] = TaskDatasetTypes._from_graph_nodes( task_node, - rgraph.dataset_types, + cast(Mapping[str, pipeline_graph.DatasetTypeNode], graph.dataset_types), include_configs=include_configs, ) result = cls( @@ -1177,8 +1173,9 @@ def fromPipeline( # PipelineGraph does, by putting components in the edge objects). But # including all components as well is what this code has done in the # past and changing that would break downstream code. - for dataset_type_node in rgraph.dataset_types.values(): - if consumers := rgraph.consumers_of(dataset_type_node.name): + for dataset_type_node in graph.dataset_types.values(): + assert dataset_type_node is not None, "Graph is expected to be resolved." + if consumers := graph.consumers_of(dataset_type_node.name): dataset_types = [ ( dataset_type_node.dataset_type.makeComponentDatasetType(edge.component) @@ -1188,21 +1185,21 @@ def fromPipeline( for edge in consumers.values() ] if any(edge.is_init for edge in consumers.values()): - if rgraph.producer_of(dataset_type_node.name) is None: + if graph.producer_of(dataset_type_node.name) is None: result.initInputs.update(dataset_types) else: result.initIntermediates.update(dataset_types) else: if dataset_type_node.is_prerequisite: result.prerequisites.update(dataset_types) - elif rgraph.producer_of(dataset_type_node.name) is None: + elif graph.producer_of(dataset_type_node.name) is None: result.inputs.update(dataset_types) if dataset_type_node.is_initial_query_constraint: result.queryConstraints.add(dataset_type_node.dataset_type) - elif rgraph.consumers_of(dataset_type_node.name): + elif graph.consumers_of(dataset_type_node.name): result.intermediates.update(dataset_types) else: - producer = rgraph.producer_of(dataset_type_node.name) + producer = graph.producer_of(dataset_type_node.name) assert ( producer is not None ), "Dataset type must have either a producer or consumers to be in graph." @@ -1254,15 +1251,15 @@ def initOutputNames( Name of the dataset type. """ if isinstance(pipeline, Pipeline): - mgraph = pipeline.to_graph() + graph = pipeline.to_graph() else: - mgraph = pipeline_graph.MutablePipelineGraph() + graph = pipeline_graph.PipelineGraph() for task_def in pipeline: - mgraph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections) + graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections) if include_packages: # Package versions dataset type yield cls.packagesDatasetName - for task_node in mgraph.tasks.values(): + for task_node in graph.tasks.values(): edges = task_node.init.iter_all_outputs() if include_configs else task_node.init.outputs for edge in edges: yield edge.dataset_type_name diff --git a/python/lsst/pipe/base/pipeline_graph/__init__.py b/python/lsst/pipe/base/pipeline_graph/__init__.py index fa4c211f1..3cf7a8101 100644 --- a/python/lsst/pipe/base/pipeline_graph/__init__.py +++ b/python/lsst/pipe/base/pipeline_graph/__init__.py @@ -23,9 +23,7 @@ from ._dataset_types import * from ._edges import * from ._exceptions import * -from ._mutable_pipeline_graph import * from ._nodes import * from ._pipeline_graph import * -from ._resolved_pipeline_graph import * from ._task_subsets import * from ._tasks import * diff --git a/python/lsst/pipe/base/pipeline_graph/_dataset_types.py b/python/lsst/pipe/base/pipeline_graph/_dataset_types.py index b738583aa..1a547ff9b 100644 --- a/python/lsst/pipe/base/pipeline_graph/_dataset_types.py +++ b/python/lsst/pipe/base/pipeline_graph/_dataset_types.py @@ -20,70 +20,85 @@ # along with this program. If not, see . from __future__ import annotations -__all__ = ( - "DatasetTypeNode", - "ResolvedDatasetTypeNode", -) +__all__ = ("DatasetTypeNode",) from typing import TYPE_CHECKING, Any import networkx -from lsst.daf.butler import DatasetRef, DatasetType, Registry +from lsst.daf.butler import DatasetRef, DatasetType, DimensionGraph, Registry, StorageClass from lsst.daf.butler.registry import MissingDatasetTypeError from pydantic import BaseModel -from ._nodes import Node, NodeKey +from ._nodes import NodeKey, NodeType if TYPE_CHECKING: from ._edges import ReadEdge, WriteEdge from ._tasks import TaskInitNode, TaskNode -class DatasetTypeNode(Node): - """A node in a pipeline graph that represents a dataset type. +class DatasetTypeNode: + """A node in a resolved pipeline graph that represents a dataset type. Parameters ---------- - node : `NodeKey` - Key for this node in the graph. + dataset_type : `DatasetType` + Common definition of this dataset type for the graph. + is_prerequisite: `bool` + Whether this dataset type is a prerequisite input that must exist in + the Registry before graph creation. + is_initial_query_constraint : `bool` + Whether this dataset should be included as a constraint in the initial + query for data IDs in QuantumGraph generation. + + This is only `True` for dataset types that are overall regular inputs, + and also `False` if all such connections had + ``deferQueryConstraint=True``. + is_registered : `bool` + Whether this dataset type was registered in the data repository when it + was resolved. + + When `is_registered` is `True`, the storage class is guaranteed to + match the data repository definition. Notes ----- - This class only holds information that can be pulled unambiguously from - `.PipelineTask` a single definitions, without input from the data - repository or other tasks - which amounts to just the parent dataset type - name. The `ResolvedDatasetTypeNode` subclass also includes information - from the data repository and holds an actual `DatasetType` instance. - A dataset type node represents a common definition of the dataset type - across the entire graph, which means it never refers to a component. + across the entire graph - it is never a component, and when storage class + information is present (in `DatasetTypeNode`) this is the registry dataset + type's storage class or (if there isn't one) the one defined by the + producing task. Dataset type nodes are intentionally not equality comparable, since there - are many different (and useful) ways to compare its resolved variant, with - no clear winner as the most obvious behavior. + are many different (and useful) ways to compare these objects with no clear + winner as the most obvious behavior. """ - @property - def name(self) -> str: - """Name of the dataset type. - - This is always the parent dataset type, never that of a component. - """ - return str(self.key) + def __init__( + self, + *, + dataset_type: DatasetType, + is_prerequisite: bool, + is_initial_query_constraint: bool, + is_registered: bool, + ): + self.dataset_type = dataset_type + self.is_prerequisite = is_prerequisite + self.is_initial_query_constraint = is_initial_query_constraint + self.is_registered = is_registered - def _resolved(self, xgraph: networkx.DiGraph, registry: Registry) -> ResolvedDatasetTypeNode: - # Docstring inherited. + @classmethod + def _from_edges(cls, key: NodeKey, xgraph: networkx.DiGraph, registry: Registry) -> DatasetTypeNode: try: - dataset_type = registry.getDatasetType(self.name) - in_data_repo = True + dataset_type = registry.getDatasetType(key.name) + is_registered = True except MissingDatasetTypeError: dataset_type = None - in_data_repo = False + is_registered = False is_initial_query_constraint = True is_prerequisite: bool | None = None producer: str | None = None write_edge: WriteEdge - for _, _, write_edge in xgraph.in_edges(self.key, data="instance"): # will iterate zero or one time + for _, _, write_edge in xgraph.in_edges(key, data="instance"): # will iterate zero or one time task_node: TaskNode | TaskInitNode = xgraph.nodes[write_edge.task_key]["instance"] connection_map = task_node._get_connection_map() dataset_type = write_edge._resolve_dataset_type( @@ -95,7 +110,7 @@ def _resolved(self, xgraph: networkx.DiGraph, registry: Registry) -> ResolvedDat is_initial_query_constraint = False read_edge: ReadEdge consumers: list[str] = [] - for _, _, read_edge in xgraph.out_edges(self.key, data="instance"): + for _, _, read_edge in xgraph.out_edges(key, data="instance"): task_node = xgraph.nodes[read_edge.task_key]["instance"] connection_map = task_node._get_connection_map() dataset_type, is_initial_query_constraint, is_prerequisite = read_edge._resolve_dataset_type( @@ -104,80 +119,62 @@ def _resolved(self, xgraph: networkx.DiGraph, registry: Registry) -> ResolvedDat universe=registry.dimensions, is_initial_query_constraint=is_initial_query_constraint, is_prerequisite=is_prerequisite, - in_data_repo=in_data_repo, + is_registered=is_registered, producer=producer, consumers=consumers, ) consumers.append(read_edge.task_label) assert dataset_type is not None, "Graph structure guarantees at least one edge." assert is_prerequisite is not None, "Having at least one edge guarantees is_prerequisite is known." - return ResolvedDatasetTypeNode( - key=self.key, + return DatasetTypeNode( dataset_type=dataset_type, is_initial_query_constraint=is_initial_query_constraint, is_prerequisite=is_prerequisite, + is_registered=is_registered, ) - def _unresolved(self) -> DatasetTypeNode: - # Docstring inherited. - return self - - def _serialize(self) -> SerializedDatasetTypeNode: - # Docstring inherited. - return SerializedDatasetTypeNode.construct() - - def _to_xgraph_state(self) -> dict[str, Any]: - # Docstring inherited. - return { - "bipartite": self.key.node_type.bipartite, - } - + def _resolved(self, registry: Registry) -> DatasetTypeNode: + """Resolve an existing DatasetTypeNode against current data repository + content. -class ResolvedDatasetTypeNode(DatasetTypeNode): - """A node in a resolved pipeline graph that represents a dataset type. + Since DatasetTypeNodes are updated or replaced with `None` whenever new + edges are added to the graph, the only thing that might have changed + when this method is the registration in the data repository. + """ + try: + dataset_type = registry.getDatasetType(self.dataset_type.name) + is_registered = True + except MissingDatasetTypeError: + dataset_type = self.dataset_type + is_registered = False + if is_registered == self.is_registered: + return self + return DatasetTypeNode( + is_prerequisite=self.is_prerequisite, + dataset_type=dataset_type, + is_initial_query_constraint=self.is_initial_query_constraint, + is_registered=self.is_registered, + ) - Parameters - ---------- - node : `NodeKey` - Key for this node in the graph. - is_prerequisite: `bool` - Whether this dataset type is a prerequisite input that must exist in - the Registry before graph creation. - dataset_type : `DatasetType` - Common definition of this dataset type for the graph. - is_initial_query_constraint : `bool` - Whether this dataset should be included as a constraint in the initial - query for data IDs in QuantumGraph generation. + @property + def name(self) -> str: + """Name of the dataset type. - This is only `True` for dataset types that are overall regular inputs, - and also `False` if all such connections had - ``deferQueryConstraint=True``. + This is always the parent dataset type, never that of a component. + """ + return self.dataset_type.name - Notes - ----- - A dataset type node represents a common definition of the dataset type - across the entire graph - it is never a component, and when storage class - information is present (in `ResolvedDatasetTypeNode`) this is the registry - dataset type's storage class or (if there isn't one) the one defined by the - producing task. + @property + def dimensions(self) -> DimensionGraph: + return self.dataset_type.dimensions - Dataset type nodes are intentionally not equality comparable, since there - are many different (and useful) ways to compare these objects with no clear - winner as the most obvious behavior. - """ + @property + def storage_class_name(self) -> str: + return self.dataset_type.storageClass_name - def __init__( - self, - key: NodeKey, - *, - is_prerequisite: bool, - dataset_type: DatasetType, - is_initial_query_constraint: bool, - ): - super().__init__(key) - self.dataset_type = dataset_type - self.is_initial_query_constraint = is_initial_query_constraint - self.is_prerequisite = is_prerequisite + @property + def storage_class(self) -> StorageClass: + return self.dataset_type.storageClass dataset_type: DatasetType """Common definition of this dataset type for the graph. @@ -196,13 +193,13 @@ def __init__( the Registry before graph creation. """ - def _resolved(self, xgraph: networkx.DiGraph, registry: Registry) -> ResolvedDatasetTypeNode: - # Docstring inherited. - return self + is_registered: bool + """Whether this dataset type was registered in the data repository when + it was resolved. - def _unresolved(self) -> DatasetTypeNode: - # Docstring inherited. - return DatasetTypeNode(key=self.key) + When `is_registered` is `True`, the storage class is guaranteed to match + the data repository definition. + """ def generalize_ref(self, ref: DatasetRef) -> DatasetRef: """Convert a `~lsst.daf.butler.DatasetRef` with the dataset type @@ -223,6 +220,7 @@ def _serialize(self) -> SerializedDatasetTypeNode: is_calibration=self.dataset_type.isCalibration(), is_initial_query_constraint=self.is_initial_query_constraint, is_prerequisite=self.is_prerequisite, + is_registered=self.is_registered, ) def _to_xgraph_state(self) -> dict[str, Any]: @@ -231,9 +229,10 @@ def _to_xgraph_state(self) -> dict[str, Any]: "dataset_type": self.dataset_type, "is_initial_query_constraint": self.is_initial_query_constraint, "is_prerequisite": self.is_prerequisite, + "is_registered": self.is_registered, "dimensions": self.dataset_type.dimensions, "storage_class_name": self.dataset_type.storageClass_name, - "bipartite": self.key.node_type.bipartite, + "bipartite": NodeType.DATASET_TYPE.bipartite, } @@ -243,4 +242,5 @@ class SerializedDatasetTypeNode(BaseModel): is_calibration: bool = False is_initial_query_constraint: bool = False is_prerequisite: bool = False + is_registered: bool = False index: int | None = None diff --git a/python/lsst/pipe/base/pipeline_graph/_edges.py b/python/lsst/pipe/base/pipeline_graph/_edges.py index eb0897430..700cbc296 100644 --- a/python/lsst/pipe/base/pipeline_graph/_edges.py +++ b/python/lsst/pipe/base/pipeline_graph/_edges.py @@ -151,8 +151,10 @@ def adapt_dataset_ref(self, ref: DatasetRef) -> DatasetRef: """ raise NotImplementedError() - def _check_dataset_type( - self, xgraph: networkx.DiGraph, dataset_type_node: DatasetTypeNode + def _update_dataset_type( + self, + xgraph: networkx.DiGraph, + dataset_type_node: DatasetTypeNode | None, ) -> DatasetTypeNode | None: """Check the a potential graph-wide definition of a dataset type for consistency with this edge. @@ -161,16 +163,16 @@ def _check_dataset_type( ----------- xgraph : `networkx.DiGraph` Directed bipartite graph representing the full pipeline. - dataset_type_node : `DatasetTypeNode` + dataset_type_node : `DatasetTypeNode` or `None` Dataset type node to be checked and possibly updated. Returns ------- - updated : `bool` - New `DatasetTypeNode` if it needs to be changed, or `None` if it - does not. + updated : `DatasetTypeNode` or `None` + Possibly-updated node to include in the graph after the addition + of this edge. """ - pass + return None @abstractmethod def _serialize(self) -> BaseModel: @@ -325,7 +327,7 @@ def _resolve_dataset_type( universe: DimensionUniverse, producer: str | None, consumers: Sequence[str], - in_data_repo: bool, + is_registered: bool, ) -> tuple[DatasetType, bool, bool]: """Participate in the construction of the graph-wide `DatasetType` object associated with this edge. @@ -356,7 +358,7 @@ def _resolve_dataset_type( consumers : `Sequence` [ `str` ] Labels for other consuming tasks that have already participated in this dataset type's resolution. - in_data_repo : `bool` + is_registered : `bool` Whether are registration for this dataset type was found in the data repository. @@ -416,7 +418,7 @@ def _resolve_dataset_type( ) def report_current_origin() -> str: - if in_data_repo: + if is_registered: return "data repository" elif producer is not None: return f"producing task {producer!r}" @@ -541,13 +543,18 @@ def _from_connection_map( connection_name=connection_name, ) - def _check_dataset_type(self, xgraph: networkx.DiGraph, dataset_type_node: DatasetTypeNode) -> None: + def _update_dataset_type( + self, + xgraph: networkx.DiGraph, + dataset_type_node: DatasetTypeNode | None, + ) -> DatasetTypeNode | None: # Docstring inherited. - for existing_producer in xgraph.predecessors(dataset_type_node.key): + for existing_producer in xgraph.predecessors(self.dataset_type_key): raise DuplicateOutputError( - f"Dataset type {dataset_type_node.name} is produced by both {self.task_label!r} " + f"Dataset type {self.parent_dataset_type_name!r} is produced by both {self.task_label!r} " f"and {existing_producer!r}." ) + return None def _resolve_dataset_type( self, *, connection: BaseConnection, current: DatasetType | None, universe: DimensionUniverse diff --git a/python/lsst/pipe/base/pipeline_graph/_exceptions.py b/python/lsst/pipe/base/pipeline_graph/_exceptions.py index c4164a632..467513aa4 100644 --- a/python/lsst/pipe/base/pipeline_graph/_exceptions.py +++ b/python/lsst/pipe/base/pipeline_graph/_exceptions.py @@ -27,8 +27,9 @@ "PipelineDataCycleError", "PipelineGraphError", "PipelineGraphReadError", - "TaskNotImportedError", "ReadInconsistencyError", + "UnresolvedGraphError", + "TaskNotImportedError", ) @@ -61,6 +62,12 @@ class IncompatibleDatasetTypeError(PipelineGraphError): """ +class UnresolvedGraphError(PipelineGraphError): + """Exception raised when an operation requires dimensions or dataset types + to have been resolved, but they have not been. + """ + + class PipelineGraphReadError(PipelineGraphError, IOError): """Exception raised when a serialized PipelineGraph cannot be read.""" diff --git a/python/lsst/pipe/base/pipeline_graph/_extract_helper.py b/python/lsst/pipe/base/pipeline_graph/_extract_helper.py deleted file mode 100644 index b591e6a71..000000000 --- a/python/lsst/pipe/base/pipeline_graph/_extract_helper.py +++ /dev/null @@ -1,106 +0,0 @@ -# This file is part of pipe_base. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -from __future__ import annotations - -__all__ = ("ExtractHelper",) - -from collections.abc import Iterable -from types import EllipsisType -from typing import TYPE_CHECKING, Generic, TypeVar - -import networkx -import networkx.algorithms.bipartite -import networkx.algorithms.dag -from lsst.utils.iteration import ensure_iterable - -from ._nodes import Node, NodeKey, NodeType - -if TYPE_CHECKING: - from ._mutable_pipeline_graph import MutablePipelineGraph - from ._pipeline_graph import PipelineGraph - - -_P = TypeVar("_P", bound="PipelineGraph", covariant=True) - - -class ExtractHelper(Generic[_P]): - def __init__(self, parent: _P) -> None: - self._parent = parent - self._run_xgraph: networkx.DiGraph | None = None - self._task_keys: set[NodeKey] = set() - - def include_tasks(self, labels: str | Iterable[str] | EllipsisType = ...) -> None: - if labels is ...: - self._task_keys.update(key for key in self._parent._xgraph if key.node_type is NodeType.TASK) - else: - self._task_keys.update( - NodeKey(NodeType.TASK, task_label) for task_label in ensure_iterable(labels) - ) - - def exclude_tasks(self, labels: str | Iterable[str]) -> None: - self._task_keys.difference_update( - NodeKey(NodeType.TASK, task_label) for task_label in ensure_iterable(labels) - ) - - def include_subset(self, label: str) -> None: - self._task_keys.update(node.key for node in self._parent.task_subsets[label].values()) - - def exclude_subset(self, label: str) -> None: - self._task_keys.difference_update(node.key for node in self._parent.task_subsets[label].values()) - - def start_after(self, names: str | Iterable[str], node_type: NodeType) -> None: - to_exclude: set[NodeKey] = set() - for name in ensure_iterable(names): - key = NodeKey(node_type, name) - to_exclude.update(networkx.algorithms.dag.ancestors(self._get_run_xgraph(), key)) - to_exclude.add(key) - self._task_keys.difference_update(to_exclude) - - def stop_at(self, names: str | Iterable[str], node_type: NodeType) -> None: - to_exclude: set[NodeKey] = set() - for name in ensure_iterable(names): - key = NodeKey(node_type, name) - to_exclude.update(networkx.algorithms.dag.descendants(self._get_run_xgraph(), key)) - self._task_keys.difference_update(to_exclude) - - def finish(self, description: str | None = None) -> MutablePipelineGraph: - from ._mutable_pipeline_graph import MutablePipelineGraph - - if description is None: - description = self._parent._description - # Combine the task_keys we're starting with and the keys for their init - # nodes. - keys = self._task_keys | {NodeKey(NodeType.TASK_INIT, key.name) for key in self._task_keys} - # Also add the keys for the adjacent dataset type nodes. - keys.update(networkx.node_boundary(self._parent._xgraph.to_undirected(as_view=True), keys)) - # Make the new backing networkx graph. - xgraph: networkx.DiGraph = self._parent._xgraph.subgraph(keys).copy() - for state in xgraph.nodes.values(): - node: Node = state["instance"] - state["instance"] = node._unresolved() - result = MutablePipelineGraph.__new__(MutablePipelineGraph) - result._init_from_args(xgraph, None, description=description) - return result - - def _get_run_xgraph(self) -> networkx.DiGraph: - if self._run_xgraph is None: - self._run_xgraph = self._parent.make_bipartite_xgraph(init=False) - return self._run_xgraph diff --git a/python/lsst/pipe/base/pipeline_graph/_io.py b/python/lsst/pipe/base/pipeline_graph/_io.py index 9ced70f92..57572e6de 100644 --- a/python/lsst/pipe/base/pipeline_graph/_io.py +++ b/python/lsst/pipe/base/pipeline_graph/_io.py @@ -22,22 +22,20 @@ import os import tarfile -from abc import abstractmethod from collections.abc import Sequence -from typing import Any, BinaryIO, Generic, TypeVar +from typing import Any, BinaryIO, TypeVar import networkx import pydantic +from lsst.daf.butler import DatasetType, DimensionConfig, DimensionGraph, DimensionUniverse -from ._dataset_types import SerializedDatasetTypeNode +from ._dataset_types import DatasetTypeNode, SerializedDatasetTypeNode from ._edges import ReadEdge, SerializedEdge, WriteEdge from ._exceptions import PipelineGraphReadError -from ._mapping_views import _D, _T from ._nodes import NodeKey, NodeType from ._task_subsets import SerializedTaskSubset, TaskSubset -from ._tasks import SerializedTaskInitNode, SerializedTaskNode, TaskInitNode +from ._tasks import SerializedTaskInitNode, SerializedTaskNode, TaskInitNode, TaskNode -_S = TypeVar("_S", bound="TaskSubset", covariant=True) _U = TypeVar("_U") _IO_VERSION_INFO = (0, 0, 1) @@ -90,33 +88,47 @@ def read_stream(cls, stream: BinaryIO) -> SerializedPipelineGraph: return serialized_graph -class PipelineGraphReader(Generic[_T, _D, _S]): +class PipelineGraphReader: def __init__(self) -> None: self.xgraph = networkx.DiGraph() self.sort_keys: Sequence[NodeKey] | None = None - self.task_subsets: dict[str, _S] = {} + self.task_subsets: dict[str, TaskSubset] = {} self.description: str = "" + self.universe: DimensionUniverse | None = None def deserialize_graph(self, serialized_graph: SerializedPipelineGraph) -> None: + if serialized_graph.dimensions is not None: + self.universe = DimensionUniverse( + config=DimensionConfig( + expect_not_none( + serialized_graph.dimensions, + "Serialized pipeline graph has not been resolved; " + "load it is a MutablePipelineGraph instead.", + ) + ) + ) sort_index_map: dict[int, NodeKey] = {} for dataset_type_name, serialized_dataset_type in serialized_graph.dataset_types.items(): - dataset_type_node = self.deserialize_dataset_type(dataset_type_name, serialized_dataset_type) + dataset_type_key = NodeKey(NodeType.DATASET_TYPE, dataset_type_name) + dataset_type_node = self.deserialize_dataset_type(dataset_type_key, serialized_dataset_type) self.xgraph.add_node( - dataset_type_node.key, instance=dataset_type_node, bipartite=NodeType.DATASET_TYPE.value + dataset_type_key, instance=dataset_type_node, bipartite=NodeType.DATASET_TYPE.value ) if serialized_dataset_type.index is not None: - sort_index_map[serialized_dataset_type.index] = dataset_type_node.key + sort_index_map[serialized_dataset_type.index] = dataset_type_key for task_label, serialized_task in serialized_graph.tasks.items(): - task_node = self.deserialize_task(task_label, serialized_task) + task_key = NodeKey(NodeType.TASK, task_label) + task_init_key = NodeKey(NodeType.TASK_INIT, task_label) + task_node = self.deserialize_task(task_key, task_init_key, serialized_task) if serialized_task.index is not None: - sort_index_map[serialized_task.index] = task_node.key + sort_index_map[serialized_task.index] = task_key if serialized_task.init.index is not None: - sort_index_map[serialized_task.init.index] = task_node.init.key - self.xgraph.add_node(task_node.key, instance=task_node, bipartite=NodeType.TASK.bipartite) + sort_index_map[serialized_task.init.index] = task_init_key + self.xgraph.add_node(task_key, instance=task_node, bipartite=NodeType.TASK.bipartite) self.xgraph.add_node( - task_node.init.key, instance=task_node.init, bipartite=NodeType.TASK_INIT.bipartite + task_init_key, instance=task_node.init, bipartite=NodeType.TASK_INIT.bipartite ) - self.xgraph.add_edge(task_node.init.key, task_node.key, instance=None) + self.xgraph.add_edge(task_init_key, task_key, instance=None) for read_edge in task_node.init.iter_all_inputs(): self.xgraph.add_edge(read_edge.dataset_type_key, read_edge.task_key, instance=read_edge) for write_edge in task_node.init.iter_all_outputs(): @@ -131,58 +143,42 @@ def deserialize_graph(self, serialized_graph: SerializedPipelineGraph) -> None: self.sort_keys = [sort_index_map[i] for i in range(len(self.xgraph))] self.description = serialized_graph.description - @abstractmethod - def deserialize_dataset_type(self, name: str, serialized_dataset_type: SerializedDatasetTypeNode) -> _D: - raise NotImplementedError() - - @abstractmethod - def deserialize_task(self, label: str, serialized_task: SerializedTaskNode) -> _T: - raise NotImplementedError() - - @abstractmethod - def deserialize_task_subset(self, label: str, serialized_task_subset: SerializedTaskSubset) -> _S: - raise NotImplementedError() - - def deserialize_task_init( - self, - label: str, - serialized_task_init: SerializedTaskInitNode, - task_class_name: str, - config_str: str, - ) -> TaskInitNode: - key = NodeKey(NodeType.TASK_INIT, label) - return TaskInitNode( - key, - inputs={ - self.deserialize_read_edge(key, parent_dataset_type_name, serialized_edge) - for parent_dataset_type_name, serialized_edge in serialized_task_init.inputs.items() - }, - outputs={ - self.deserialize_write_edge(key, parent_dataset_type_name, serialized_edge) - for parent_dataset_type_name, serialized_edge in serialized_task_init.outputs.items() - }, - config_output=self.deserialize_write_edge( - key, + def deserialize_dataset_type( + self, key: NodeKey, serialized_dataset_type: SerializedDatasetTypeNode + ) -> DatasetTypeNode | None: + if serialized_dataset_type.dimensions is not None: + dataset_type = DatasetType( + key.name, expect_not_none( - serialized_task_init.config_output.dataset_type_name, - "Serialized task config edges should always have a dataset type.", + serialized_dataset_type.dimensions, + f"Serialized dataset type {key.name!r} has no dimensions.", ), - serialized_task_init.config_output, - ), - task_class_name=task_class_name, - config_str=config_str, - ) + storageClass=expect_not_none( + serialized_dataset_type.storage_class, + f"Serialized dataset type {key.name!r} has no storage class.", + ), + isCalibration=serialized_dataset_type.is_calibration, + universe=self.universe, + ) + return DatasetTypeNode( + dataset_type=dataset_type, + is_prerequisite=serialized_dataset_type.is_prerequisite, + is_initial_query_constraint=serialized_dataset_type.is_initial_query_constraint, + is_registered=serialized_dataset_type.is_registered, + ) + return None - def deserialize_task_args(self, label: str, serialized_task: SerializedTaskNode) -> dict[str, Any]: + def deserialize_task( + self, key: NodeKey, init_key: NodeKey, serialized_task: SerializedTaskNode + ) -> TaskNode: init = self.deserialize_task_init( - label, + init_key, serialized_task.init, task_class_name=serialized_task.task_class, config_str=expect_not_none( - serialized_task.config_str, f"No serialized config file for task with label {label!r}." + serialized_task.config_str, f"No serialized config file for task with label {key.name!r}." ), ) - key = NodeKey(NodeType.TASK, label) inputs = { self.deserialize_read_edge(key, parent_dataset_type_name, serialized_edge) for parent_dataset_type_name, serialized_edge in serialized_task.inputs.items() @@ -214,15 +210,53 @@ def deserialize_task_args(self, label: str, serialized_task: SerializedTaskNode) ), serialized_task.metadata_output, ) - return dict( - key=key, + dimensions: DimensionGraph | None = None + if serialized_task.dimensions is not None: + dimensions = expect_not_none( + self.universe, + f"Dimensions for task {key.name} were persisted, but dimension universe was not.", + ).extract(serialized_task.dimensions) + return TaskNode( init=init, inputs=inputs, prerequisite_inputs=prerequisite_inputs, outputs=outputs, log_output=log_output, metadata_output=metadata_output, - dimensions=None, + dimensions=dimensions, + ) + + def deserialize_task_subset(self, label: str, serialized_task_subset: SerializedTaskSubset) -> TaskSubset: + members = set(serialized_task_subset.tasks) + return TaskSubset(self.xgraph, label, members, serialized_task_subset.description) + + def deserialize_task_init( + self, + key: NodeKey, + serialized_task_init: SerializedTaskInitNode, + task_class_name: str, + config_str: str, + ) -> TaskInitNode: + return TaskInitNode( + key, + inputs={ + self.deserialize_read_edge(key, parent_dataset_type_name, serialized_edge) + for parent_dataset_type_name, serialized_edge in serialized_task_init.inputs.items() + }, + outputs={ + self.deserialize_write_edge(key, parent_dataset_type_name, serialized_edge) + for parent_dataset_type_name, serialized_edge in serialized_task_init.outputs.items() + }, + config_output=self.deserialize_write_edge( + key, + expect_not_none( + serialized_task_init.config_output.dataset_type_name, + "Serialized task config edges should always have a dataset type.", + ), + serialized_task_init.config_output, + ), + task_class_name=task_class_name, + config_str=config_str, ) def deserialize_read_edge( @@ -232,12 +266,7 @@ def deserialize_read_edge( serialized_edge: SerializedEdge, is_prerequisite: bool = False, ) -> ReadEdge: - # Look up dataset type key in the graph, both to validate as we read - # and to reduce the number of distinct but equivalent NodeKey instances - # present in the graph. - dataset_type_key = self.xgraph.nodes[NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name)][ - "instance" - ].key + dataset_type_key = NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name) return ReadEdge( dataset_type_key, task_key, @@ -253,12 +282,7 @@ def deserialize_write_edge( parent_dataset_type_name: str, serialized_edge: SerializedEdge, ) -> WriteEdge: - # Look up dataset type key in the graph, both to validate as we read - # and to reduce the number of distinct but equivalent NodeKey instances - # present in the graph. - dataset_type_key = self.xgraph.nodes[NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name)][ - "instance" - ].key + dataset_type_key = NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name) return WriteEdge( task_key=task_key, dataset_type_key=dataset_type_key, diff --git a/python/lsst/pipe/base/pipeline_graph/_mapping_views.py b/python/lsst/pipe/base/pipeline_graph/_mapping_views.py index 244ff60c9..fee5fa7c8 100644 --- a/python/lsst/pipe/base/pipeline_graph/_mapping_views.py +++ b/python/lsst/pipe/base/pipeline_graph/_mapping_views.py @@ -26,12 +26,10 @@ import networkx from ._dataset_types import DatasetTypeNode -from ._nodes import Node, NodeKey, NodeType +from ._nodes import NodeKey, NodeType from ._tasks import TaskInitNode, TaskNode -_N = TypeVar("_N", bound=Node, covariant=True) -_T = TypeVar("_T", bound=TaskNode, covariant=True) -_D = TypeVar("_D", bound=DatasetTypeNode, covariant=True) +_N = TypeVar("_N", covariant=True) class MappingView(Mapping[str, _N]): @@ -97,7 +95,7 @@ def _make_keys(self, parent_keys: Iterable[NodeKey]) -> list[str]: return [str(k) for k in parent_keys if k.node_type is self._NODE_TYPE] -class TaskMappingView(MappingView[_T]): +class TaskMappingView(MappingView[TaskNode]): _NODE_TYPE = NodeType.TASK @@ -105,5 +103,5 @@ class TaskInitMappingView(MappingView[TaskInitNode]): _NODE_TYPE = NodeType.TASK_INIT -class DatasetTypeMappingView(MappingView[_D]): +class DatasetTypeMappingView(MappingView[DatasetTypeNode | None]): _NODE_TYPE = NodeType.DATASET_TYPE diff --git a/python/lsst/pipe/base/pipeline_graph/_mutable_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_mutable_pipeline_graph.py deleted file mode 100644 index 3cd74f916..000000000 --- a/python/lsst/pipe/base/pipeline_graph/_mutable_pipeline_graph.py +++ /dev/null @@ -1,298 +0,0 @@ -# This file is part of pipe_base. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -from __future__ import annotations - -__all__ = ("MutablePipelineGraph",) - -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, BinaryIO, cast, final - -import networkx -import networkx.algorithms.bipartite -import networkx.algorithms.dag -from lsst.daf.butler import Registry -from lsst.resources import ResourcePathExpression - -from ._dataset_types import DatasetTypeNode, SerializedDatasetTypeNode -from ._edges import Edge -from ._exceptions import PipelineDataCycleError -from ._io import PipelineGraphReader, SerializedPipelineGraph -from ._nodes import Node, NodeKey, NodeType -from ._pipeline_graph import PipelineGraph -from ._task_subsets import MutableTaskSubset, SerializedTaskSubset -from ._tasks import SerializedTaskNode, TaskNode, _TaskNodeImportedData - -if TYPE_CHECKING: - from ..config import PipelineTaskConfig - from ..connections import PipelineTaskConnections - from ..pipelineTask import PipelineTask - from ._resolved_pipeline_graph import ResolvedPipelineGraph - - -@final -class MutablePipelineGraph(PipelineGraph[TaskNode, DatasetTypeNode, MutableTaskSubset]): - """A pipeline graph that can be modified in place. - - Notes - ----- - Mutable pipeline graphs are not automatically sorted and are not checked - for cycles until they are sorted, but they do remember when they've been - sorted so repeated calls to `sort` with no modifications in between are - fast. - - Mutable pipeline graphs never carry around resolved dimensions and dataset - types, since the process of resolving dataset types in particular depends - in subtle ways on having the full graph available. In other words, a graph - that has its dataset types resolved as tasks are added to it could end up - with different dataset types from a complete graph that is resolved all at - once, and we don't want to deal with that kind of inconsistency. - """ - - @classmethod - def read_stream( - cls, stream: BinaryIO, import_and_configure: bool = True, check_edges: bool = True - ) -> MutablePipelineGraph: - serialized_graph = SerializedPipelineGraph.read_stream(stream) - reader = MutablePipelineGraphReader() - reader.deserialize_graph(serialized_graph) - result = MutablePipelineGraph.__new__(MutablePipelineGraph) - result._init_from_args(reader.xgraph, reader.sort_keys, reader.task_subsets, reader.description) - if import_and_configure: - result.import_and_configure_in_place(check_edges=check_edges) - return result - - @classmethod - def read_uri( - cls, uri: ResourcePathExpression, import_and_configure: bool = True, check_edges: bool = True - ) -> MutablePipelineGraph: - return cast( - MutablePipelineGraph, - super().read_uri(uri, import_and_configure=import_and_configure, check_edges=check_edges), - ) - - @property - def description(self) -> str: - # Docstring inherited. - return self._description - - @description.setter - def description(self, value: str) -> None: - # Docstring inherited. - self._description = value - - def copy(self) -> MutablePipelineGraph: - # Docstring inherited. - xgraph = self._xgraph.copy() - result = MutablePipelineGraph.__new__(MutablePipelineGraph) - result._init_from_args( - xgraph, - self._sorted_keys, - task_subsets={k: v._mutable_copy(xgraph) for k, v in self._task_subsets.items()}, - description=self._description, - ) - return result - - def resolved(self, registry: Registry, *, redo: bool = False) -> ResolvedPipelineGraph: - # Docstring inherited. - from ._resolved_pipeline_graph import ResolvedPipelineGraph - - xgraph = self._xgraph.copy() - for state in xgraph.nodes.values(): - node: Node = state["instance"] - state["instance"] = node._resolved(xgraph, registry) - result = ResolvedPipelineGraph.__new__(ResolvedPipelineGraph) - result._init_from_args( - xgraph, - self._sorted_keys, - task_subsets={k: v._resolved(xgraph) for k, v in self._task_subsets.items()}, - description=self._description, - ) - result.universe = registry.dimensions - return result - - def mutable_copy(self) -> MutablePipelineGraph: - # Docstring inherited. - return self.copy() - - def add_task( - self, - label: str, - task_class: type[PipelineTask], - config: PipelineTaskConfig, - connections: PipelineTaskConnections | None = None, - ) -> None: - """Add a new task to the graph. - - Parameters - ---------- - label : `str` - Label for the task in the pipeline. - task_class : `type` [ `PipelineTask` ] - Class object for the task. - config : `PipelineTaskConfig` - Configuration for the task. - connections : `PipelineTaskConnections`, optional - Object that describes the dataset types used by the task. If not - provided, one will be constructed from the given configuration. If - provided, it is assumed that ``config`` has already been validated - and frozen. - - Raises - ------ - ConnectionTypeConsistencyError - Raised if the task defines an edge's ``is_init`` or - ``is_prerequisite`` flags in a way that is inconsistent with some - other task in the graph. - IncompatibleDatasetTypeError - Raised if the task defines a dataset type differently from some - other task in the graph. Note that checks for dataset type - dimension consistency do not occur until the graph is resolved. - ValueError - Raised if configuration validation failed when constructing. - ``connections``. - PipelineDataCycleError - Raised if the graph is cyclic after this addition. - RuntimeError - Raised if an unexpected exception (which will be chained) occurred - at a stage that may have left the graph in an inconsistent state. - Other exceptions should leave the graph unchanged. - """ - # Make the task node, the corresponding state dict that will be held - # by the networkx graph (which includes the node instance), and the - # state dicts for the edges - task_node = TaskNode._from_imported_data( - label, _TaskNodeImportedData.configure(label, task_class, config, connections) - ) - node_data: list[tuple[NodeKey, dict[str, Any]]] = [ - (task_node.key, {"instance": task_node, "bipartite": task_node.key.node_type.bipartite}), - ( - task_node.init.key, - {"instance": task_node.init, "bipartite": task_node.init.key.node_type.bipartite}, - ), - ] - # Convert the edge objects attached to the task node to networkx form. - edge_data: list[tuple[NodeKey, NodeKey, dict[str, Any]]] = [] - for read_edge in task_node.init.iter_all_inputs(): - self._append_graph_from_edge(node_data, edge_data, read_edge) - for write_edge in task_node.init.iter_all_outputs(): - self._append_graph_from_edge(node_data, edge_data, write_edge) - for read_edge in task_node.prerequisite_inputs: - self._append_graph_from_edge(node_data, edge_data, read_edge) - for read_edge in task_node.inputs: - self._append_graph_from_edge(node_data, edge_data, read_edge) - for write_edge in task_node.iter_all_outputs(): - self._append_graph_from_edge(node_data, edge_data, write_edge) - # Add a special edge (with no Edge instance) that connects the - # TaskInitNode to the runtime TaskNode. - edge_data.append((task_node.init.key, task_node.key, {"instance": None})) - # Checks complete; time to start the actual modification, during which - # it's hard to provide strong exception safety. - self._reset() - try: - self._xgraph.add_nodes_from(node_data) - self._xgraph.add_edges_from(edge_data) - if not networkx.algorithms.dag.is_directed_acyclic_graph(self._xgraph): - cycle = networkx.find_cycle(self._xgraph) - raise PipelineDataCycleError(f"Cycle detected while adding task {label} graph: {cycle}.") - except Exception: - # First try to roll back our changes. - try: - self._xgraph.remove_edges_from(edge_data) - self._xgraph.remove_nodes_from(key for key, _ in node_data) - except Exception as err: - raise RuntimeError( - "Error while attempting to revert PipelineGraph modification has left the graph in " - "an inconsistent state." - ) from err - # Successfully rolled back; raise the original exception. - raise - - def add_task_subset(self, subset_label: str, task_labels: Iterable[str], description: str = "") -> None: - """Add a label for a set of tasks that are already in the pipeline. - - Parameters - ---------- - subset_label : `str` - Label for this set of tasks. - task_labels : `~collections.abc.Iterable` [ `str` ] - Labels of the tasks to include in the set. All must already be - included in the graph. - description : `str`, optional - String description to associate with this label. - """ - subset = MutableTaskSubset(self._xgraph, subset_label, set(task_labels), description) - self._task_subsets[subset_label] = subset - - def _append_graph_from_edge( - self, - node_data: list[tuple[NodeKey, dict[str, Any]]], - edge_data: list[tuple[NodeKey, NodeKey, dict[str, Any]]], - edge: Edge, - ) -> None: - """Append networkx state dictionaries for an edge and the corresponding - dataset type node. - - Parameters - ---------- - node_data : `list` - List of node keys and state dictionaries. A node is appended if - one does not already exist for this dataset type. - edge_data : `list` - List of node key pairs and state dictionaries for edges. - edge : `Edge` - New edge being processed. - """ - if (existing_dataset_type_state := self._xgraph.nodes.get(edge.dataset_type_key)) is not None: - dataset_type_node = existing_dataset_type_state["instance"] - edge._check_dataset_type(self._xgraph, dataset_type_node) - else: - dataset_type_node = DatasetTypeNode(edge.dataset_type_key) - node_data.append( - ( - edge.dataset_type_key, - { - "instance": dataset_type_node, - "bipartite": NodeType.DATASET_TYPE.bipartite, - }, - ) - ) - edge_data.append(edge.key + ({"instance": edge},)) - - -class MutablePipelineGraphReader(PipelineGraphReader[TaskNode, DatasetTypeNode, MutableTaskSubset]): - def deserialize_dataset_type( - self, name: str, serialized_dataset_type: SerializedDatasetTypeNode - ) -> DatasetTypeNode: - return DatasetTypeNode(NodeKey(NodeType.DATASET_TYPE, name)) - - def deserialize_task(self, label: str, serialized_task: SerializedTaskNode) -> TaskNode: - return TaskNode(**self.deserialize_task_args(label, serialized_task)) - - def deserialize_task_subset( - self, label: str, serialized_task_subset: SerializedTaskSubset - ) -> MutableTaskSubset: - members = set(serialized_task_subset.tasks) - return MutableTaskSubset(self.xgraph, label, members, serialized_task_subset.description) - - def finish(self) -> MutablePipelineGraph: - result = MutablePipelineGraph.__new__(MutablePipelineGraph) - result._init_from_args(self.xgraph, self.sort_keys, self.task_subsets, self.description) - return result diff --git a/python/lsst/pipe/base/pipeline_graph/_nodes.py b/python/lsst/pipe/base/pipeline_graph/_nodes.py index d726a32a6..10c303abb 100644 --- a/python/lsst/pipe/base/pipeline_graph/_nodes.py +++ b/python/lsst/pipe/base/pipeline_graph/_nodes.py @@ -21,21 +21,12 @@ from __future__ import annotations __all__ = ( - "Node", "NodeKey", "NodeType", ) import enum -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, NamedTuple - -import networkx -from lsst.daf.butler import Registry -from lsst.utils.classes import immutable - -if TYPE_CHECKING: - from pydantic import BaseModel +from typing import NamedTuple class NodeType(enum.Enum): @@ -84,75 +75,3 @@ def __repr__(self) -> str: def __str__(self) -> str: return self.name - - -@immutable -class Node(ABC): - """Base class for nodes in a pipeline graph. - - Parameters - ---------- - key : `NodeKey` - The key for this node in networkx graphs. - """ - - def __init__(self, key: NodeKey): - self.key = key - - key: NodeKey - """The key for this node in networkx graphs.""" - - @abstractmethod - def _resolved(self, xgraph: networkx.DiGraph, registry: Registry) -> Node: - """Resolve any dataset type and dimension names in this graph. - - Parameters - ---------- - xgraph : `networkx.DiGraph` - Directed bipartite graph representing the full pipeline. Should - not be modified. - registry : `lsst.daf.butler.Registry` - Registry that provides dimension and dataset type information. - - Returns - ------- - node : `Node` - Resolved version of this node. May be self if the node is already - resolved. - """ - raise NotImplementedError() - - @abstractmethod - def _unresolved(self) -> Node: - """Revert this node to a form that just holds names for dataset types - and dimensions, allowing `_reresolve` to have an effect if called - again. - - Returns - ------- - node : `Node` - Resolved version of this node. May be self if the node is already - resolved. - """ - raise NotImplementedError() - - @abstractmethod - def _serialize(self) -> BaseModel: - raise NotImplementedError() - - @abstractmethod - def _to_xgraph_state(self) -> dict[str, Any]: - """Unpack the content of this node into a dictionary that can be used - as the state dictionary for an external networkx graph. - - Unlike `_serialize`, this may hold types that are not directly suitable - for JSON conversion, and it does not need to hold any edge state. Like - `_serialize`, this should not include the node's key, as that is always - included in the graph separately. - - Returns - ------- - state : `dict` - Dictionary for an external networkx graph. - """ - raise NotImplementedError() diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py index fdf62ae79..81fbfd642 100644 --- a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py +++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py @@ -25,82 +25,78 @@ import io import os import tarfile -from abc import abstractmethod -from collections.abc import Iterator, Mapping, Sequence +from collections.abc import Iterable, Iterator, Mapping, Sequence from datetime import datetime -from typing import TYPE_CHECKING, Any, BinaryIO, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, BinaryIO, Literal, cast import networkx import networkx.algorithms.bipartite import networkx.algorithms.dag -from lsst.daf.butler import Registry +from lsst.daf.butler import DimensionGraph, DimensionUniverse, Registry from lsst.resources import ResourcePath, ResourcePathExpression +from ._dataset_types import DatasetTypeNode, SerializedDatasetTypeNode from ._edges import Edge, ReadEdge, WriteEdge -from ._exceptions import PipelineDataCycleError -from ._io import SerializedPipelineGraph -from ._mapping_views import _D, _T, DatasetTypeMappingView, TaskMappingView -from ._nodes import Node, NodeKey, NodeType +from ._exceptions import PipelineDataCycleError, ReadInconsistencyError, UnresolvedGraphError +from ._io import PipelineGraphReader, SerializedPipelineGraph +from ._mapping_views import DatasetTypeMappingView, TaskMappingView +from ._nodes import NodeKey, NodeType from ._task_subsets import TaskSubset +from ._tasks import TaskInitNode, TaskNode, _TaskNodeImportedData if TYPE_CHECKING: + from ..config import PipelineTaskConfig + from ..connections import PipelineTaskConnections from ..pipeline import TaskDef - from ._extract_helper import ExtractHelper - from ._mutable_pipeline_graph import MutablePipelineGraph - from ._resolved_pipeline_graph import ResolvedPipelineGraph + from ..pipelineTask import PipelineTask -_S = TypeVar("_S", bound="TaskSubset", covariant=True) -_P = TypeVar("_P", bound="PipelineGraph", covariant=True) - - -class PipelineGraph(Generic[_T, _D, _S]): - """A base class for directed acyclic graph of `PipelineTask` definitions. - - This abstract base class should not be inherited from outside its package; - it exists to share code and interfaces between `MutablePipelineGraph` and - `ResolvedPipelineGraph`. - """ - - def __init__(self) -> None: - self._init_from_args() +class PipelineGraph: + def __init__(self, universe: DimensionUniverse | None = None) -> None: + self._init_from_args( + xgraph=None, sorted_keys=None, task_subsets=None, description="", universe=universe + ) def _init_from_args( self, - xgraph: networkx.DiGraph | None = None, - sorted_keys: Sequence[NodeKey] | None = None, - task_subsets: dict[str, _S] | None = None, - description: str = "", + xgraph: networkx.DiGraph | None, + sorted_keys: Sequence[NodeKey] | None, + task_subsets: dict[str, TaskSubset] | None, + description: str, + universe: DimensionUniverse | None, ) -> None: """Initialize the graph with possibly-nontrivial arguments. Parameters ---------- - xgraph : `networkx.DiGraph` or `None`, optional + xgraph : `networkx.DiGraph` or `None` The backing networkx graph, or `None` to create an empty one. - sorted_keys : `Sequence` [ `NodeKey` ] or `None`, optional + sorted_keys : `Sequence` [ `NodeKey` ] or `None` Topologically sorted sequence of node keys, or `None` if the graph is not sorted. - task_subsets : `dict` [ `str`, `TaskSubsetMapping` ], optional + task_subsets : `dict` [ `str`, `TaskSubset` ] Labeled subsets of tasks. Values must be constructed with ``xgraph`` as their parent graph. - description : `str`, optional + description : `str` String description for this pipeline. + universe : `lsst.daf.butler.DimensionUniverse` or `None` + Definitions of all dimensions. Notes ----- - Only empty `PipelineGraph` [subclass] instances should be constructed - directly by users, which sets the signature of ``__init__`` itself, but - methods on `PipelineGraph` and its helper classes need to be able to - create them with state. Those methods can call this after calling - ``__new__`` manually. + Only empty `PipelineGraph` instances should be constructed directly by + users, which sets the signature of ``__init__`` itself, but methods on + `PipelineGraph` and its helper classes need to be able to create them + with state. Those methods can call this after calling ``__new__`` + manually. """ self._xgraph = xgraph if xgraph is not None else networkx.DiGraph() - self._sorted_keys: Sequence[NodeKey] | None = None + self._sorted_keys: Sequence[NodeKey] | None self._task_subsets = task_subsets if task_subsets is not None else {} self._description = description - self._tasks = TaskMappingView[_T](self._xgraph) - self._dataset_types = DatasetTypeMappingView[_D](self._xgraph) + self._tasks = TaskMappingView(self._xgraph) + self._dataset_types = DatasetTypeMappingView(self._xgraph) + self._universe = universe if sorted_keys is not None: self._reorder(sorted_keys) @@ -112,16 +108,26 @@ def description(self) -> str: """String description for this pipeline.""" return self._description + @description.setter + def description(self, value: str) -> None: + # Docstring in setter. + self._description = value + + @property + def universe(self) -> DimensionUniverse | None: + """Definitions for all dimensions.""" + return self._universe + @property - def tasks(self) -> TaskMappingView[_T]: + def tasks(self) -> TaskMappingView: return self._tasks @property - def dataset_types(self) -> DatasetTypeMappingView[_D]: + def dataset_types(self) -> DatasetTypeMappingView: return self._dataset_types @property - def task_subsets(self) -> Mapping[str, _S]: + def task_subsets(self) -> Mapping[str, TaskSubset]: """Mapping of all labeled subsets of tasks. Keys are subset labels, values are Task-only graphs (subgraphs of @@ -135,29 +141,32 @@ def iter_edges(self, init: bool = False) -> Iterator[Edge]: if edge is not None and edge.is_init == init: yield edge - def iter_nodes(self) -> Iterator[Node]: + def iter_nodes( + self, + ) -> Iterator[ + tuple[Literal[NodeType.TASK_INIT], str, TaskInitNode] + | tuple[Literal[NodeType.TASK], str, TaskInitNode] + | tuple[Literal[NodeType.DATASET_TYPE], str, DatasetTypeNode | None] + ]: + key: NodeKey if self._sorted_keys is not None: for key in self._sorted_keys: - yield self._xgraph.nodes[key]["instance"] + yield key.node_type, key.name, self._xgraph.nodes[key]["instance"] # type: ignore else: - for _, node in self._xgraph.nodes(data="instance"): - yield node + for key, node in self._xgraph.nodes(data="instance"): + yield key.node_type, key.name, node # type: ignore - def iter_overall_inputs(self) -> Iterator[_D]: + def iter_overall_inputs(self) -> Iterator[tuple[str, DatasetTypeNode | None]]: for generation in networkx.algorithms.dag.topological_generations(self._xgraph): + key: NodeKey for key in generation: # While we expect all tasks to have at least one input and # hence never appear in the first topological generation, that # is not true of task init nodes. if key.node_type is NodeType.DATASET_TYPE: - yield self._xgraph.nodes[key]["instance"] + yield key.name, self._xgraph.nodes[key]["instance"] return - def import_and_configure_in_place(self, check_edges: bool = True) -> None: - # TODO: docs - for task in self.tasks.values(): - task.import_and_configure(check_edges=check_edges) - def make_xgraph(self) -> networkx.DiGraph: return self._transform_xgraph_state(self._xgraph.copy()) @@ -192,14 +201,62 @@ def _make_bipartite_xgraph_internal(self, init: bool) -> networkx.DiGraph: def _transform_xgraph_state(self, xgraph: networkx.DiGraph) -> networkx.DiGraph: state: dict[str, Any] for state in xgraph.nodes.values(): - node: Node = state.pop("instance") - state.update(node._to_xgraph_state()) + node_value: TaskInitNode | TaskNode | DatasetTypeNode | None = state.pop("instance") + if node_value is not None: + state.update(node_value._to_xgraph_state()) for _, _, state in xgraph.edges(data=True): edge: Edge | None = state.pop("instance", None) if edge is not None: state.update(edge._to_xgraph_state()) return xgraph + def group_by_dimensions( + self, prerequisites: bool = False + ) -> dict[DimensionGraph, tuple[list[TaskNode], list[DatasetTypeNode]]]: + """Group this graph's tasks and runtime dataset types by their + dimensions. + + Parameters + ---------- + prerequisites : `bool`, optional + If `True`, include prerequisite dataset types as well as regular + input and output datasets (including intermediates). + + Returns + ------- + groups : `dict` [ `DimensionGraph`, `tuple` ] + A dictionary of groups keyed by `DimensionGraph`, which each value + a tuple of: + + - a `list` of `TaskNode` instances + - a `list` of `ResolvedDatasetTypeNode` instances + + that have those dimensions. + + Notes + ----- + Init inputs and outputs are always included, but always have empty + dimensions and are hence easily filtered out. + """ + result: dict[DimensionGraph, tuple[list[TaskNode], list[DatasetTypeNode]]] = {} + next_new_value: tuple[list[TaskNode], list[DatasetTypeNode]] = ([], []) + for task_label, task_node in self.tasks.items(): + if task_node.dimensions is None: + raise UnresolvedGraphError(f"Task with label {task_label!r} has not been resolved.") + if (group := result.setdefault(task_node.dimensions, next_new_value)) is next_new_value: + next_new_value = ([], []) # make new lists for next time + group[0].append(task_node) + for dataset_type_name, dataset_type_node in self.dataset_types.items(): + if dataset_type_node is None: + raise UnresolvedGraphError(f"Dataset type {dataset_type_name!r} has not been resolved.") + if not dataset_type_node.is_prerequisite or prerequisites: + if ( + group := result.setdefault(dataset_type_node.dataset_type.dimensions, next_new_value) + ) is next_new_value: + next_new_value = ([], []) # make new lists for next time + group[1].append(dataset_type_node) + return result + @property def is_sorted(self) -> bool: """Whether this graph's tasks and dataset types are topologically @@ -288,63 +345,116 @@ def consumers_of(self, dataset_type_name: str) -> dict[str, ReadEdge]: ) } - def extract(self) -> ExtractHelper: - """Create a new `MutablePipelineGraph` containing just the tasks that - match the given criteria. - """ - from ._extract_helper import ExtractHelper + def add_task( + self, + label: str, + task_class: type[PipelineTask], + config: PipelineTaskConfig, + connections: PipelineTaskConnections | None = None, + ) -> None: + """Add a new task to the graph. - return ExtractHelper(self) + Parameters + ---------- + label : `str` + Label for the task in the pipeline. + task_class : `type` [ `PipelineTask` ] + Class object for the task. + config : `PipelineTaskConfig` + Configuration for the task. + connections : `PipelineTaskConnections`, optional + Object that describes the dataset types used by the task. If not + provided, one will be constructed from the given configuration. If + provided, it is assumed that ``config`` has already been validated + and frozen. - def _reorder(self, sorted_keys: Sequence[NodeKey]) -> None: - """Set the order of all views of this graph from the given sorted - sequence of task labels and dataset type names. + Raises + ------ + ConnectionTypeConsistencyError + Raised if the task defines an edge's ``is_init`` or + ``is_prerequisite`` flags in a way that is inconsistent with some + other task in the graph. + IncompatibleDatasetTypeError + Raised if the task defines a dataset type differently from some + other task in the graph. Note that checks for dataset type + dimension consistency do not occur until the graph is resolved. + ValueError + Raised if configuration validation failed when constructing. + ``connections``. + PipelineDataCycleError + Raised if the graph is cyclic after this addition. + RuntimeError + Raised if an unexpected exception (which will be chained) occurred + at a stage that may have left the graph in an inconsistent state. + Other exceptions should leave the graph unchanged. """ - self._sorted_keys = sorted_keys - self._tasks._reorder(sorted_keys) - self._dataset_types._reorder(sorted_keys) + key = NodeKey(NodeType.TASK, label) + init_key = NodeKey(NodeType.TASK_INIT, label) + task_node = TaskNode._from_imported_data( + key, + init_key, + _TaskNodeImportedData.configure(label, task_class, config, connections), + universe=self.universe, + ) + self._add_task_nodes([(key, task_node)]) - def _reset(self) -> None: - """Reset the all views of this graph following a modification that - might invalidate them. + def remove_task(self, label: str) -> None: + key = NodeKey(NodeType.TASK, label) + self._remove_task_nodes({key: self._xgraph[key]["instance"]}) + + def add_task_subset(self, subset_label: str, task_labels: Iterable[str], description: str = "") -> None: + """Add a label for a set of tasks that are already in the pipeline. + + Parameters + ---------- + subset_label : `str` + Label for this set of tasks. + task_labels : `~collections.abc.Iterable` [ `str` ] + Labels of the tasks to include in the set. All must already be + included in the graph. + description : `str`, optional + String description to associate with this label. """ - self._sorted_keys = None - self._tasks._reset() - self._dataset_types._reset() + subset = TaskSubset(self._xgraph, subset_label, set(task_labels), description) + self._task_subsets[subset_label] = subset - @abstractmethod - def copy(self: _P) -> _P: + def copy(self) -> PipelineGraph: """Return a copy of this graph that copies all mutable state.""" - raise NotImplementedError() + xgraph = self._xgraph.copy() + result = PipelineGraph.__new__(PipelineGraph) + result._init_from_args( + xgraph, + self._sorted_keys, + task_subsets={ + k: TaskSubset(xgraph, v.label, set(v._members), v.description) + for k, v in self._task_subsets.items() + }, + description=self._description, + universe=self.universe, + ) + return result - def __copy__(self: _P) -> _P: + def __copy__(self) -> PipelineGraph: # Fully shallow copies are dangerous; we don't want shared mutable # state to lead to broken class invariants. return self.copy() - def __deepcopy__(self: _P, memo: dict) -> _P: - # Genuine deep copies are sometimes unnecessary, since we should only - # ever care that mutable state is copied. + def __deepcopy__(self, memo: dict) -> PipelineGraph: + # Genuine deep copies are unnecessary, since we should only ever care + # that mutable state is copied. return self.copy() - @abstractmethod - def resolved(self, registry: Registry, *, redo: bool = False) -> ResolvedPipelineGraph: - """Return a version of this graph with all dimensions and dataset types - resolved according to the given butler registry. + def import_and_configure(self, check: bool = True, rebuild: bool = False) -> None: + self._import_and_configure(check=check, rebuild=rebuild, universe=self._universe) + + def resolve(self, registry: Registry, *, check: bool = True, rebuild: bool = False) -> None: + """Resolve all dimensions and dataset types. Parameters ---------- registry : `lsst.daf.butler.Registry` Client for the data repository to resolve against. - redo : `bool`, optional - If `True`, re-do the resolution even if the graph has already been - resolved to pick up changes in the registry. If `False` (default) - and the graph is already resolved, this method returns ``self``. - - Returns - ------- - resolved : `ResolvedPipelineGraph` - A resolved version of this graph. Always sorted and immutable. + TODO Raises ------ @@ -358,69 +468,59 @@ def resolved(self, registry: Registry, *, redo: bool = False) -> ResolvedPipelin are consumed with different storage classes or as components by tasks in the pipeline. """ - raise NotImplementedError() - - @abstractmethod - def mutable_copy(self) -> MutablePipelineGraph: - """Return a mutable copy of this graph. - - Returns - ------- - mutable : `MutablePipelineGraph` - A mutable copy of this graph. This drops all dimension and - dataset type resolutions that may be present in ``self``. See - docs for `MutablePipelineGraph` for details. - """ - raise NotImplementedError() - - def _iter_task_defs(self) -> Iterator[TaskDef]: - """Iterate over this pipeline as a sequence of `TaskDef` instances. - - Notes - ----- - This is a package-private method intended to aid in the transition to a - codebase more fully integrated with the `PipelineGraph` class, in which - both `TaskDef` and `PipelineDatasetTypes` are expected to go away, and - much of the functionality on the `Pipeline` class will be moved to - `PipelineGraph` as well. - """ - from ..pipeline import TaskDef - - for node in self._tasks.values(): - yield TaskDef( - config=node.config, - taskClass=node.task_class, - label=node.label, - connections=node._get_connections(), - ) + self._import_and_configure(check=check, rebuild=rebuild, universe=registry.dimensions) + self.sort() + node_key: NodeKey + updates: dict[NodeKey, TaskNode | DatasetTypeNode] = {} + for node_key, node_state in self._xgraph.nodes.items(): + match node_key.node_type: + case NodeType.TASK: + task_node: TaskNode = node_state["instance"] + new_task_node = task_node._resolved(registry.dimensions) + if new_task_node is not task_node: + updates[node_key] = new_task_node + case NodeType.DATASET_TYPE: + dataset_type_node: DatasetTypeNode | None = node_state["instance"] + if dataset_type_node is None: + updates[node_key] = DatasetTypeNode._from_edges(node_key, self._xgraph, registry) + else: + new_dataset_type_node = dataset_type_node._resolved(registry) + if new_dataset_type_node is not dataset_type_node: + updates[node_key] = new_dataset_type_node + for node_key, node_value in updates.items(): + self._xgraph.nodes[node_key]["instance"] = node_value + self._universe = registry.dimensions @classmethod def read_stream( - cls, stream: BinaryIO, import_and_configure: bool = True, check_edges: bool = True - ) -> MutablePipelineGraph | ResolvedPipelineGraph: - from ._mutable_pipeline_graph import MutablePipelineGraphReader - from ._resolved_pipeline_graph import ResolvedPipelineGraphReader - + cls, stream: BinaryIO, import_and_configure: bool = True, check: bool = True, rebuild: bool = False + ) -> PipelineGraph: serialized_graph = SerializedPipelineGraph.read_stream(stream) - reader: MutablePipelineGraphReader | ResolvedPipelineGraphReader - if serialized_graph.dimensions is None: - reader = MutablePipelineGraphReader() - else: - reader = ResolvedPipelineGraphReader() + reader = PipelineGraphReader() reader.deserialize_graph(serialized_graph) - result = reader.finish() + result = PipelineGraph.__new__(PipelineGraph) + result._init_from_args( + reader.xgraph, reader.sort_keys, reader.task_subsets, reader.description, universe=reader.universe + ) if import_and_configure: - result.import_and_configure_in_place(check_edges=check_edges) + result._import_and_configure(check=check, rebuild=rebuild, universe=result.universe) return result @classmethod def read_uri( - cls, uri: ResourcePathExpression, import_and_configure: bool = True, check_edges: bool = True - ) -> MutablePipelineGraph | ResolvedPipelineGraph: + cls, + uri: ResourcePathExpression, + import_and_configure: bool = True, + check: bool = True, + rebuild: bool = False, + ) -> PipelineGraph: uri = ResourcePath(uri) with uri.open("rb") as stream: return cls.read_stream( - cast(BinaryIO, stream), import_and_configure=import_and_configure, check_edges=check_edges + cast(BinaryIO, stream), + import_and_configure=import_and_configure, + check=check, + rebuild=rebuild, ) def write_stream(self, stream: BinaryIO, basename: str = "pipeline", compression: str = "gz") -> None: @@ -524,8 +624,12 @@ def _serialize(self) -> SerializedPipelineGraph: result = SerializedPipelineGraph.construct( description=self.description, tasks={label: node._serialize() for label, node in self.tasks.items()}, - dataset_types={name: node._serialize() for name, node in self.dataset_types.items()}, + dataset_types={ + name: node._serialize() if node is not None else SerializedDatasetTypeNode() + for name, node in self.dataset_types.items() + }, task_subsets={label: subset._serialize() for label, subset in self.task_subsets.items()}, + dimensions=self.universe.dimensionConfig.toDict() if self.universe is not None else None, ) if self._sorted_keys: for index, node_key in enumerate(self._sorted_keys): @@ -537,3 +641,180 @@ def _serialize(self) -> SerializedPipelineGraph: case NodeType.TASK_INIT: result.tasks[node_key.name].init.index = index return result + + def _iter_task_defs(self) -> Iterator[TaskDef]: + """Iterate over this pipeline as a sequence of `TaskDef` instances. + + Notes + ----- + This is a package-private method intended to aid in the transition to a + codebase more fully integrated with the `PipelineGraph` class, in which + both `TaskDef` and `PipelineDatasetTypes` are expected to go away, and + much of the functionality on the `Pipeline` class will be moved to + `PipelineGraph` as well. + """ + from ..pipeline import TaskDef + + for node in self._tasks.values(): + yield TaskDef( + config=node.config, + taskClass=node.task_class, + label=node.label, + connections=node._get_connections(), + ) + + def _import_and_configure(self, check: bool, rebuild: bool, universe: DimensionUniverse | None) -> None: + if rebuild: + check = False + updates: dict[NodeKey, TaskNode] = {} + node_key: NodeKey + for node_key, node_state in self._xgraph.nodes.items(): + if node_key.node_type is NodeType.TASK: + task_node: TaskNode = node_state["instance"] + new_task_node = task_node._imported_and_configured(node_key, rebuild or check, universe) + if new_task_node is not task_node: + updates[node_key] = new_task_node + if check: + messages = new_task_node.diff(task_node) + if messages: + messages.insert( + 0, + f"Imported and reconfigured edges for task {node_key.name!r} " + "differ from those persisted:", + ) + raise ReadInconsistencyError("\n".join(messages)) + if rebuild: + self._remove_task_nodes(updates) + self._add_task_nodes(updates.items()) + else: + for node_key, task_node in updates.items(): + self._xgraph.nodes[node_key]["instance"] = task_node + self._xgraph.nodes[task_node.init._key]["instance"] = task_node.init + + def _add_task_nodes(self, nodes: Iterable[tuple[NodeKey, TaskNode]]) -> None: + node_data: list[tuple[NodeKey, dict[str, Any]]] = [] + for key, task_node in nodes: + node_data.append((key, {"instance": task_node, "bipartite": key.node_type.bipartite})) + node_data.append( + ( + task_node.init._key, + {"instance": task_node.init, "bipartite": task_node.init._key.node_type.bipartite}, + ) + ) + # Convert the edge objects attached to the task node to networkx. + edge_data: list[tuple[NodeKey, NodeKey, dict[str, Any]]] = [] + for read_edge in task_node.init.iter_all_inputs(): + self._append_graph_data_from_edge(node_data, edge_data, read_edge) + for write_edge in task_node.init.iter_all_outputs(): + self._append_graph_data_from_edge(node_data, edge_data, write_edge) + for read_edge in task_node.prerequisite_inputs: + self._append_graph_data_from_edge(node_data, edge_data, read_edge) + for read_edge in task_node.inputs: + self._append_graph_data_from_edge(node_data, edge_data, read_edge) + for write_edge in task_node.iter_all_outputs(): + self._append_graph_data_from_edge(node_data, edge_data, write_edge) + # Add a special edge (with no Edge instance) that connects the + # TaskInitNode to the runtime TaskNode. + edge_data.append((task_node.init._key, key, {"instance": None})) + if not node_data and not edge_data: + return + # Checks complete; time to start the actual modification, during which + # it's hard to provide strong exception safety. + self._reset() + try: + self._xgraph.add_nodes_from(node_data) + self._xgraph.add_edges_from(edge_data) + if not networkx.algorithms.dag.is_directed_acyclic_graph(self._xgraph): + cycle = networkx.find_cycle(self._xgraph) + raise PipelineDataCycleError(f"Cycle detected while adding task {key.name} graph: {cycle}.") + except Exception: + # First try to roll back our changes. + try: + self._xgraph.remove_edges_from(edge_data) + self._xgraph.remove_nodes_from(key for key, _ in node_data) + except Exception as err: + raise RuntimeError( + "Error while attempting to revert PipelineGraph modification has left the graph in " + "an inconsistent state." + ) from err + # Successfully rolled back; raise the original exception. + raise + + def _remove_task_nodes(self, nodes: Mapping[NodeKey, TaskNode]) -> None: + dataset_types: set[NodeKey] = set() + for task_key, task_node in nodes.items(): + dataset_types.add(self._xgraph.predecessors(task_key)) + dataset_types.add(self._xgraph.successors(task_key)) + dataset_types.add(self._xgraph.predecessors(task_node.init._key)) + dataset_types.add(self._xgraph.successors(task_node.init._key)) + # Since there's an edge between the task and its init node, we'll + # have added those two nodes here, too, and we don't want that. + dataset_types.remove(task_node.init._key) + dataset_types.remove(task_key) + to_remove = list(nodes.keys()) + to_unresolve: list[NodeKey] = [] + for dataset_type_key in dataset_types: + related_tasks = set() + related_tasks.update(self._xgraph.predecessors(dataset_type_key)) + related_tasks.update(self._xgraph.successors(dataset_type_key)) + related_tasks.difference_update(nodes.keys()) + if not related_tasks: + to_remove.append(dataset_type_key) + else: + to_unresolve.append(dataset_type_key) + for dataset_type_key in to_unresolve: + self._xgraph.nodes[dataset_type_key]["instance"] = None + if to_remove: + self._reset() + self._xgraph.remove_nodes_from(to_remove) + + def _append_graph_data_from_edge( + self, + node_data: list[tuple[NodeKey, dict[str, Any]]], + edge_data: list[tuple[NodeKey, NodeKey, dict[str, Any]]], + edge: Edge, + ) -> None: + """Append networkx state dictionaries for an edge and the corresponding + dataset type node. + + Parameters + ---------- + node_data : `list` + List of node keys and state dictionaries. A node is appended if + one does not already exist for this dataset type. + edge_data : `list` + List of node key pairs and state dictionaries for edges. + edge : `Edge` + New edge being processed. + """ + if (existing_dataset_type_state := self._xgraph.nodes.get(edge.dataset_type_key)) is not None: + existing_dataset_type_state["instance"] = edge._update_dataset_type( + self._xgraph, existing_dataset_type_state["instance"] + ) + else: + node_data.append( + ( + edge.dataset_type_key, + { + "instance": None, + "bipartite": NodeType.DATASET_TYPE.bipartite, + }, + ) + ) + edge_data.append(edge.key + ({"instance": edge},)) + + def _reorder(self, sorted_keys: Sequence[NodeKey]) -> None: + """Set the order of all views of this graph from the given sorted + sequence of task labels and dataset type names. + """ + self._sorted_keys = sorted_keys + self._tasks._reorder(sorted_keys) + self._dataset_types._reorder(sorted_keys) + + def _reset(self) -> None: + """Reset the all views of this graph following a modification that + might invalidate them. + """ + self._sorted_keys = None + self._tasks._reset() + self._dataset_types._reset() diff --git a/python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py deleted file mode 100644 index 1ea260f2c..000000000 --- a/python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py +++ /dev/null @@ -1,231 +0,0 @@ -# This file is part of pipe_base. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -from __future__ import annotations - -__all__ = ("ResolvedPipelineGraph",) - -from collections.abc import Sequence -from typing import TYPE_CHECKING, BinaryIO, cast, final - -import networkx -import networkx.algorithms.bipartite -import networkx.algorithms.dag -from lsst.daf.butler import DatasetType, DimensionConfig, DimensionGraph, DimensionUniverse, Registry -from lsst.resources import ResourcePathExpression - -from ._dataset_types import ResolvedDatasetTypeNode, SerializedDatasetTypeNode -from ._io import PipelineGraphReader, SerializedPipelineGraph, expect_not_none -from ._nodes import Node, NodeKey, NodeType -from ._pipeline_graph import PipelineGraph -from ._task_subsets import ResolvedTaskSubset, SerializedTaskSubset -from ._tasks import SerializedTaskNode, TaskNode - -if TYPE_CHECKING: - from ._mutable_pipeline_graph import MutablePipelineGraph - - -@final -class ResolvedPipelineGraph(PipelineGraph[TaskNode, ResolvedDatasetTypeNode, ResolvedTaskSubset]): - """An immutable pipeline graph with resolved dimensions and dataset types. - - Resolved pipeline graphs are sorted at construction and cannot be modified, - so calling `sort` on them does nothing. - """ - - def __init__(self, universe: DimensionUniverse) -> None: - super().__init__() - self.universe = universe - - def _init_from_args( - self, - xgraph: networkx.DiGraph | None = None, - sorted_keys: Sequence[NodeKey] | None = None, - task_subsets: dict[str, ResolvedTaskSubset] | None = None, - description: str = "", - ) -> None: - super()._init_from_args(xgraph, sorted_keys, task_subsets, description) - super().sort() - - def sort(self) -> None: - # Docstring inherited. - assert self.is_sorted, "Sorted at construction and immutable." - - def copy(self) -> ResolvedPipelineGraph: - # Docstring inherited. - # Immutable types shouldn't actually be copied, since there's nothing - # one could do with the copy that couldn't be done with the original. - return self - - def resolved(self, registry: Registry, *, redo: bool = False) -> ResolvedPipelineGraph: - # Docstring inherited. - if redo: - return self.mutable_copy().resolved(registry) - return self - - def mutable_copy(self) -> MutablePipelineGraph: - # Docstring inherited. - from ._mutable_pipeline_graph import MutablePipelineGraph - - xgraph = self._xgraph.copy() - for state in xgraph.nodes.values(): - node: Node = state["instance"] - state["instance"] = node._unresolved() - result = MutablePipelineGraph.__new__(MutablePipelineGraph) - result._init_from_args( - xgraph, - self._sorted_keys, - task_subsets={k: v._mutable_copy(xgraph) for k, v in self._task_subsets.items()}, - description=self._description, - ) - return result - - def group_by_dimensions( - self, prerequisites: bool = False - ) -> dict[DimensionGraph, tuple[list[TaskNode], list[ResolvedDatasetTypeNode]]]: - """Group this graph's tasks and runtime dataset types by their - dimensions. - - Parameters - ---------- - prerequisites : `bool`, optional - If `True`, include prerequisite dataset types as well as regular - input and output datasets (including intermediates). - - Returns - ------- - groups : `dict` [ `DimensionGraph`, `tuple` ] - A dictionary of groups keyed by `DimensionGraph`, which each value - a tuple of: - - - a `list` of `TaskNode` instances - - a `list` of `ResolvedDatasetTypeNode` instances - - that have those dimensions. - - Notes - ----- - Init inputs and outputs are always included, but always have empty - dimensions and are hence easily filtered out. - """ - result: dict[DimensionGraph, tuple[list[TaskNode], list[ResolvedDatasetTypeNode]]] = {} - next_new_value: tuple[list[TaskNode], list[ResolvedDatasetTypeNode]] = ([], []) - for task_node in self.tasks.values(): - if ( - group := result.setdefault(cast(DimensionGraph, task_node.dimensions), next_new_value) - ) is next_new_value: - next_new_value = ([], []) # make new lists for next time - group[0].append(task_node) - for dataset_type_node in self.dataset_types.values(): - if not dataset_type_node.is_prerequisite or prerequisites: - if ( - group := result.setdefault(dataset_type_node.dataset_type.dimensions, next_new_value) - ) is next_new_value: - next_new_value = ([], []) # make new lists for next time - group[1].append(dataset_type_node) - return result - - def _serialize(self) -> SerializedPipelineGraph: - # Docstring inherited. - result = super()._serialize() - result.dimensions = self.universe.dimensionConfig.toDict() - return result - - @classmethod - def read_stream( - cls, stream: BinaryIO, import_and_configure: bool = True, check_edges: bool = True - ) -> ResolvedPipelineGraph: - serialized_graph = SerializedPipelineGraph.read_stream(stream) - reader = ResolvedPipelineGraphReader() - reader.deserialize_graph(serialized_graph) - result = ResolvedPipelineGraph.__new__(ResolvedPipelineGraph) - result._init_from_args(reader.xgraph, reader.sort_keys, reader.task_subsets, reader.description) - result.universe = reader.universe - if import_and_configure: - result.import_and_configure_in_place(check_edges=check_edges) - return result - - @classmethod - def read_uri( - cls, uri: ResourcePathExpression, import_and_configure: bool = True, check_edges: bool = True - ) -> ResolvedPipelineGraph: - return cast( - ResolvedPipelineGraph, - super().read_uri(uri, import_and_configure=import_and_configure, check_edges=check_edges), - ) - - -class ResolvedPipelineGraphReader(PipelineGraphReader[TaskNode, ResolvedDatasetTypeNode, ResolvedTaskSubset]): - def deserialize_graph( - self, - serialized_graph: SerializedPipelineGraph, - ) -> None: - self.universe = DimensionUniverse( - config=DimensionConfig( - expect_not_none( - serialized_graph.dimensions, - "Serialized pipeline graph has not been resolved; " - "load it is a MutablePipelineGraph instead.", - ) - ) - ) - super().deserialize_graph(serialized_graph) - - def deserialize_dataset_type( - self, name: str, serialized_dataset_type: SerializedDatasetTypeNode - ) -> ResolvedDatasetTypeNode: - dataset_type = DatasetType( - name, - expect_not_none( - serialized_dataset_type.dimensions, f"Serialized dataset type {name!r} has no dimensions." - ), - storageClass=expect_not_none( - serialized_dataset_type.storage_class, - f"Serialized dataset type {name!r} has no storage class.", - ), - isCalibration=serialized_dataset_type.is_calibration, - universe=self.universe, - ) - return ResolvedDatasetTypeNode( - key=NodeKey(NodeType.DATASET_TYPE, name), - dataset_type=dataset_type, - is_prerequisite=serialized_dataset_type.is_prerequisite, - is_initial_query_constraint=serialized_dataset_type.is_initial_query_constraint, - ) - - def deserialize_task(self, label: str, serialized_task: SerializedTaskNode) -> TaskNode: - kwargs = self.deserialize_task_args(label, serialized_task) - kwargs["dimensions"] = self.universe.extract( - expect_not_none( - serialized_task.dimensions, f"Serialized task with label {label!r} has no dimensions." - ) - ) - return TaskNode(**kwargs) - - def deserialize_task_subset( - self, label: str, serialized_task_subset: SerializedTaskSubset - ) -> ResolvedTaskSubset: - members = set(serialized_task_subset.tasks) - return ResolvedTaskSubset(self.xgraph, label, members, serialized_task_subset.description) - - def finish(self) -> ResolvedPipelineGraph: - result = ResolvedPipelineGraph.__new__(ResolvedPipelineGraph) - result._init_from_args(self.xgraph, self.sort_keys, self.task_subsets, self.description) - return result diff --git a/python/lsst/pipe/base/pipeline_graph/_task_subsets.py b/python/lsst/pipe/base/pipeline_graph/_task_subsets.py index 215c88973..b5cdea09e 100644 --- a/python/lsst/pipe/base/pipeline_graph/_task_subsets.py +++ b/python/lsst/pipe/base/pipeline_graph/_task_subsets.py @@ -20,9 +20,9 @@ # along with this program. If not, see . from __future__ import annotations -__all__ = ("TaskSubset", "MutableTaskSubset", "ResolvedTaskSubset", "SerializedTaskSubset") +__all__ = ("TaskSubset", "SerializedTaskSubset") -from collections.abc import Iterator, MutableSet, Set +from collections.abc import Iterator, MutableSet import networkx import networkx.algorithms.boundary @@ -32,7 +32,7 @@ from ._nodes import NodeKey, NodeType -class TaskSubset(Set[str]): +class TaskSubset(MutableSet[str]): """An abstract base class whose instances represent a labeled subset of the tasks in a pipeline. @@ -70,6 +70,10 @@ def description(self) -> str: """Description string associated with this labeled subset.""" return self._description + @description.setter + def description(self, value: str) -> None: + self._description = value + def __str__(self) -> str: return f"{self.label}: {self.description}, tasks={', '.join(iter(self))}" @@ -82,53 +86,6 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[str]: return iter(self._members) - def _resolved(self, parent_xgraph: networkx.DiGraph) -> ResolvedTaskSubset: - """Return a version of this view appropriate for a resolved pipeline - graph. - - Parameters - ---------- - parent_xgraph : `networkx.DiGraph` - The new parent networkx graph that will back the new view. - - Returns - ------- - resolved : `ResolvedTaskSubsetGraph` - A resolved version of this object. - """ - return ResolvedTaskSubset(parent_xgraph, self.label, self._members.copy(), self._description) - - def _mutable_copy(self, parent_xgraph: networkx.DiGraph) -> MutableTaskSubset: - """Return a copy of this view appropriate for a mutable pipeline - graph. - - Parameters - ---------- - parent_xgraph : `networkx.DiGraph` - The new parent networkx graph that will back the new view. - - Returns - ------- - mutable : `MutableTaskSubsetGraph` - A mutable version of this object. - """ - return MutableTaskSubset(parent_xgraph, self.label, self._members.copy(), self._description) - - def _serialize(self) -> SerializedTaskSubset: - return SerializedTaskSubset.construct(description=self._description, tasks=list(sorted(self))) - - -class MutableTaskSubset(TaskSubset, MutableSet[str]): - @property - def description(self) -> str: - # Docstring inherited. - return self._description - - @description.setter - def description(self, value: str) -> None: - # Docstring inherited. - self._description = value - def add(self, task_label: str) -> None: """Add a new task to this subset. @@ -146,9 +103,8 @@ def add(self, task_label: str) -> None: def discard(self, task_label: str) -> None: self._members.discard(task_label) - -class ResolvedTaskSubset(TaskSubset): - pass + def _serialize(self) -> SerializedTaskSubset: + return SerializedTaskSubset.construct(description=self._description, tasks=list(sorted(self))) class SerializedTaskSubset(BaseModel): diff --git a/python/lsst/pipe/base/pipeline_graph/_tasks.py b/python/lsst/pipe/base/pipeline_graph/_tasks.py index f52dd62d4..326555d09 100644 --- a/python/lsst/pipe/base/pipeline_graph/_tasks.py +++ b/python/lsst/pipe/base/pipeline_graph/_tasks.py @@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Any import networkx -from lsst.daf.butler import DimensionGraph, Registry +from lsst.daf.butler import DimensionGraph, DimensionUniverse, Registry from lsst.utils.classes import immutable from lsst.utils.doImport import doImportType from lsst.utils.introspection import get_full_type_name @@ -38,8 +38,8 @@ from ..connections import PipelineTaskConnections from ..connectionTypes import BaseConnection, InitOutput, Output from ._edges import ReadEdge, SerializedEdge, WriteEdge -from ._exceptions import ReadInconsistencyError, TaskNotImportedError -from ._nodes import Node, NodeKey, NodeType +from ._exceptions import TaskNotImportedError +from ._nodes import NodeKey, NodeType if TYPE_CHECKING: from ..config import PipelineTaskConfig @@ -122,14 +122,12 @@ def configure( @immutable -class TaskInitNode(Node): +class TaskInitNode: """A node in a pipeline graph that represents the construction of a `PipelineTask`. Parameters ---------- - key : `NodeKey` - Key for this node in the graph. inputs : `~collections.abc.Set` [ `ReadEdge` ] Graph edges that represent inputs required just to construct an instance of this task. @@ -157,7 +155,7 @@ def __init__( task_class_name: str | None = None, config_str: str | None = None, ): - super().__init__(key) + self._key = key self.inputs = inputs self.outputs = outputs self.config_output = config_output @@ -195,7 +193,7 @@ def __init__( @property def label(self) -> str: """Label of this configuration of a task in the pipeline.""" - return str(self.key) + return str(self._key) @property def is_imported(self) -> bool: @@ -286,7 +284,7 @@ def _serialize(self) -> SerializedTaskInitNode: def _to_xgraph_state(self) -> dict[str, Any]: # Docstring inherited. - result = {"task_class_name": self.task_class_name, "bipartite": self.key.node_type.bipartite} + result = {"task_class_name": self.task_class_name, "bipartite": NodeType.TASK_INIT.bipartite} if hasattr(self, "_imported_data"): result["task_class"] = self.task_class result["config"] = self.config @@ -316,14 +314,12 @@ class SerializedTaskInitNode(BaseModel): @immutable -class TaskNode(Node): +class TaskNode: """A node in a pipeline graph that represents a labeled configuration of a `PipelineTask`. Parameters ---------- - key : `NodeKey` - Key for this node in the graph. init : `TaskInitNode` Node representing the initialization of this task. prerequisite_inputs : `~collections.abc.Set` [ `ReadEdge` ] @@ -357,7 +353,6 @@ class TaskNode(Node): def __init__( self, - key: NodeKey, init: TaskInitNode, *, prerequisite_inputs: Set[ReadEdge], @@ -367,7 +362,6 @@ def __init__( metadata_output: WriteEdge, dimensions: DimensionGraph | None, ): - super().__init__(key) self.init = init self.prerequisite_inputs = prerequisite_inputs self.inputs = inputs @@ -378,25 +372,26 @@ def __init__( @staticmethod def _from_imported_data( - label: str, + key: NodeKey, + init_key: NodeKey, data: _TaskNodeImportedData, + universe: DimensionUniverse | None, ) -> TaskNode: """Construct from a `PipelineTask` type and its configuration. Parameters ---------- - label : `str` - Label for the task in the pipeline. + TODO + data : `_TaskNodeImportedData` Internal data for the node. + universe : `lsst.daf.butler.DimensionUniverse` or `None` + Definitions of all dimensions. Returns ------- node : `TaskNode` New task node. - state: `dict` [ `str`, `Any` ] - State object for the networkx representation of this node. The - returned ``node`` object is the value of the "instance" key. Raises ------ @@ -416,8 +411,7 @@ def _from_imported_data( at a stage that may have left the graph in an inconsistent state. All other exceptions should leave the graph unchanged. """ - key = NodeKey(NodeType.TASK, label) - init_key = NodeKey(NodeType.TASK_INIT, label) + init_inputs = { ReadEdge._from_connection_map(init_key, name, data.connection_map) for name in data.connections.initInputs @@ -447,7 +441,6 @@ def _from_imported_data( imported_data=data, ) instance = TaskNode( - key=key, init=init, prerequisite_inputs=prerequisite_inputs, inputs=inputs, @@ -460,7 +453,7 @@ def _from_imported_data( metadata_output=WriteEdge._from_connection_map( key, acc.METADATA_OUTPUT_CONNECTION_NAME, data.connection_map ), - dimensions=None, + dimensions=None if universe is None else universe.extract(data.connections.dimensions), ) return instance @@ -498,7 +491,7 @@ def _from_imported_data( @property def label(self) -> str: """Label of this configuration of a task in the pipeline.""" - return str(self.key) + return self.init.label @property def task_class(self) -> type[PipelineTask]: @@ -545,66 +538,67 @@ def iter_all_outputs(self) -> Iterator[WriteEdge]: if self.log_output is not None: yield self.log_output - def import_and_configure(self, check_edges: bool = True) -> None: + def _imported_and_configured( + self, key: NodeKey, rebuild: bool, universe: DimensionUniverse | None + ) -> TaskNode: # TODO: docs from ..pipelineTask import PipelineTask if self.is_imported: - return + return self task_class = doImportType(self.task_class_name) if not issubclass(task_class, PipelineTask): raise TypeError(f"{self.task_class_name!r} is not a PipelineTask subclass.") config = task_class.ConfigClass() config.loadFromString(self.get_config_str()) - imported_data = _TaskNodeImportedData.configure(self.label, task_class, config) - if check_edges: - if messages := self.diff(self._from_imported_data(self.label, imported_data)): - messages.insert( - 0, - f"Inconsistency between serialized and configured edges for task {self.label!r}:", - ) - raise ReadInconsistencyError("\n".join(messages)) - self.init._imported_data = imported_data + imported_data = _TaskNodeImportedData.configure(key.name, task_class, config) + if rebuild: + return self._from_imported_data( + key, + self.init._key, + imported_data, + universe=universe, + ) + else: + return TaskNode( + TaskInitNode( + self.init._key, + inputs=self.init.inputs, + outputs=self.init.outputs, + config_output=self.init.config_output, + imported_data=imported_data, + ), + prerequisite_inputs=self.prerequisite_inputs, + inputs=self.inputs, + outputs=self.outputs, + log_output=self.log_output, + metadata_output=self.metadata_output, + dimensions=( + universe.extract(self._get_connections().dimensions) if universe is not None else None + ), + ) def diff(self, other: TaskNode) -> list[str]: # TODO: docs return self.init.diff(other.init) - def _get_connections(self) -> PipelineTaskConnections: - # TODO: docs - return self.init._get_connections() - - def _resolved(self, xgraph: networkx.DiGraph, registry: Registry) -> TaskNode: - # Docstring inherited. - if self.dimensions is not None: - if self.dimensions.universe is registry.dimensions: - return self - return TaskNode( - key=self.key, - init=self.init, - prerequisite_inputs=self.prerequisite_inputs, - inputs=self.inputs, - outputs=self.outputs, - log_output=self.log_output, - metadata_output=self.metadata_output, - dimensions=registry.dimensions.extract(self._get_connections().dimensions), - ) - - def _unresolved(self) -> TaskNode: - # Docstring inherited. - if self.dimensions is None: + def _resolved(self, universe: DimensionUniverse) -> TaskNode: + if self.dimensions is not None and self.dimensions.universe is universe: return self return TaskNode( - key=self.key, init=self.init, prerequisite_inputs=self.prerequisite_inputs, inputs=self.inputs, outputs=self.outputs, log_output=self.log_output, metadata_output=self.metadata_output, - dimensions=None, + dimensions=universe.extract(self._get_connections().dimensions), ) + def _get_connections(self) -> PipelineTaskConnections: + # TODO: docs + return self.init._get_connections() + def _serialize(self) -> SerializedTaskNode: # Docstring inherited. return SerializedTaskNode.construct( diff --git a/python/lsst/pipe/base/tests/pipelineStepTester.py b/python/lsst/pipe/base/tests/pipelineStepTester.py index 00f4523ec..a63ec8aa1 100644 --- a/python/lsst/pipe/base/tests/pipelineStepTester.py +++ b/python/lsst/pipe/base/tests/pipelineStepTester.py @@ -26,9 +26,11 @@ import dataclasses import unittest +from typing import cast from lsst.daf.butler import Butler, DatasetType from lsst.pipe.base import Pipeline +from lsst.pipe.base.pipeline_graph import DatasetTypeNode @dataclasses.dataclass @@ -88,25 +90,22 @@ def run(self, butler: Butler, test_case: unittest.TestCase) -> None: pure_inputs: dict[str, str] = dict() for suffix in self.step_suffixes: - step_graph = Pipeline.from_uri(self.filename + suffix).to_graph().resolved(butler.registry) + step_graph = Pipeline.from_uri(self.filename + suffix).to_graph() + step_graph.resolve(butler.registry) pure_inputs.update( - { - node.name: suffix - for node in step_graph.iter_overall_inputs() - if node.name not in all_outputs - } + {name: suffix for name, _ in step_graph.iter_overall_inputs() if name not in all_outputs} ) all_outputs.update( { - name: node.dataset_type + name: cast(DatasetTypeNode, node).dataset_type for name, node in step_graph.dataset_types.items() if step_graph.producer_of(name) is not None } ) for node in step_graph.dataset_types.values(): - butler.registry.registerDatasetType(node.dataset_type) + butler.registry.registerDatasetType(cast(DatasetTypeNode, node).dataset_type) if not pure_inputs.keys() <= self.expected_inputs: missing = [f"{k} ({pure_inputs[k]})" for k in pure_inputs.keys() - self.expected_inputs] diff --git a/tests/test_pipeline_graph.py b/tests/test_pipeline_graph.py index a0a3fb6cf..7ff704ec7 100644 --- a/tests/test_pipeline_graph.py +++ b/tests/test_pipeline_graph.py @@ -31,13 +31,7 @@ import lsst.utils.tests from lsst.daf.butler import DatasetType, DimensionUniverse from lsst.daf.butler.registry import MissingDatasetTypeError -from lsst.pipe.base.pipeline_graph import ( - MutablePipelineGraph, - NodeKey, - NodeType, - PipelineGraph, - ResolvedPipelineGraph, -) +from lsst.pipe.base.pipeline_graph import NodeKey, NodeType, PipelineGraph from lsst.pipe.base.tests import no_dimensions _LOG = logging.getLogger(__name__) @@ -62,109 +56,99 @@ def setUp(self) -> None: # any of those. We add tasks in reverse order to better test sorting. # There is one labeled task subset, 'only_b', with just 'b' in it. self.description = "A pipeline for PipelineGraph unit tests." - self.mgraph = MutablePipelineGraph() - self.mgraph.description = self.description + self.graph = PipelineGraph() + self.graph.description = self.description self.b_config = no_dimensions.NoDimensionsTestConfig() self.b_config.connections.input = "intermediate" - self.mgraph.add_task("b", no_dimensions.NoDimensionsTestTask, self.b_config) + self.graph.add_task("b", no_dimensions.NoDimensionsTestTask, self.b_config) self.a_config = no_dimensions.NoDimensionsTestConfig() self.a_config.connections.output = "intermediate" - self.mgraph.add_task("a", no_dimensions.NoDimensionsTestTask, self.a_config) - self.mgraph.add_task_subset("only_b", ["b"]) + self.graph.add_task("a", no_dimensions.NoDimensionsTestTask, self.a_config) + self.graph.add_task_subset("only_b", ["b"]) self.dimensions = DimensionUniverse() self.maxDiff = None - def test_mutable_accessors(self) -> None: - self.check_base_accessors(self.mgraph) - self.assertTrue(repr(self.mgraph).startswith(f"MutablePipelineGraph({self.description!r}, tasks=")) + def test_unresolved_accessors(self) -> None: + self.check_base_accessors(self.graph) + self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks=")) def test_sorting(self) -> None: - """Test sort methods on MutablePipelineGraph.""" - self.assertFalse(self.mgraph.has_been_sorted) - self.assertFalse(self.mgraph.is_sorted) - self.mgraph.sort() - self.check_sorted(self.mgraph) - - def test_mutable_xgraph_export(self) -> None: - self.check_make_xgraph(self.mgraph, resolved=False) - self.check_make_bipartite_xgraph(self.mgraph, resolved=False) - self.check_make_task_xgraph(self.mgraph, resolved=False) - self.check_make_dataset_type_xgraph(self.mgraph, resolved=False) - - def test_mutable_stream_io(self) -> None: + """Test sort methods on PipelineGraph.""" + self.assertFalse(self.graph.has_been_sorted) + self.assertFalse(self.graph.is_sorted) + self.graph.sort() + self.check_sorted(self.graph) + + def test_unresolved_xgraph_export(self) -> None: + self.check_make_xgraph(self.graph, resolved=False) + self.check_make_bipartite_xgraph(self.graph, resolved=False) + self.check_make_task_xgraph(self.graph, resolved=False) + self.check_make_dataset_type_xgraph(self.graph, resolved=False) + + def test_unresolved_stream_io(self) -> None: stream = io.BytesIO() - self.mgraph.write_stream(stream) + self.graph.write_stream(stream) stream.seek(0) - roundtripped = MutablePipelineGraph.read_stream(stream) + roundtripped = PipelineGraph.read_stream(stream) self.check_make_xgraph(roundtripped, resolved=False) - def test_mutable_file_io(self) -> None: + def test_unresolved_file_io(self) -> None: with lsst.utils.tests.getTempFilePath(".tar.gz") as filename: - self.mgraph.write_uri(filename) - roundtripped = MutablePipelineGraph.read_uri(filename) + self.graph.write_uri(filename) + roundtripped = PipelineGraph.read_uri(filename) self.check_make_xgraph(roundtripped, resolved=False) def test_resolved_accessors(self) -> None: """Test resolving a pipeline graph against a data repository.""" - rgraph = self.mgraph.resolved(MockRegistry(self.dimensions, {})) - self.check_base_accessors(rgraph) - self.check_sorted(rgraph) - self.assertTrue(repr(rgraph).startswith(f"ResolvedPipelineGraph({self.description!r}, tasks=")) + self.graph.resolve(MockRegistry(self.dimensions, {})) + self.check_base_accessors(self.graph) + self.check_sorted(self.graph) + self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks=")) def test_resolved_xgraph_export(self) -> None: - rgraph = self.mgraph.resolved(MockRegistry(self.dimensions, {})) - self.check_make_xgraph(rgraph, resolved=True) - self.check_make_bipartite_xgraph(rgraph, resolved=True) - self.check_make_task_xgraph(rgraph, resolved=True) - self.check_make_dataset_type_xgraph(rgraph, resolved=True) + self.graph.resolve(MockRegistry(self.dimensions, {})) + self.check_make_xgraph(self.graph, resolved=True) + self.check_make_bipartite_xgraph(self.graph, resolved=True) + self.check_make_task_xgraph(self.graph, resolved=True) + self.check_make_dataset_type_xgraph(self.graph, resolved=True) def test_resolved_stream_io(self) -> None: - rgraph = self.mgraph.resolved(MockRegistry(self.dimensions, {})) + self.graph.resolve(MockRegistry(self.dimensions, {})) stream = io.BytesIO() - rgraph.write_stream(stream) + self.graph.write_stream(stream) stream.seek(0) - roundtripped = ResolvedPipelineGraph.read_stream(stream) + roundtripped = PipelineGraph.read_stream(stream) self.check_make_xgraph(roundtripped, resolved=True) def test_resolved_file_io(self) -> None: - rgraph = self.mgraph.resolved(MockRegistry(self.dimensions, {})) + self.graph.resolve(MockRegistry(self.dimensions, {})) with lsst.utils.tests.getTempFilePath(".tar.gz") as filename: - rgraph.write_uri(filename) - roundtripped = ResolvedPipelineGraph.read_uri(filename) + self.graph.write_uri(filename) + roundtripped = PipelineGraph.read_uri(filename) self.check_make_xgraph(roundtripped, resolved=True) - def test_mixed_io(self) -> None: - """Test writing a ResolvedPipelineGraph and reading it as a - MutablePipelineGraph. - """ - rgraph = self.mgraph.resolved(MockRegistry(self.dimensions, {})) - stream = io.BytesIO() - rgraph.write_stream(stream) - stream.seek(0) - roundtripped = MutablePipelineGraph.read_stream(stream) - self.check_make_xgraph(roundtripped, resolved=False) - - def test_mutable_copies(self) -> None: - mcopy = self.mgraph.mutable_copy() - self.assertIsNot(mcopy, self.mgraph) - self.check_make_xgraph(mcopy, resolved=False) - mcopy = copy.copy(self.mgraph) - self.assertIsNot(mcopy, self.mgraph) - self.check_make_xgraph(mcopy, resolved=False) - mcopy = copy.deepcopy(self.mgraph) - self.assertIsNot(mcopy, self.mgraph) - self.check_make_xgraph(mcopy, resolved=False) + def test_unresolved_copies(self) -> None: + copy1 = self.graph.copy() + self.assertIsNot(copy1, self.graph) + self.check_make_xgraph(copy1, resolved=False) + copy2 = copy.copy(self.graph) + self.assertIsNot(copy2, self.graph) + self.check_make_xgraph(copy2, resolved=False) + copy3 = copy.deepcopy(self.graph) + self.assertIsNot(copy3, self.graph) + self.check_make_xgraph(copy3, resolved=False) def test_resolved_copies(self) -> None: - rgraph = self.mgraph.resolved(MockRegistry(self.dimensions, {})) - self.assertIs(rgraph, rgraph.resolved(MockRegistry(self.dimensions, {}))) - self.assertIs(rgraph, copy.copy(rgraph)) - self.assertIs(rgraph, copy.deepcopy(rgraph)) - rcopy = rgraph.resolved(MockRegistry(self.dimensions, {}), redo=True) - self.assertIsNot(rgraph, rcopy) - self.check_make_xgraph(rcopy, resolved=True) - mcopy = rgraph.mutable_copy() - self.check_make_xgraph(mcopy, resolved=False) + self.graph.resolve(MockRegistry(self.dimensions, {})) + copy1 = self.graph.copy() + self.assertIsNot(copy1, self.graph) + self.check_make_xgraph(copy1, resolved=True) + copy2 = copy.copy(self.graph) + self.assertIsNot(copy2, self.graph) + self.check_make_xgraph(copy2, resolved=True) + copy3 = copy.deepcopy(self.graph) + self.assertIsNot(copy3, self.graph) + self.check_make_xgraph(copy3, resolved=True) def check_base_accessors(self, graph: PipelineGraph) -> None: self.assertEqual(graph.description, self.description) @@ -205,7 +189,7 @@ def check_base_accessors(self, graph: PipelineGraph) -> None: }, ) self.assertEqual( - {node.key for node in graph.iter_nodes()}, + {(node_type, name) for node_type, name, _ in graph.iter_nodes()}, { NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.TASK, "b"), @@ -222,7 +206,7 @@ def check_base_accessors(self, graph: PipelineGraph) -> None: NodeKey(NodeType.DATASET_TYPE, "b_metadata"), }, ) - self.assertEqual({node.name for node in graph.iter_overall_inputs()}, {"input"}) + self.assertEqual({name for name, _ in graph.iter_overall_inputs()}, {"input"}) self.assertEqual({label for label in graph.consumers_of("input")}, {"a"}) self.assertEqual({label for label in graph.consumers_of("intermediate")}, {"b"}) self.assertEqual({label for label in graph.consumers_of("output")}, set()) @@ -234,7 +218,7 @@ def check_sorted(self, graph: PipelineGraph) -> None: self.assertTrue(graph.has_been_sorted) self.assertTrue(graph.is_sorted) self.assertEqual( - [node.key for node in graph.iter_nodes()], + [(node_type, name) for node_type, name, _ in graph.iter_nodes()], [ # We only advertise that the order is topological and # deterministic, so this test is slightly over-specified; there @@ -274,39 +258,42 @@ def check_sorted(self, graph: PipelineGraph) -> None: def check_make_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: xgraph = graph.make_xgraph() - self.assertEqual( - set(xgraph.edges), + expected_edges = ( {edge.key for edge in graph.iter_edges()} | {edge.key for edge in graph.iter_edges(init=True)} | { (NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK, "a")), (NodeKey(NodeType.TASK_INIT, "b"), NodeKey(NodeType.TASK, "b")), - }, - ) - self.assertEqual( - dict(xgraph.nodes.items()), - { - NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved), - NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved), - NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved), - NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved), - NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), - NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), - NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), - NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), - NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), - NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), - NodeKey(NodeType.DATASET_TYPE, "input"): self.get_expected_connection_node( - "input", resolved, True - ), - NodeKey(NodeType.DATASET_TYPE, "intermediate"): self.get_expected_connection_node( - "intermediate", resolved, False - ), - NodeKey(NodeType.DATASET_TYPE, "output"): self.get_expected_connection_node( - "output", resolved, False - ), - }, + } ) + test_edges = set(xgraph.edges) + self.assertEqual(test_edges, expected_edges) + expected_nodes = { + NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved), + NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved), + NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved), + NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "input"): self.get_expected_connection_node( + "input", resolved, True + ), + NodeKey(NodeType.DATASET_TYPE, "intermediate"): self.get_expected_connection_node( + "intermediate", resolved, False + ), + NodeKey(NodeType.DATASET_TYPE, "output"): self.get_expected_connection_node( + "output", resolved, False + ), + } + test_nodes = dict(xgraph.nodes.items()) + self.assertEqual(set(test_nodes.keys()), set(expected_nodes.keys())) + for key, expected_node in expected_nodes.items(): + test_node = test_nodes[key] + self.assertEqual(expected_node, test_node, key) def check_make_bipartite_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: run_xgraph = graph.make_bipartite_xgraph() @@ -439,6 +426,7 @@ def get_expected_config_node(self, label: str, resolved: bool) -> dict[str, Any] ), "is_initial_query_constraint": False, "is_prerequisite": False, + "is_registered": False, "dimensions": self.dimensions.empty, "storage_class_name": acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, "bipartite": 0, @@ -456,6 +444,7 @@ def get_expected_log_node(self, label: str, resolved: bool) -> dict[str, Any]: ), "is_initial_query_constraint": False, "is_prerequisite": False, + "is_registered": False, "dimensions": self.dimensions.empty, "storage_class_name": acc.LOG_OUTPUT_STORAGE_CLASS, "bipartite": 0, @@ -473,6 +462,7 @@ def get_expected_metadata_node(self, label: str, resolved: bool) -> dict[str, An ), "is_initial_query_constraint": False, "is_prerequisite": False, + "is_registered": False, "dimensions": self.dimensions.empty, "storage_class_name": acc.METADATA_OUTPUT_STORAGE_CLASS, "bipartite": 0, @@ -492,6 +482,7 @@ def get_expected_connection_node( ), "is_initial_query_constraint": is_initial_query_constraint, "is_prerequisite": False, + "is_registered": False, "dimensions": self.dimensions.empty, "storage_class_name": "StructuredDataDict", "bipartite": 0,