diff --git a/python/lsst/pipe/base/__init__.py b/python/lsst/pipe/base/__init__.py
index e440d9eff..ed58c5935 100644
--- a/python/lsst/pipe/base/__init__.py
+++ b/python/lsst/pipe/base/__init__.py
@@ -12,9 +12,9 @@
from .graphBuilder import *
from .pipeline import *
-# We import the main PipelineGraph types and the module (above), but we don't
+# We import the main PipelineGraph type and the module (above), but we don't
# lift all symbols to package scope.
-from .pipeline_graph import MutablePipelineGraph, ResolvedPipelineGraph
+from .pipeline_graph import PipelineGraph
from .pipelineTask import *
from .struct import *
from .task import *
diff --git a/python/lsst/pipe/base/pipeTools.py b/python/lsst/pipe/base/pipeTools.py
index ed94d28d2..3fd907322 100644
--- a/python/lsst/pipe/base/pipeTools.py
+++ b/python/lsst/pipe/base/pipeTools.py
@@ -33,7 +33,7 @@
from .pipeline import Pipeline, TaskDef
# Exceptions re-exported here for backwards compatibility.
-from .pipeline_graph import DuplicateOutputError, MutablePipelineGraph, PipelineDataCycleError # noqa: F401
+from .pipeline_graph import DuplicateOutputError, PipelineDataCycleError, PipelineGraph # noqa: F401
if TYPE_CHECKING:
from .taskFactory import TaskFactory
@@ -73,7 +73,7 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF
if isinstance(pipeline, Pipeline):
graph = pipeline.to_graph()
else:
- graph = MutablePipelineGraph()
+ graph = PipelineGraph()
for task_def in pipeline:
graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
# Can't use graph.is_sorted because that requires sorted dataset type names
@@ -111,7 +111,7 @@ def orderPipeline(pipeline: Pipeline | Iterable[TaskDef]) -> list[TaskDef]:
if isinstance(pipeline, Pipeline):
graph = pipeline.to_graph()
else:
- graph = MutablePipelineGraph()
+ graph = PipelineGraph()
for task_def in pipeline:
graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
graph.sort()
diff --git a/python/lsst/pipe/base/pipeline.py b/python/lsst/pipe/base/pipeline.py
index 3b6f011b7..dea33448f 100644
--- a/python/lsst/pipe/base/pipeline.py
+++ b/python/lsst/pipe/base/pipeline.py
@@ -53,13 +53,7 @@
# -----------------------------
# Imports for other modules --
-from lsst.daf.butler import (
- DataCoordinate,
- DatasetType,
- DimensionUniverse,
- NamedValueSet,
- Registry,
-)
+from lsst.daf.butler import DataCoordinate, DatasetType, DimensionUniverse, NamedValueSet, Registry
from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.utils import doImportType
from lsst.utils.introspection import get_full_type_name
@@ -757,7 +751,7 @@ def write_to_uri(self, uri: ResourcePathExpression) -> None:
"""
self._pipelineIR.write_to_uri(uri)
- def to_graph(self) -> pipeline_graph.MutablePipelineGraph:
+ def to_graph(self) -> pipeline_graph.PipelineGraph:
"""Construct a pipeline graph from this pipeline.
Constructing a graph applies all configuration overrides, freezes all
@@ -767,10 +761,10 @@ def to_graph(self) -> pipeline_graph.MutablePipelineGraph:
Returns
-------
- graph : `pipeline_graph.MutablePipelineGraph`
+ graph : `pipeline_graph.PipelineGraph`
Representation of the pipeline as a graph.
"""
- graph = pipeline_graph.MutablePipelineGraph()
+ graph = pipeline_graph.PipelineGraph()
graph.description = self._pipelineIR.description
for label in self._pipelineIR.tasks:
self._add_task_to_graph(label, graph)
@@ -810,7 +804,7 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
"""
yield from self.to_graph()._iter_task_defs()
- def _add_task_to_graph(self, label: str, graph: pipeline_graph.MutablePipelineGraph) -> None:
+ def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) -> None:
"""Add a single task from this pipeline to a pipeline graph that is
under construction.
@@ -818,7 +812,7 @@ def _add_task_to_graph(self, label: str, graph: pipeline_graph.MutablePipelineGr
----------
label : `str`
Label for the task to be added.
- graph : `pipeline_graph.MutablePipelineGraph`
+ graph : `pipeline_graph.PipelineGraph`
Graph to add the task to.
"""
if (taskIR := self._pipelineIR.tasks.get(label)) is None:
@@ -845,7 +839,7 @@ def __getitem__(self, item: str) -> TaskDef:
# Making a whole graph and then making a TaskDef from that is pretty
# backwards, but I'm hoping to deprecate this method shortly in favor
# of making the graph explicitly and working with its node objects.
- graph = pipeline_graph.MutablePipelineGraph()
+ graph = pipeline_graph.PipelineGraph()
self._add_task_to_graph(item, graph)
(result,) = graph._iter_task_defs()
return result
@@ -971,17 +965,19 @@ def fromTaskDef(
# the whole class soon, but for now and before it's actually removed
# it's more important to avoid duplication with PipelineGraph's dataset
# type resolution logic.
- mgraph = pipeline_graph.MutablePipelineGraph()
- mgraph.add_task(taskDef.label, taskDef.taskClass, taskDef.config, taskDef.connections)
- rgraph = mgraph.resolved(registry)
- (task_node,) = rgraph.tasks.values()
- return cls._from_graph_nodes(task_node, rgraph.dataset_types)
+ graph = pipeline_graph.PipelineGraph()
+ graph.add_task(taskDef.label, taskDef.taskClass, taskDef.config, taskDef.connections)
+ graph.resolve(registry)
+ (task_node,) = graph.tasks.values()
+ return cls._from_graph_nodes(
+ task_node, cast(Mapping[str, pipeline_graph.DatasetTypeNode], graph.dataset_types)
+ )
@classmethod
def _from_graph_nodes(
cls,
task_node: pipeline_graph.TaskNode,
- dataset_type_nodes: Mapping[str, pipeline_graph.ResolvedDatasetTypeNode],
+ dataset_type_nodes: Mapping[str, pipeline_graph.DatasetTypeNode],
include_configs: bool = True,
) -> TaskDatasetTypes:
"""Construct from `PipelineGraph` nodes.
@@ -1146,17 +1142,17 @@ def fromPipeline(
of the same `Pipeline`.
"""
if isinstance(pipeline, Pipeline):
- mgraph = pipeline.to_graph()
+ graph = pipeline.to_graph()
else:
- mgraph = pipeline_graph.MutablePipelineGraph()
+ graph = pipeline_graph.PipelineGraph()
for task_def in pipeline:
- mgraph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
- rgraph = mgraph.resolved(registry)
+ graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
+ graph.resolve(registry)
byTask = dict()
- for task_node in rgraph.tasks.values():
+ for task_node in graph.tasks.values():
byTask[task_node.label] = TaskDatasetTypes._from_graph_nodes(
task_node,
- rgraph.dataset_types,
+ cast(Mapping[str, pipeline_graph.DatasetTypeNode], graph.dataset_types),
include_configs=include_configs,
)
result = cls(
@@ -1177,8 +1173,9 @@ def fromPipeline(
# PipelineGraph does, by putting components in the edge objects). But
# including all components as well is what this code has done in the
# past and changing that would break downstream code.
- for dataset_type_node in rgraph.dataset_types.values():
- if consumers := rgraph.consumers_of(dataset_type_node.name):
+ for dataset_type_node in graph.dataset_types.values():
+ assert dataset_type_node is not None, "Graph is expected to be resolved."
+ if consumers := graph.consumers_of(dataset_type_node.name):
dataset_types = [
(
dataset_type_node.dataset_type.makeComponentDatasetType(edge.component)
@@ -1188,21 +1185,21 @@ def fromPipeline(
for edge in consumers.values()
]
if any(edge.is_init for edge in consumers.values()):
- if rgraph.producer_of(dataset_type_node.name) is None:
+ if graph.producer_of(dataset_type_node.name) is None:
result.initInputs.update(dataset_types)
else:
result.initIntermediates.update(dataset_types)
else:
if dataset_type_node.is_prerequisite:
result.prerequisites.update(dataset_types)
- elif rgraph.producer_of(dataset_type_node.name) is None:
+ elif graph.producer_of(dataset_type_node.name) is None:
result.inputs.update(dataset_types)
if dataset_type_node.is_initial_query_constraint:
result.queryConstraints.add(dataset_type_node.dataset_type)
- elif rgraph.consumers_of(dataset_type_node.name):
+ elif graph.consumers_of(dataset_type_node.name):
result.intermediates.update(dataset_types)
else:
- producer = rgraph.producer_of(dataset_type_node.name)
+ producer = graph.producer_of(dataset_type_node.name)
assert (
producer is not None
), "Dataset type must have either a producer or consumers to be in graph."
@@ -1254,15 +1251,15 @@ def initOutputNames(
Name of the dataset type.
"""
if isinstance(pipeline, Pipeline):
- mgraph = pipeline.to_graph()
+ graph = pipeline.to_graph()
else:
- mgraph = pipeline_graph.MutablePipelineGraph()
+ graph = pipeline_graph.PipelineGraph()
for task_def in pipeline:
- mgraph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
+ graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
if include_packages:
# Package versions dataset type
yield cls.packagesDatasetName
- for task_node in mgraph.tasks.values():
+ for task_node in graph.tasks.values():
edges = task_node.init.iter_all_outputs() if include_configs else task_node.init.outputs
for edge in edges:
yield edge.dataset_type_name
diff --git a/python/lsst/pipe/base/pipeline_graph/__init__.py b/python/lsst/pipe/base/pipeline_graph/__init__.py
index fa4c211f1..3cf7a8101 100644
--- a/python/lsst/pipe/base/pipeline_graph/__init__.py
+++ b/python/lsst/pipe/base/pipeline_graph/__init__.py
@@ -23,9 +23,7 @@
from ._dataset_types import *
from ._edges import *
from ._exceptions import *
-from ._mutable_pipeline_graph import *
from ._nodes import *
from ._pipeline_graph import *
-from ._resolved_pipeline_graph import *
from ._task_subsets import *
from ._tasks import *
diff --git a/python/lsst/pipe/base/pipeline_graph/_dataset_types.py b/python/lsst/pipe/base/pipeline_graph/_dataset_types.py
index b738583aa..1a547ff9b 100644
--- a/python/lsst/pipe/base/pipeline_graph/_dataset_types.py
+++ b/python/lsst/pipe/base/pipeline_graph/_dataset_types.py
@@ -20,70 +20,85 @@
# along with this program. If not, see .
from __future__ import annotations
-__all__ = (
- "DatasetTypeNode",
- "ResolvedDatasetTypeNode",
-)
+__all__ = ("DatasetTypeNode",)
from typing import TYPE_CHECKING, Any
import networkx
-from lsst.daf.butler import DatasetRef, DatasetType, Registry
+from lsst.daf.butler import DatasetRef, DatasetType, DimensionGraph, Registry, StorageClass
from lsst.daf.butler.registry import MissingDatasetTypeError
from pydantic import BaseModel
-from ._nodes import Node, NodeKey
+from ._nodes import NodeKey, NodeType
if TYPE_CHECKING:
from ._edges import ReadEdge, WriteEdge
from ._tasks import TaskInitNode, TaskNode
-class DatasetTypeNode(Node):
- """A node in a pipeline graph that represents a dataset type.
+class DatasetTypeNode:
+ """A node in a resolved pipeline graph that represents a dataset type.
Parameters
----------
- node : `NodeKey`
- Key for this node in the graph.
+ dataset_type : `DatasetType`
+ Common definition of this dataset type for the graph.
+ is_prerequisite: `bool`
+ Whether this dataset type is a prerequisite input that must exist in
+ the Registry before graph creation.
+ is_initial_query_constraint : `bool`
+ Whether this dataset should be included as a constraint in the initial
+ query for data IDs in QuantumGraph generation.
+
+ This is only `True` for dataset types that are overall regular inputs,
+ and also `False` if all such connections had
+ ``deferQueryConstraint=True``.
+ is_registered : `bool`
+ Whether this dataset type was registered in the data repository when it
+ was resolved.
+
+ When `is_registered` is `True`, the storage class is guaranteed to
+ match the data repository definition.
Notes
-----
- This class only holds information that can be pulled unambiguously from
- `.PipelineTask` a single definitions, without input from the data
- repository or other tasks - which amounts to just the parent dataset type
- name. The `ResolvedDatasetTypeNode` subclass also includes information
- from the data repository and holds an actual `DatasetType` instance.
-
A dataset type node represents a common definition of the dataset type
- across the entire graph, which means it never refers to a component.
+ across the entire graph - it is never a component, and when storage class
+ information is present (in `DatasetTypeNode`) this is the registry dataset
+ type's storage class or (if there isn't one) the one defined by the
+ producing task.
Dataset type nodes are intentionally not equality comparable, since there
- are many different (and useful) ways to compare its resolved variant, with
- no clear winner as the most obvious behavior.
+ are many different (and useful) ways to compare these objects with no clear
+ winner as the most obvious behavior.
"""
- @property
- def name(self) -> str:
- """Name of the dataset type.
-
- This is always the parent dataset type, never that of a component.
- """
- return str(self.key)
+ def __init__(
+ self,
+ *,
+ dataset_type: DatasetType,
+ is_prerequisite: bool,
+ is_initial_query_constraint: bool,
+ is_registered: bool,
+ ):
+ self.dataset_type = dataset_type
+ self.is_prerequisite = is_prerequisite
+ self.is_initial_query_constraint = is_initial_query_constraint
+ self.is_registered = is_registered
- def _resolved(self, xgraph: networkx.DiGraph, registry: Registry) -> ResolvedDatasetTypeNode:
- # Docstring inherited.
+ @classmethod
+ def _from_edges(cls, key: NodeKey, xgraph: networkx.DiGraph, registry: Registry) -> DatasetTypeNode:
try:
- dataset_type = registry.getDatasetType(self.name)
- in_data_repo = True
+ dataset_type = registry.getDatasetType(key.name)
+ is_registered = True
except MissingDatasetTypeError:
dataset_type = None
- in_data_repo = False
+ is_registered = False
is_initial_query_constraint = True
is_prerequisite: bool | None = None
producer: str | None = None
write_edge: WriteEdge
- for _, _, write_edge in xgraph.in_edges(self.key, data="instance"): # will iterate zero or one time
+ for _, _, write_edge in xgraph.in_edges(key, data="instance"): # will iterate zero or one time
task_node: TaskNode | TaskInitNode = xgraph.nodes[write_edge.task_key]["instance"]
connection_map = task_node._get_connection_map()
dataset_type = write_edge._resolve_dataset_type(
@@ -95,7 +110,7 @@ def _resolved(self, xgraph: networkx.DiGraph, registry: Registry) -> ResolvedDat
is_initial_query_constraint = False
read_edge: ReadEdge
consumers: list[str] = []
- for _, _, read_edge in xgraph.out_edges(self.key, data="instance"):
+ for _, _, read_edge in xgraph.out_edges(key, data="instance"):
task_node = xgraph.nodes[read_edge.task_key]["instance"]
connection_map = task_node._get_connection_map()
dataset_type, is_initial_query_constraint, is_prerequisite = read_edge._resolve_dataset_type(
@@ -104,80 +119,62 @@ def _resolved(self, xgraph: networkx.DiGraph, registry: Registry) -> ResolvedDat
universe=registry.dimensions,
is_initial_query_constraint=is_initial_query_constraint,
is_prerequisite=is_prerequisite,
- in_data_repo=in_data_repo,
+ is_registered=is_registered,
producer=producer,
consumers=consumers,
)
consumers.append(read_edge.task_label)
assert dataset_type is not None, "Graph structure guarantees at least one edge."
assert is_prerequisite is not None, "Having at least one edge guarantees is_prerequisite is known."
- return ResolvedDatasetTypeNode(
- key=self.key,
+ return DatasetTypeNode(
dataset_type=dataset_type,
is_initial_query_constraint=is_initial_query_constraint,
is_prerequisite=is_prerequisite,
+ is_registered=is_registered,
)
- def _unresolved(self) -> DatasetTypeNode:
- # Docstring inherited.
- return self
-
- def _serialize(self) -> SerializedDatasetTypeNode:
- # Docstring inherited.
- return SerializedDatasetTypeNode.construct()
-
- def _to_xgraph_state(self) -> dict[str, Any]:
- # Docstring inherited.
- return {
- "bipartite": self.key.node_type.bipartite,
- }
-
+ def _resolved(self, registry: Registry) -> DatasetTypeNode:
+ """Resolve an existing DatasetTypeNode against current data repository
+ content.
-class ResolvedDatasetTypeNode(DatasetTypeNode):
- """A node in a resolved pipeline graph that represents a dataset type.
+ Since DatasetTypeNodes are updated or replaced with `None` whenever new
+ edges are added to the graph, the only thing that might have changed
+ when this method is the registration in the data repository.
+ """
+ try:
+ dataset_type = registry.getDatasetType(self.dataset_type.name)
+ is_registered = True
+ except MissingDatasetTypeError:
+ dataset_type = self.dataset_type
+ is_registered = False
+ if is_registered == self.is_registered:
+ return self
+ return DatasetTypeNode(
+ is_prerequisite=self.is_prerequisite,
+ dataset_type=dataset_type,
+ is_initial_query_constraint=self.is_initial_query_constraint,
+ is_registered=self.is_registered,
+ )
- Parameters
- ----------
- node : `NodeKey`
- Key for this node in the graph.
- is_prerequisite: `bool`
- Whether this dataset type is a prerequisite input that must exist in
- the Registry before graph creation.
- dataset_type : `DatasetType`
- Common definition of this dataset type for the graph.
- is_initial_query_constraint : `bool`
- Whether this dataset should be included as a constraint in the initial
- query for data IDs in QuantumGraph generation.
+ @property
+ def name(self) -> str:
+ """Name of the dataset type.
- This is only `True` for dataset types that are overall regular inputs,
- and also `False` if all such connections had
- ``deferQueryConstraint=True``.
+ This is always the parent dataset type, never that of a component.
+ """
+ return self.dataset_type.name
- Notes
- -----
- A dataset type node represents a common definition of the dataset type
- across the entire graph - it is never a component, and when storage class
- information is present (in `ResolvedDatasetTypeNode`) this is the registry
- dataset type's storage class or (if there isn't one) the one defined by the
- producing task.
+ @property
+ def dimensions(self) -> DimensionGraph:
+ return self.dataset_type.dimensions
- Dataset type nodes are intentionally not equality comparable, since there
- are many different (and useful) ways to compare these objects with no clear
- winner as the most obvious behavior.
- """
+ @property
+ def storage_class_name(self) -> str:
+ return self.dataset_type.storageClass_name
- def __init__(
- self,
- key: NodeKey,
- *,
- is_prerequisite: bool,
- dataset_type: DatasetType,
- is_initial_query_constraint: bool,
- ):
- super().__init__(key)
- self.dataset_type = dataset_type
- self.is_initial_query_constraint = is_initial_query_constraint
- self.is_prerequisite = is_prerequisite
+ @property
+ def storage_class(self) -> StorageClass:
+ return self.dataset_type.storageClass
dataset_type: DatasetType
"""Common definition of this dataset type for the graph.
@@ -196,13 +193,13 @@ def __init__(
the Registry before graph creation.
"""
- def _resolved(self, xgraph: networkx.DiGraph, registry: Registry) -> ResolvedDatasetTypeNode:
- # Docstring inherited.
- return self
+ is_registered: bool
+ """Whether this dataset type was registered in the data repository when
+ it was resolved.
- def _unresolved(self) -> DatasetTypeNode:
- # Docstring inherited.
- return DatasetTypeNode(key=self.key)
+ When `is_registered` is `True`, the storage class is guaranteed to match
+ the data repository definition.
+ """
def generalize_ref(self, ref: DatasetRef) -> DatasetRef:
"""Convert a `~lsst.daf.butler.DatasetRef` with the dataset type
@@ -223,6 +220,7 @@ def _serialize(self) -> SerializedDatasetTypeNode:
is_calibration=self.dataset_type.isCalibration(),
is_initial_query_constraint=self.is_initial_query_constraint,
is_prerequisite=self.is_prerequisite,
+ is_registered=self.is_registered,
)
def _to_xgraph_state(self) -> dict[str, Any]:
@@ -231,9 +229,10 @@ def _to_xgraph_state(self) -> dict[str, Any]:
"dataset_type": self.dataset_type,
"is_initial_query_constraint": self.is_initial_query_constraint,
"is_prerequisite": self.is_prerequisite,
+ "is_registered": self.is_registered,
"dimensions": self.dataset_type.dimensions,
"storage_class_name": self.dataset_type.storageClass_name,
- "bipartite": self.key.node_type.bipartite,
+ "bipartite": NodeType.DATASET_TYPE.bipartite,
}
@@ -243,4 +242,5 @@ class SerializedDatasetTypeNode(BaseModel):
is_calibration: bool = False
is_initial_query_constraint: bool = False
is_prerequisite: bool = False
+ is_registered: bool = False
index: int | None = None
diff --git a/python/lsst/pipe/base/pipeline_graph/_edges.py b/python/lsst/pipe/base/pipeline_graph/_edges.py
index eb0897430..700cbc296 100644
--- a/python/lsst/pipe/base/pipeline_graph/_edges.py
+++ b/python/lsst/pipe/base/pipeline_graph/_edges.py
@@ -151,8 +151,10 @@ def adapt_dataset_ref(self, ref: DatasetRef) -> DatasetRef:
"""
raise NotImplementedError()
- def _check_dataset_type(
- self, xgraph: networkx.DiGraph, dataset_type_node: DatasetTypeNode
+ def _update_dataset_type(
+ self,
+ xgraph: networkx.DiGraph,
+ dataset_type_node: DatasetTypeNode | None,
) -> DatasetTypeNode | None:
"""Check the a potential graph-wide definition of a dataset type for
consistency with this edge.
@@ -161,16 +163,16 @@ def _check_dataset_type(
-----------
xgraph : `networkx.DiGraph`
Directed bipartite graph representing the full pipeline.
- dataset_type_node : `DatasetTypeNode`
+ dataset_type_node : `DatasetTypeNode` or `None`
Dataset type node to be checked and possibly updated.
Returns
-------
- updated : `bool`
- New `DatasetTypeNode` if it needs to be changed, or `None` if it
- does not.
+ updated : `DatasetTypeNode` or `None`
+ Possibly-updated node to include in the graph after the addition
+ of this edge.
"""
- pass
+ return None
@abstractmethod
def _serialize(self) -> BaseModel:
@@ -325,7 +327,7 @@ def _resolve_dataset_type(
universe: DimensionUniverse,
producer: str | None,
consumers: Sequence[str],
- in_data_repo: bool,
+ is_registered: bool,
) -> tuple[DatasetType, bool, bool]:
"""Participate in the construction of the graph-wide `DatasetType`
object associated with this edge.
@@ -356,7 +358,7 @@ def _resolve_dataset_type(
consumers : `Sequence` [ `str` ]
Labels for other consuming tasks that have already participated in
this dataset type's resolution.
- in_data_repo : `bool`
+ is_registered : `bool`
Whether are registration for this dataset type was found in the
data repository.
@@ -416,7 +418,7 @@ def _resolve_dataset_type(
)
def report_current_origin() -> str:
- if in_data_repo:
+ if is_registered:
return "data repository"
elif producer is not None:
return f"producing task {producer!r}"
@@ -541,13 +543,18 @@ def _from_connection_map(
connection_name=connection_name,
)
- def _check_dataset_type(self, xgraph: networkx.DiGraph, dataset_type_node: DatasetTypeNode) -> None:
+ def _update_dataset_type(
+ self,
+ xgraph: networkx.DiGraph,
+ dataset_type_node: DatasetTypeNode | None,
+ ) -> DatasetTypeNode | None:
# Docstring inherited.
- for existing_producer in xgraph.predecessors(dataset_type_node.key):
+ for existing_producer in xgraph.predecessors(self.dataset_type_key):
raise DuplicateOutputError(
- f"Dataset type {dataset_type_node.name} is produced by both {self.task_label!r} "
+ f"Dataset type {self.parent_dataset_type_name!r} is produced by both {self.task_label!r} "
f"and {existing_producer!r}."
)
+ return None
def _resolve_dataset_type(
self, *, connection: BaseConnection, current: DatasetType | None, universe: DimensionUniverse
diff --git a/python/lsst/pipe/base/pipeline_graph/_exceptions.py b/python/lsst/pipe/base/pipeline_graph/_exceptions.py
index c4164a632..467513aa4 100644
--- a/python/lsst/pipe/base/pipeline_graph/_exceptions.py
+++ b/python/lsst/pipe/base/pipeline_graph/_exceptions.py
@@ -27,8 +27,9 @@
"PipelineDataCycleError",
"PipelineGraphError",
"PipelineGraphReadError",
- "TaskNotImportedError",
"ReadInconsistencyError",
+ "UnresolvedGraphError",
+ "TaskNotImportedError",
)
@@ -61,6 +62,12 @@ class IncompatibleDatasetTypeError(PipelineGraphError):
"""
+class UnresolvedGraphError(PipelineGraphError):
+ """Exception raised when an operation requires dimensions or dataset types
+ to have been resolved, but they have not been.
+ """
+
+
class PipelineGraphReadError(PipelineGraphError, IOError):
"""Exception raised when a serialized PipelineGraph cannot be read."""
diff --git a/python/lsst/pipe/base/pipeline_graph/_extract_helper.py b/python/lsst/pipe/base/pipeline_graph/_extract_helper.py
deleted file mode 100644
index b591e6a71..000000000
--- a/python/lsst/pipe/base/pipeline_graph/_extract_helper.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# This file is part of pipe_base.
-#
-# Developed for the LSST Data Management System.
-# This product includes software developed by the LSST Project
-# (http://www.lsst.org).
-# See the COPYRIGHT file at the top-level directory of this distribution
-# for details of code ownership.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see .
-from __future__ import annotations
-
-__all__ = ("ExtractHelper",)
-
-from collections.abc import Iterable
-from types import EllipsisType
-from typing import TYPE_CHECKING, Generic, TypeVar
-
-import networkx
-import networkx.algorithms.bipartite
-import networkx.algorithms.dag
-from lsst.utils.iteration import ensure_iterable
-
-from ._nodes import Node, NodeKey, NodeType
-
-if TYPE_CHECKING:
- from ._mutable_pipeline_graph import MutablePipelineGraph
- from ._pipeline_graph import PipelineGraph
-
-
-_P = TypeVar("_P", bound="PipelineGraph", covariant=True)
-
-
-class ExtractHelper(Generic[_P]):
- def __init__(self, parent: _P) -> None:
- self._parent = parent
- self._run_xgraph: networkx.DiGraph | None = None
- self._task_keys: set[NodeKey] = set()
-
- def include_tasks(self, labels: str | Iterable[str] | EllipsisType = ...) -> None:
- if labels is ...:
- self._task_keys.update(key for key in self._parent._xgraph if key.node_type is NodeType.TASK)
- else:
- self._task_keys.update(
- NodeKey(NodeType.TASK, task_label) for task_label in ensure_iterable(labels)
- )
-
- def exclude_tasks(self, labels: str | Iterable[str]) -> None:
- self._task_keys.difference_update(
- NodeKey(NodeType.TASK, task_label) for task_label in ensure_iterable(labels)
- )
-
- def include_subset(self, label: str) -> None:
- self._task_keys.update(node.key for node in self._parent.task_subsets[label].values())
-
- def exclude_subset(self, label: str) -> None:
- self._task_keys.difference_update(node.key for node in self._parent.task_subsets[label].values())
-
- def start_after(self, names: str | Iterable[str], node_type: NodeType) -> None:
- to_exclude: set[NodeKey] = set()
- for name in ensure_iterable(names):
- key = NodeKey(node_type, name)
- to_exclude.update(networkx.algorithms.dag.ancestors(self._get_run_xgraph(), key))
- to_exclude.add(key)
- self._task_keys.difference_update(to_exclude)
-
- def stop_at(self, names: str | Iterable[str], node_type: NodeType) -> None:
- to_exclude: set[NodeKey] = set()
- for name in ensure_iterable(names):
- key = NodeKey(node_type, name)
- to_exclude.update(networkx.algorithms.dag.descendants(self._get_run_xgraph(), key))
- self._task_keys.difference_update(to_exclude)
-
- def finish(self, description: str | None = None) -> MutablePipelineGraph:
- from ._mutable_pipeline_graph import MutablePipelineGraph
-
- if description is None:
- description = self._parent._description
- # Combine the task_keys we're starting with and the keys for their init
- # nodes.
- keys = self._task_keys | {NodeKey(NodeType.TASK_INIT, key.name) for key in self._task_keys}
- # Also add the keys for the adjacent dataset type nodes.
- keys.update(networkx.node_boundary(self._parent._xgraph.to_undirected(as_view=True), keys))
- # Make the new backing networkx graph.
- xgraph: networkx.DiGraph = self._parent._xgraph.subgraph(keys).copy()
- for state in xgraph.nodes.values():
- node: Node = state["instance"]
- state["instance"] = node._unresolved()
- result = MutablePipelineGraph.__new__(MutablePipelineGraph)
- result._init_from_args(xgraph, None, description=description)
- return result
-
- def _get_run_xgraph(self) -> networkx.DiGraph:
- if self._run_xgraph is None:
- self._run_xgraph = self._parent.make_bipartite_xgraph(init=False)
- return self._run_xgraph
diff --git a/python/lsst/pipe/base/pipeline_graph/_io.py b/python/lsst/pipe/base/pipeline_graph/_io.py
index 9ced70f92..57572e6de 100644
--- a/python/lsst/pipe/base/pipeline_graph/_io.py
+++ b/python/lsst/pipe/base/pipeline_graph/_io.py
@@ -22,22 +22,20 @@
import os
import tarfile
-from abc import abstractmethod
from collections.abc import Sequence
-from typing import Any, BinaryIO, Generic, TypeVar
+from typing import Any, BinaryIO, TypeVar
import networkx
import pydantic
+from lsst.daf.butler import DatasetType, DimensionConfig, DimensionGraph, DimensionUniverse
-from ._dataset_types import SerializedDatasetTypeNode
+from ._dataset_types import DatasetTypeNode, SerializedDatasetTypeNode
from ._edges import ReadEdge, SerializedEdge, WriteEdge
from ._exceptions import PipelineGraphReadError
-from ._mapping_views import _D, _T
from ._nodes import NodeKey, NodeType
from ._task_subsets import SerializedTaskSubset, TaskSubset
-from ._tasks import SerializedTaskInitNode, SerializedTaskNode, TaskInitNode
+from ._tasks import SerializedTaskInitNode, SerializedTaskNode, TaskInitNode, TaskNode
-_S = TypeVar("_S", bound="TaskSubset", covariant=True)
_U = TypeVar("_U")
_IO_VERSION_INFO = (0, 0, 1)
@@ -90,33 +88,47 @@ def read_stream(cls, stream: BinaryIO) -> SerializedPipelineGraph:
return serialized_graph
-class PipelineGraphReader(Generic[_T, _D, _S]):
+class PipelineGraphReader:
def __init__(self) -> None:
self.xgraph = networkx.DiGraph()
self.sort_keys: Sequence[NodeKey] | None = None
- self.task_subsets: dict[str, _S] = {}
+ self.task_subsets: dict[str, TaskSubset] = {}
self.description: str = ""
+ self.universe: DimensionUniverse | None = None
def deserialize_graph(self, serialized_graph: SerializedPipelineGraph) -> None:
+ if serialized_graph.dimensions is not None:
+ self.universe = DimensionUniverse(
+ config=DimensionConfig(
+ expect_not_none(
+ serialized_graph.dimensions,
+ "Serialized pipeline graph has not been resolved; "
+ "load it is a MutablePipelineGraph instead.",
+ )
+ )
+ )
sort_index_map: dict[int, NodeKey] = {}
for dataset_type_name, serialized_dataset_type in serialized_graph.dataset_types.items():
- dataset_type_node = self.deserialize_dataset_type(dataset_type_name, serialized_dataset_type)
+ dataset_type_key = NodeKey(NodeType.DATASET_TYPE, dataset_type_name)
+ dataset_type_node = self.deserialize_dataset_type(dataset_type_key, serialized_dataset_type)
self.xgraph.add_node(
- dataset_type_node.key, instance=dataset_type_node, bipartite=NodeType.DATASET_TYPE.value
+ dataset_type_key, instance=dataset_type_node, bipartite=NodeType.DATASET_TYPE.value
)
if serialized_dataset_type.index is not None:
- sort_index_map[serialized_dataset_type.index] = dataset_type_node.key
+ sort_index_map[serialized_dataset_type.index] = dataset_type_key
for task_label, serialized_task in serialized_graph.tasks.items():
- task_node = self.deserialize_task(task_label, serialized_task)
+ task_key = NodeKey(NodeType.TASK, task_label)
+ task_init_key = NodeKey(NodeType.TASK_INIT, task_label)
+ task_node = self.deserialize_task(task_key, task_init_key, serialized_task)
if serialized_task.index is not None:
- sort_index_map[serialized_task.index] = task_node.key
+ sort_index_map[serialized_task.index] = task_key
if serialized_task.init.index is not None:
- sort_index_map[serialized_task.init.index] = task_node.init.key
- self.xgraph.add_node(task_node.key, instance=task_node, bipartite=NodeType.TASK.bipartite)
+ sort_index_map[serialized_task.init.index] = task_init_key
+ self.xgraph.add_node(task_key, instance=task_node, bipartite=NodeType.TASK.bipartite)
self.xgraph.add_node(
- task_node.init.key, instance=task_node.init, bipartite=NodeType.TASK_INIT.bipartite
+ task_init_key, instance=task_node.init, bipartite=NodeType.TASK_INIT.bipartite
)
- self.xgraph.add_edge(task_node.init.key, task_node.key, instance=None)
+ self.xgraph.add_edge(task_init_key, task_key, instance=None)
for read_edge in task_node.init.iter_all_inputs():
self.xgraph.add_edge(read_edge.dataset_type_key, read_edge.task_key, instance=read_edge)
for write_edge in task_node.init.iter_all_outputs():
@@ -131,58 +143,42 @@ def deserialize_graph(self, serialized_graph: SerializedPipelineGraph) -> None:
self.sort_keys = [sort_index_map[i] for i in range(len(self.xgraph))]
self.description = serialized_graph.description
- @abstractmethod
- def deserialize_dataset_type(self, name: str, serialized_dataset_type: SerializedDatasetTypeNode) -> _D:
- raise NotImplementedError()
-
- @abstractmethod
- def deserialize_task(self, label: str, serialized_task: SerializedTaskNode) -> _T:
- raise NotImplementedError()
-
- @abstractmethod
- def deserialize_task_subset(self, label: str, serialized_task_subset: SerializedTaskSubset) -> _S:
- raise NotImplementedError()
-
- def deserialize_task_init(
- self,
- label: str,
- serialized_task_init: SerializedTaskInitNode,
- task_class_name: str,
- config_str: str,
- ) -> TaskInitNode:
- key = NodeKey(NodeType.TASK_INIT, label)
- return TaskInitNode(
- key,
- inputs={
- self.deserialize_read_edge(key, parent_dataset_type_name, serialized_edge)
- for parent_dataset_type_name, serialized_edge in serialized_task_init.inputs.items()
- },
- outputs={
- self.deserialize_write_edge(key, parent_dataset_type_name, serialized_edge)
- for parent_dataset_type_name, serialized_edge in serialized_task_init.outputs.items()
- },
- config_output=self.deserialize_write_edge(
- key,
+ def deserialize_dataset_type(
+ self, key: NodeKey, serialized_dataset_type: SerializedDatasetTypeNode
+ ) -> DatasetTypeNode | None:
+ if serialized_dataset_type.dimensions is not None:
+ dataset_type = DatasetType(
+ key.name,
expect_not_none(
- serialized_task_init.config_output.dataset_type_name,
- "Serialized task config edges should always have a dataset type.",
+ serialized_dataset_type.dimensions,
+ f"Serialized dataset type {key.name!r} has no dimensions.",
),
- serialized_task_init.config_output,
- ),
- task_class_name=task_class_name,
- config_str=config_str,
- )
+ storageClass=expect_not_none(
+ serialized_dataset_type.storage_class,
+ f"Serialized dataset type {key.name!r} has no storage class.",
+ ),
+ isCalibration=serialized_dataset_type.is_calibration,
+ universe=self.universe,
+ )
+ return DatasetTypeNode(
+ dataset_type=dataset_type,
+ is_prerequisite=serialized_dataset_type.is_prerequisite,
+ is_initial_query_constraint=serialized_dataset_type.is_initial_query_constraint,
+ is_registered=serialized_dataset_type.is_registered,
+ )
+ return None
- def deserialize_task_args(self, label: str, serialized_task: SerializedTaskNode) -> dict[str, Any]:
+ def deserialize_task(
+ self, key: NodeKey, init_key: NodeKey, serialized_task: SerializedTaskNode
+ ) -> TaskNode:
init = self.deserialize_task_init(
- label,
+ init_key,
serialized_task.init,
task_class_name=serialized_task.task_class,
config_str=expect_not_none(
- serialized_task.config_str, f"No serialized config file for task with label {label!r}."
+ serialized_task.config_str, f"No serialized config file for task with label {key.name!r}."
),
)
- key = NodeKey(NodeType.TASK, label)
inputs = {
self.deserialize_read_edge(key, parent_dataset_type_name, serialized_edge)
for parent_dataset_type_name, serialized_edge in serialized_task.inputs.items()
@@ -214,15 +210,53 @@ def deserialize_task_args(self, label: str, serialized_task: SerializedTaskNode)
),
serialized_task.metadata_output,
)
- return dict(
- key=key,
+ dimensions: DimensionGraph | None = None
+ if serialized_task.dimensions is not None:
+ dimensions = expect_not_none(
+ self.universe,
+ f"Dimensions for task {key.name} were persisted, but dimension universe was not.",
+ ).extract(serialized_task.dimensions)
+ return TaskNode(
init=init,
inputs=inputs,
prerequisite_inputs=prerequisite_inputs,
outputs=outputs,
log_output=log_output,
metadata_output=metadata_output,
- dimensions=None,
+ dimensions=dimensions,
+ )
+
+ def deserialize_task_subset(self, label: str, serialized_task_subset: SerializedTaskSubset) -> TaskSubset:
+ members = set(serialized_task_subset.tasks)
+ return TaskSubset(self.xgraph, label, members, serialized_task_subset.description)
+
+ def deserialize_task_init(
+ self,
+ key: NodeKey,
+ serialized_task_init: SerializedTaskInitNode,
+ task_class_name: str,
+ config_str: str,
+ ) -> TaskInitNode:
+ return TaskInitNode(
+ key,
+ inputs={
+ self.deserialize_read_edge(key, parent_dataset_type_name, serialized_edge)
+ for parent_dataset_type_name, serialized_edge in serialized_task_init.inputs.items()
+ },
+ outputs={
+ self.deserialize_write_edge(key, parent_dataset_type_name, serialized_edge)
+ for parent_dataset_type_name, serialized_edge in serialized_task_init.outputs.items()
+ },
+ config_output=self.deserialize_write_edge(
+ key,
+ expect_not_none(
+ serialized_task_init.config_output.dataset_type_name,
+ "Serialized task config edges should always have a dataset type.",
+ ),
+ serialized_task_init.config_output,
+ ),
+ task_class_name=task_class_name,
+ config_str=config_str,
)
def deserialize_read_edge(
@@ -232,12 +266,7 @@ def deserialize_read_edge(
serialized_edge: SerializedEdge,
is_prerequisite: bool = False,
) -> ReadEdge:
- # Look up dataset type key in the graph, both to validate as we read
- # and to reduce the number of distinct but equivalent NodeKey instances
- # present in the graph.
- dataset_type_key = self.xgraph.nodes[NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name)][
- "instance"
- ].key
+ dataset_type_key = NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name)
return ReadEdge(
dataset_type_key,
task_key,
@@ -253,12 +282,7 @@ def deserialize_write_edge(
parent_dataset_type_name: str,
serialized_edge: SerializedEdge,
) -> WriteEdge:
- # Look up dataset type key in the graph, both to validate as we read
- # and to reduce the number of distinct but equivalent NodeKey instances
- # present in the graph.
- dataset_type_key = self.xgraph.nodes[NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name)][
- "instance"
- ].key
+ dataset_type_key = NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name)
return WriteEdge(
task_key=task_key,
dataset_type_key=dataset_type_key,
diff --git a/python/lsst/pipe/base/pipeline_graph/_mapping_views.py b/python/lsst/pipe/base/pipeline_graph/_mapping_views.py
index 244ff60c9..fee5fa7c8 100644
--- a/python/lsst/pipe/base/pipeline_graph/_mapping_views.py
+++ b/python/lsst/pipe/base/pipeline_graph/_mapping_views.py
@@ -26,12 +26,10 @@
import networkx
from ._dataset_types import DatasetTypeNode
-from ._nodes import Node, NodeKey, NodeType
+from ._nodes import NodeKey, NodeType
from ._tasks import TaskInitNode, TaskNode
-_N = TypeVar("_N", bound=Node, covariant=True)
-_T = TypeVar("_T", bound=TaskNode, covariant=True)
-_D = TypeVar("_D", bound=DatasetTypeNode, covariant=True)
+_N = TypeVar("_N", covariant=True)
class MappingView(Mapping[str, _N]):
@@ -97,7 +95,7 @@ def _make_keys(self, parent_keys: Iterable[NodeKey]) -> list[str]:
return [str(k) for k in parent_keys if k.node_type is self._NODE_TYPE]
-class TaskMappingView(MappingView[_T]):
+class TaskMappingView(MappingView[TaskNode]):
_NODE_TYPE = NodeType.TASK
@@ -105,5 +103,5 @@ class TaskInitMappingView(MappingView[TaskInitNode]):
_NODE_TYPE = NodeType.TASK_INIT
-class DatasetTypeMappingView(MappingView[_D]):
+class DatasetTypeMappingView(MappingView[DatasetTypeNode | None]):
_NODE_TYPE = NodeType.DATASET_TYPE
diff --git a/python/lsst/pipe/base/pipeline_graph/_mutable_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_mutable_pipeline_graph.py
deleted file mode 100644
index 3cd74f916..000000000
--- a/python/lsst/pipe/base/pipeline_graph/_mutable_pipeline_graph.py
+++ /dev/null
@@ -1,298 +0,0 @@
-# This file is part of pipe_base.
-#
-# Developed for the LSST Data Management System.
-# This product includes software developed by the LSST Project
-# (http://www.lsst.org).
-# See the COPYRIGHT file at the top-level directory of this distribution
-# for details of code ownership.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see .
-from __future__ import annotations
-
-__all__ = ("MutablePipelineGraph",)
-
-from collections.abc import Iterable
-from typing import TYPE_CHECKING, Any, BinaryIO, cast, final
-
-import networkx
-import networkx.algorithms.bipartite
-import networkx.algorithms.dag
-from lsst.daf.butler import Registry
-from lsst.resources import ResourcePathExpression
-
-from ._dataset_types import DatasetTypeNode, SerializedDatasetTypeNode
-from ._edges import Edge
-from ._exceptions import PipelineDataCycleError
-from ._io import PipelineGraphReader, SerializedPipelineGraph
-from ._nodes import Node, NodeKey, NodeType
-from ._pipeline_graph import PipelineGraph
-from ._task_subsets import MutableTaskSubset, SerializedTaskSubset
-from ._tasks import SerializedTaskNode, TaskNode, _TaskNodeImportedData
-
-if TYPE_CHECKING:
- from ..config import PipelineTaskConfig
- from ..connections import PipelineTaskConnections
- from ..pipelineTask import PipelineTask
- from ._resolved_pipeline_graph import ResolvedPipelineGraph
-
-
-@final
-class MutablePipelineGraph(PipelineGraph[TaskNode, DatasetTypeNode, MutableTaskSubset]):
- """A pipeline graph that can be modified in place.
-
- Notes
- -----
- Mutable pipeline graphs are not automatically sorted and are not checked
- for cycles until they are sorted, but they do remember when they've been
- sorted so repeated calls to `sort` with no modifications in between are
- fast.
-
- Mutable pipeline graphs never carry around resolved dimensions and dataset
- types, since the process of resolving dataset types in particular depends
- in subtle ways on having the full graph available. In other words, a graph
- that has its dataset types resolved as tasks are added to it could end up
- with different dataset types from a complete graph that is resolved all at
- once, and we don't want to deal with that kind of inconsistency.
- """
-
- @classmethod
- def read_stream(
- cls, stream: BinaryIO, import_and_configure: bool = True, check_edges: bool = True
- ) -> MutablePipelineGraph:
- serialized_graph = SerializedPipelineGraph.read_stream(stream)
- reader = MutablePipelineGraphReader()
- reader.deserialize_graph(serialized_graph)
- result = MutablePipelineGraph.__new__(MutablePipelineGraph)
- result._init_from_args(reader.xgraph, reader.sort_keys, reader.task_subsets, reader.description)
- if import_and_configure:
- result.import_and_configure_in_place(check_edges=check_edges)
- return result
-
- @classmethod
- def read_uri(
- cls, uri: ResourcePathExpression, import_and_configure: bool = True, check_edges: bool = True
- ) -> MutablePipelineGraph:
- return cast(
- MutablePipelineGraph,
- super().read_uri(uri, import_and_configure=import_and_configure, check_edges=check_edges),
- )
-
- @property
- def description(self) -> str:
- # Docstring inherited.
- return self._description
-
- @description.setter
- def description(self, value: str) -> None:
- # Docstring inherited.
- self._description = value
-
- def copy(self) -> MutablePipelineGraph:
- # Docstring inherited.
- xgraph = self._xgraph.copy()
- result = MutablePipelineGraph.__new__(MutablePipelineGraph)
- result._init_from_args(
- xgraph,
- self._sorted_keys,
- task_subsets={k: v._mutable_copy(xgraph) for k, v in self._task_subsets.items()},
- description=self._description,
- )
- return result
-
- def resolved(self, registry: Registry, *, redo: bool = False) -> ResolvedPipelineGraph:
- # Docstring inherited.
- from ._resolved_pipeline_graph import ResolvedPipelineGraph
-
- xgraph = self._xgraph.copy()
- for state in xgraph.nodes.values():
- node: Node = state["instance"]
- state["instance"] = node._resolved(xgraph, registry)
- result = ResolvedPipelineGraph.__new__(ResolvedPipelineGraph)
- result._init_from_args(
- xgraph,
- self._sorted_keys,
- task_subsets={k: v._resolved(xgraph) for k, v in self._task_subsets.items()},
- description=self._description,
- )
- result.universe = registry.dimensions
- return result
-
- def mutable_copy(self) -> MutablePipelineGraph:
- # Docstring inherited.
- return self.copy()
-
- def add_task(
- self,
- label: str,
- task_class: type[PipelineTask],
- config: PipelineTaskConfig,
- connections: PipelineTaskConnections | None = None,
- ) -> None:
- """Add a new task to the graph.
-
- Parameters
- ----------
- label : `str`
- Label for the task in the pipeline.
- task_class : `type` [ `PipelineTask` ]
- Class object for the task.
- config : `PipelineTaskConfig`
- Configuration for the task.
- connections : `PipelineTaskConnections`, optional
- Object that describes the dataset types used by the task. If not
- provided, one will be constructed from the given configuration. If
- provided, it is assumed that ``config`` has already been validated
- and frozen.
-
- Raises
- ------
- ConnectionTypeConsistencyError
- Raised if the task defines an edge's ``is_init`` or
- ``is_prerequisite`` flags in a way that is inconsistent with some
- other task in the graph.
- IncompatibleDatasetTypeError
- Raised if the task defines a dataset type differently from some
- other task in the graph. Note that checks for dataset type
- dimension consistency do not occur until the graph is resolved.
- ValueError
- Raised if configuration validation failed when constructing.
- ``connections``.
- PipelineDataCycleError
- Raised if the graph is cyclic after this addition.
- RuntimeError
- Raised if an unexpected exception (which will be chained) occurred
- at a stage that may have left the graph in an inconsistent state.
- Other exceptions should leave the graph unchanged.
- """
- # Make the task node, the corresponding state dict that will be held
- # by the networkx graph (which includes the node instance), and the
- # state dicts for the edges
- task_node = TaskNode._from_imported_data(
- label, _TaskNodeImportedData.configure(label, task_class, config, connections)
- )
- node_data: list[tuple[NodeKey, dict[str, Any]]] = [
- (task_node.key, {"instance": task_node, "bipartite": task_node.key.node_type.bipartite}),
- (
- task_node.init.key,
- {"instance": task_node.init, "bipartite": task_node.init.key.node_type.bipartite},
- ),
- ]
- # Convert the edge objects attached to the task node to networkx form.
- edge_data: list[tuple[NodeKey, NodeKey, dict[str, Any]]] = []
- for read_edge in task_node.init.iter_all_inputs():
- self._append_graph_from_edge(node_data, edge_data, read_edge)
- for write_edge in task_node.init.iter_all_outputs():
- self._append_graph_from_edge(node_data, edge_data, write_edge)
- for read_edge in task_node.prerequisite_inputs:
- self._append_graph_from_edge(node_data, edge_data, read_edge)
- for read_edge in task_node.inputs:
- self._append_graph_from_edge(node_data, edge_data, read_edge)
- for write_edge in task_node.iter_all_outputs():
- self._append_graph_from_edge(node_data, edge_data, write_edge)
- # Add a special edge (with no Edge instance) that connects the
- # TaskInitNode to the runtime TaskNode.
- edge_data.append((task_node.init.key, task_node.key, {"instance": None}))
- # Checks complete; time to start the actual modification, during which
- # it's hard to provide strong exception safety.
- self._reset()
- try:
- self._xgraph.add_nodes_from(node_data)
- self._xgraph.add_edges_from(edge_data)
- if not networkx.algorithms.dag.is_directed_acyclic_graph(self._xgraph):
- cycle = networkx.find_cycle(self._xgraph)
- raise PipelineDataCycleError(f"Cycle detected while adding task {label} graph: {cycle}.")
- except Exception:
- # First try to roll back our changes.
- try:
- self._xgraph.remove_edges_from(edge_data)
- self._xgraph.remove_nodes_from(key for key, _ in node_data)
- except Exception as err:
- raise RuntimeError(
- "Error while attempting to revert PipelineGraph modification has left the graph in "
- "an inconsistent state."
- ) from err
- # Successfully rolled back; raise the original exception.
- raise
-
- def add_task_subset(self, subset_label: str, task_labels: Iterable[str], description: str = "") -> None:
- """Add a label for a set of tasks that are already in the pipeline.
-
- Parameters
- ----------
- subset_label : `str`
- Label for this set of tasks.
- task_labels : `~collections.abc.Iterable` [ `str` ]
- Labels of the tasks to include in the set. All must already be
- included in the graph.
- description : `str`, optional
- String description to associate with this label.
- """
- subset = MutableTaskSubset(self._xgraph, subset_label, set(task_labels), description)
- self._task_subsets[subset_label] = subset
-
- def _append_graph_from_edge(
- self,
- node_data: list[tuple[NodeKey, dict[str, Any]]],
- edge_data: list[tuple[NodeKey, NodeKey, dict[str, Any]]],
- edge: Edge,
- ) -> None:
- """Append networkx state dictionaries for an edge and the corresponding
- dataset type node.
-
- Parameters
- ----------
- node_data : `list`
- List of node keys and state dictionaries. A node is appended if
- one does not already exist for this dataset type.
- edge_data : `list`
- List of node key pairs and state dictionaries for edges.
- edge : `Edge`
- New edge being processed.
- """
- if (existing_dataset_type_state := self._xgraph.nodes.get(edge.dataset_type_key)) is not None:
- dataset_type_node = existing_dataset_type_state["instance"]
- edge._check_dataset_type(self._xgraph, dataset_type_node)
- else:
- dataset_type_node = DatasetTypeNode(edge.dataset_type_key)
- node_data.append(
- (
- edge.dataset_type_key,
- {
- "instance": dataset_type_node,
- "bipartite": NodeType.DATASET_TYPE.bipartite,
- },
- )
- )
- edge_data.append(edge.key + ({"instance": edge},))
-
-
-class MutablePipelineGraphReader(PipelineGraphReader[TaskNode, DatasetTypeNode, MutableTaskSubset]):
- def deserialize_dataset_type(
- self, name: str, serialized_dataset_type: SerializedDatasetTypeNode
- ) -> DatasetTypeNode:
- return DatasetTypeNode(NodeKey(NodeType.DATASET_TYPE, name))
-
- def deserialize_task(self, label: str, serialized_task: SerializedTaskNode) -> TaskNode:
- return TaskNode(**self.deserialize_task_args(label, serialized_task))
-
- def deserialize_task_subset(
- self, label: str, serialized_task_subset: SerializedTaskSubset
- ) -> MutableTaskSubset:
- members = set(serialized_task_subset.tasks)
- return MutableTaskSubset(self.xgraph, label, members, serialized_task_subset.description)
-
- def finish(self) -> MutablePipelineGraph:
- result = MutablePipelineGraph.__new__(MutablePipelineGraph)
- result._init_from_args(self.xgraph, self.sort_keys, self.task_subsets, self.description)
- return result
diff --git a/python/lsst/pipe/base/pipeline_graph/_nodes.py b/python/lsst/pipe/base/pipeline_graph/_nodes.py
index d726a32a6..10c303abb 100644
--- a/python/lsst/pipe/base/pipeline_graph/_nodes.py
+++ b/python/lsst/pipe/base/pipeline_graph/_nodes.py
@@ -21,21 +21,12 @@
from __future__ import annotations
__all__ = (
- "Node",
"NodeKey",
"NodeType",
)
import enum
-from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, Any, NamedTuple
-
-import networkx
-from lsst.daf.butler import Registry
-from lsst.utils.classes import immutable
-
-if TYPE_CHECKING:
- from pydantic import BaseModel
+from typing import NamedTuple
class NodeType(enum.Enum):
@@ -84,75 +75,3 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return self.name
-
-
-@immutable
-class Node(ABC):
- """Base class for nodes in a pipeline graph.
-
- Parameters
- ----------
- key : `NodeKey`
- The key for this node in networkx graphs.
- """
-
- def __init__(self, key: NodeKey):
- self.key = key
-
- key: NodeKey
- """The key for this node in networkx graphs."""
-
- @abstractmethod
- def _resolved(self, xgraph: networkx.DiGraph, registry: Registry) -> Node:
- """Resolve any dataset type and dimension names in this graph.
-
- Parameters
- ----------
- xgraph : `networkx.DiGraph`
- Directed bipartite graph representing the full pipeline. Should
- not be modified.
- registry : `lsst.daf.butler.Registry`
- Registry that provides dimension and dataset type information.
-
- Returns
- -------
- node : `Node`
- Resolved version of this node. May be self if the node is already
- resolved.
- """
- raise NotImplementedError()
-
- @abstractmethod
- def _unresolved(self) -> Node:
- """Revert this node to a form that just holds names for dataset types
- and dimensions, allowing `_reresolve` to have an effect if called
- again.
-
- Returns
- -------
- node : `Node`
- Resolved version of this node. May be self if the node is already
- resolved.
- """
- raise NotImplementedError()
-
- @abstractmethod
- def _serialize(self) -> BaseModel:
- raise NotImplementedError()
-
- @abstractmethod
- def _to_xgraph_state(self) -> dict[str, Any]:
- """Unpack the content of this node into a dictionary that can be used
- as the state dictionary for an external networkx graph.
-
- Unlike `_serialize`, this may hold types that are not directly suitable
- for JSON conversion, and it does not need to hold any edge state. Like
- `_serialize`, this should not include the node's key, as that is always
- included in the graph separately.
-
- Returns
- -------
- state : `dict`
- Dictionary for an external networkx graph.
- """
- raise NotImplementedError()
diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py
index fdf62ae79..81fbfd642 100644
--- a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py
+++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py
@@ -25,82 +25,78 @@
import io
import os
import tarfile
-from abc import abstractmethod
-from collections.abc import Iterator, Mapping, Sequence
+from collections.abc import Iterable, Iterator, Mapping, Sequence
from datetime import datetime
-from typing import TYPE_CHECKING, Any, BinaryIO, Generic, TypeVar, cast
+from typing import TYPE_CHECKING, Any, BinaryIO, Literal, cast
import networkx
import networkx.algorithms.bipartite
import networkx.algorithms.dag
-from lsst.daf.butler import Registry
+from lsst.daf.butler import DimensionGraph, DimensionUniverse, Registry
from lsst.resources import ResourcePath, ResourcePathExpression
+from ._dataset_types import DatasetTypeNode, SerializedDatasetTypeNode
from ._edges import Edge, ReadEdge, WriteEdge
-from ._exceptions import PipelineDataCycleError
-from ._io import SerializedPipelineGraph
-from ._mapping_views import _D, _T, DatasetTypeMappingView, TaskMappingView
-from ._nodes import Node, NodeKey, NodeType
+from ._exceptions import PipelineDataCycleError, ReadInconsistencyError, UnresolvedGraphError
+from ._io import PipelineGraphReader, SerializedPipelineGraph
+from ._mapping_views import DatasetTypeMappingView, TaskMappingView
+from ._nodes import NodeKey, NodeType
from ._task_subsets import TaskSubset
+from ._tasks import TaskInitNode, TaskNode, _TaskNodeImportedData
if TYPE_CHECKING:
+ from ..config import PipelineTaskConfig
+ from ..connections import PipelineTaskConnections
from ..pipeline import TaskDef
- from ._extract_helper import ExtractHelper
- from ._mutable_pipeline_graph import MutablePipelineGraph
- from ._resolved_pipeline_graph import ResolvedPipelineGraph
+ from ..pipelineTask import PipelineTask
-_S = TypeVar("_S", bound="TaskSubset", covariant=True)
-_P = TypeVar("_P", bound="PipelineGraph", covariant=True)
-
-
-class PipelineGraph(Generic[_T, _D, _S]):
- """A base class for directed acyclic graph of `PipelineTask` definitions.
-
- This abstract base class should not be inherited from outside its package;
- it exists to share code and interfaces between `MutablePipelineGraph` and
- `ResolvedPipelineGraph`.
- """
-
- def __init__(self) -> None:
- self._init_from_args()
+class PipelineGraph:
+ def __init__(self, universe: DimensionUniverse | None = None) -> None:
+ self._init_from_args(
+ xgraph=None, sorted_keys=None, task_subsets=None, description="", universe=universe
+ )
def _init_from_args(
self,
- xgraph: networkx.DiGraph | None = None,
- sorted_keys: Sequence[NodeKey] | None = None,
- task_subsets: dict[str, _S] | None = None,
- description: str = "",
+ xgraph: networkx.DiGraph | None,
+ sorted_keys: Sequence[NodeKey] | None,
+ task_subsets: dict[str, TaskSubset] | None,
+ description: str,
+ universe: DimensionUniverse | None,
) -> None:
"""Initialize the graph with possibly-nontrivial arguments.
Parameters
----------
- xgraph : `networkx.DiGraph` or `None`, optional
+ xgraph : `networkx.DiGraph` or `None`
The backing networkx graph, or `None` to create an empty one.
- sorted_keys : `Sequence` [ `NodeKey` ] or `None`, optional
+ sorted_keys : `Sequence` [ `NodeKey` ] or `None`
Topologically sorted sequence of node keys, or `None` if the graph
is not sorted.
- task_subsets : `dict` [ `str`, `TaskSubsetMapping` ], optional
+ task_subsets : `dict` [ `str`, `TaskSubset` ]
Labeled subsets of tasks. Values must be constructed with
``xgraph`` as their parent graph.
- description : `str`, optional
+ description : `str`
String description for this pipeline.
+ universe : `lsst.daf.butler.DimensionUniverse` or `None`
+ Definitions of all dimensions.
Notes
-----
- Only empty `PipelineGraph` [subclass] instances should be constructed
- directly by users, which sets the signature of ``__init__`` itself, but
- methods on `PipelineGraph` and its helper classes need to be able to
- create them with state. Those methods can call this after calling
- ``__new__`` manually.
+ Only empty `PipelineGraph` instances should be constructed directly by
+ users, which sets the signature of ``__init__`` itself, but methods on
+ `PipelineGraph` and its helper classes need to be able to create them
+ with state. Those methods can call this after calling ``__new__``
+ manually.
"""
self._xgraph = xgraph if xgraph is not None else networkx.DiGraph()
- self._sorted_keys: Sequence[NodeKey] | None = None
+ self._sorted_keys: Sequence[NodeKey] | None
self._task_subsets = task_subsets if task_subsets is not None else {}
self._description = description
- self._tasks = TaskMappingView[_T](self._xgraph)
- self._dataset_types = DatasetTypeMappingView[_D](self._xgraph)
+ self._tasks = TaskMappingView(self._xgraph)
+ self._dataset_types = DatasetTypeMappingView(self._xgraph)
+ self._universe = universe
if sorted_keys is not None:
self._reorder(sorted_keys)
@@ -112,16 +108,26 @@ def description(self) -> str:
"""String description for this pipeline."""
return self._description
+ @description.setter
+ def description(self, value: str) -> None:
+ # Docstring in setter.
+ self._description = value
+
+ @property
+ def universe(self) -> DimensionUniverse | None:
+ """Definitions for all dimensions."""
+ return self._universe
+
@property
- def tasks(self) -> TaskMappingView[_T]:
+ def tasks(self) -> TaskMappingView:
return self._tasks
@property
- def dataset_types(self) -> DatasetTypeMappingView[_D]:
+ def dataset_types(self) -> DatasetTypeMappingView:
return self._dataset_types
@property
- def task_subsets(self) -> Mapping[str, _S]:
+ def task_subsets(self) -> Mapping[str, TaskSubset]:
"""Mapping of all labeled subsets of tasks.
Keys are subset labels, values are Task-only graphs (subgraphs of
@@ -135,29 +141,32 @@ def iter_edges(self, init: bool = False) -> Iterator[Edge]:
if edge is not None and edge.is_init == init:
yield edge
- def iter_nodes(self) -> Iterator[Node]:
+ def iter_nodes(
+ self,
+ ) -> Iterator[
+ tuple[Literal[NodeType.TASK_INIT], str, TaskInitNode]
+ | tuple[Literal[NodeType.TASK], str, TaskInitNode]
+ | tuple[Literal[NodeType.DATASET_TYPE], str, DatasetTypeNode | None]
+ ]:
+ key: NodeKey
if self._sorted_keys is not None:
for key in self._sorted_keys:
- yield self._xgraph.nodes[key]["instance"]
+ yield key.node_type, key.name, self._xgraph.nodes[key]["instance"] # type: ignore
else:
- for _, node in self._xgraph.nodes(data="instance"):
- yield node
+ for key, node in self._xgraph.nodes(data="instance"):
+ yield key.node_type, key.name, node # type: ignore
- def iter_overall_inputs(self) -> Iterator[_D]:
+ def iter_overall_inputs(self) -> Iterator[tuple[str, DatasetTypeNode | None]]:
for generation in networkx.algorithms.dag.topological_generations(self._xgraph):
+ key: NodeKey
for key in generation:
# While we expect all tasks to have at least one input and
# hence never appear in the first topological generation, that
# is not true of task init nodes.
if key.node_type is NodeType.DATASET_TYPE:
- yield self._xgraph.nodes[key]["instance"]
+ yield key.name, self._xgraph.nodes[key]["instance"]
return
- def import_and_configure_in_place(self, check_edges: bool = True) -> None:
- # TODO: docs
- for task in self.tasks.values():
- task.import_and_configure(check_edges=check_edges)
-
def make_xgraph(self) -> networkx.DiGraph:
return self._transform_xgraph_state(self._xgraph.copy())
@@ -192,14 +201,62 @@ def _make_bipartite_xgraph_internal(self, init: bool) -> networkx.DiGraph:
def _transform_xgraph_state(self, xgraph: networkx.DiGraph) -> networkx.DiGraph:
state: dict[str, Any]
for state in xgraph.nodes.values():
- node: Node = state.pop("instance")
- state.update(node._to_xgraph_state())
+ node_value: TaskInitNode | TaskNode | DatasetTypeNode | None = state.pop("instance")
+ if node_value is not None:
+ state.update(node_value._to_xgraph_state())
for _, _, state in xgraph.edges(data=True):
edge: Edge | None = state.pop("instance", None)
if edge is not None:
state.update(edge._to_xgraph_state())
return xgraph
+ def group_by_dimensions(
+ self, prerequisites: bool = False
+ ) -> dict[DimensionGraph, tuple[list[TaskNode], list[DatasetTypeNode]]]:
+ """Group this graph's tasks and runtime dataset types by their
+ dimensions.
+
+ Parameters
+ ----------
+ prerequisites : `bool`, optional
+ If `True`, include prerequisite dataset types as well as regular
+ input and output datasets (including intermediates).
+
+ Returns
+ -------
+ groups : `dict` [ `DimensionGraph`, `tuple` ]
+ A dictionary of groups keyed by `DimensionGraph`, which each value
+ a tuple of:
+
+ - a `list` of `TaskNode` instances
+ - a `list` of `ResolvedDatasetTypeNode` instances
+
+ that have those dimensions.
+
+ Notes
+ -----
+ Init inputs and outputs are always included, but always have empty
+ dimensions and are hence easily filtered out.
+ """
+ result: dict[DimensionGraph, tuple[list[TaskNode], list[DatasetTypeNode]]] = {}
+ next_new_value: tuple[list[TaskNode], list[DatasetTypeNode]] = ([], [])
+ for task_label, task_node in self.tasks.items():
+ if task_node.dimensions is None:
+ raise UnresolvedGraphError(f"Task with label {task_label!r} has not been resolved.")
+ if (group := result.setdefault(task_node.dimensions, next_new_value)) is next_new_value:
+ next_new_value = ([], []) # make new lists for next time
+ group[0].append(task_node)
+ for dataset_type_name, dataset_type_node in self.dataset_types.items():
+ if dataset_type_node is None:
+ raise UnresolvedGraphError(f"Dataset type {dataset_type_name!r} has not been resolved.")
+ if not dataset_type_node.is_prerequisite or prerequisites:
+ if (
+ group := result.setdefault(dataset_type_node.dataset_type.dimensions, next_new_value)
+ ) is next_new_value:
+ next_new_value = ([], []) # make new lists for next time
+ group[1].append(dataset_type_node)
+ return result
+
@property
def is_sorted(self) -> bool:
"""Whether this graph's tasks and dataset types are topologically
@@ -288,63 +345,116 @@ def consumers_of(self, dataset_type_name: str) -> dict[str, ReadEdge]:
)
}
- def extract(self) -> ExtractHelper:
- """Create a new `MutablePipelineGraph` containing just the tasks that
- match the given criteria.
- """
- from ._extract_helper import ExtractHelper
+ def add_task(
+ self,
+ label: str,
+ task_class: type[PipelineTask],
+ config: PipelineTaskConfig,
+ connections: PipelineTaskConnections | None = None,
+ ) -> None:
+ """Add a new task to the graph.
- return ExtractHelper(self)
+ Parameters
+ ----------
+ label : `str`
+ Label for the task in the pipeline.
+ task_class : `type` [ `PipelineTask` ]
+ Class object for the task.
+ config : `PipelineTaskConfig`
+ Configuration for the task.
+ connections : `PipelineTaskConnections`, optional
+ Object that describes the dataset types used by the task. If not
+ provided, one will be constructed from the given configuration. If
+ provided, it is assumed that ``config`` has already been validated
+ and frozen.
- def _reorder(self, sorted_keys: Sequence[NodeKey]) -> None:
- """Set the order of all views of this graph from the given sorted
- sequence of task labels and dataset type names.
+ Raises
+ ------
+ ConnectionTypeConsistencyError
+ Raised if the task defines an edge's ``is_init`` or
+ ``is_prerequisite`` flags in a way that is inconsistent with some
+ other task in the graph.
+ IncompatibleDatasetTypeError
+ Raised if the task defines a dataset type differently from some
+ other task in the graph. Note that checks for dataset type
+ dimension consistency do not occur until the graph is resolved.
+ ValueError
+ Raised if configuration validation failed when constructing.
+ ``connections``.
+ PipelineDataCycleError
+ Raised if the graph is cyclic after this addition.
+ RuntimeError
+ Raised if an unexpected exception (which will be chained) occurred
+ at a stage that may have left the graph in an inconsistent state.
+ Other exceptions should leave the graph unchanged.
"""
- self._sorted_keys = sorted_keys
- self._tasks._reorder(sorted_keys)
- self._dataset_types._reorder(sorted_keys)
+ key = NodeKey(NodeType.TASK, label)
+ init_key = NodeKey(NodeType.TASK_INIT, label)
+ task_node = TaskNode._from_imported_data(
+ key,
+ init_key,
+ _TaskNodeImportedData.configure(label, task_class, config, connections),
+ universe=self.universe,
+ )
+ self._add_task_nodes([(key, task_node)])
- def _reset(self) -> None:
- """Reset the all views of this graph following a modification that
- might invalidate them.
+ def remove_task(self, label: str) -> None:
+ key = NodeKey(NodeType.TASK, label)
+ self._remove_task_nodes({key: self._xgraph[key]["instance"]})
+
+ def add_task_subset(self, subset_label: str, task_labels: Iterable[str], description: str = "") -> None:
+ """Add a label for a set of tasks that are already in the pipeline.
+
+ Parameters
+ ----------
+ subset_label : `str`
+ Label for this set of tasks.
+ task_labels : `~collections.abc.Iterable` [ `str` ]
+ Labels of the tasks to include in the set. All must already be
+ included in the graph.
+ description : `str`, optional
+ String description to associate with this label.
"""
- self._sorted_keys = None
- self._tasks._reset()
- self._dataset_types._reset()
+ subset = TaskSubset(self._xgraph, subset_label, set(task_labels), description)
+ self._task_subsets[subset_label] = subset
- @abstractmethod
- def copy(self: _P) -> _P:
+ def copy(self) -> PipelineGraph:
"""Return a copy of this graph that copies all mutable state."""
- raise NotImplementedError()
+ xgraph = self._xgraph.copy()
+ result = PipelineGraph.__new__(PipelineGraph)
+ result._init_from_args(
+ xgraph,
+ self._sorted_keys,
+ task_subsets={
+ k: TaskSubset(xgraph, v.label, set(v._members), v.description)
+ for k, v in self._task_subsets.items()
+ },
+ description=self._description,
+ universe=self.universe,
+ )
+ return result
- def __copy__(self: _P) -> _P:
+ def __copy__(self) -> PipelineGraph:
# Fully shallow copies are dangerous; we don't want shared mutable
# state to lead to broken class invariants.
return self.copy()
- def __deepcopy__(self: _P, memo: dict) -> _P:
- # Genuine deep copies are sometimes unnecessary, since we should only
- # ever care that mutable state is copied.
+ def __deepcopy__(self, memo: dict) -> PipelineGraph:
+ # Genuine deep copies are unnecessary, since we should only ever care
+ # that mutable state is copied.
return self.copy()
- @abstractmethod
- def resolved(self, registry: Registry, *, redo: bool = False) -> ResolvedPipelineGraph:
- """Return a version of this graph with all dimensions and dataset types
- resolved according to the given butler registry.
+ def import_and_configure(self, check: bool = True, rebuild: bool = False) -> None:
+ self._import_and_configure(check=check, rebuild=rebuild, universe=self._universe)
+
+ def resolve(self, registry: Registry, *, check: bool = True, rebuild: bool = False) -> None:
+ """Resolve all dimensions and dataset types.
Parameters
----------
registry : `lsst.daf.butler.Registry`
Client for the data repository to resolve against.
- redo : `bool`, optional
- If `True`, re-do the resolution even if the graph has already been
- resolved to pick up changes in the registry. If `False` (default)
- and the graph is already resolved, this method returns ``self``.
-
- Returns
- -------
- resolved : `ResolvedPipelineGraph`
- A resolved version of this graph. Always sorted and immutable.
+ TODO
Raises
------
@@ -358,69 +468,59 @@ def resolved(self, registry: Registry, *, redo: bool = False) -> ResolvedPipelin
are consumed with different storage classes or as components by
tasks in the pipeline.
"""
- raise NotImplementedError()
-
- @abstractmethod
- def mutable_copy(self) -> MutablePipelineGraph:
- """Return a mutable copy of this graph.
-
- Returns
- -------
- mutable : `MutablePipelineGraph`
- A mutable copy of this graph. This drops all dimension and
- dataset type resolutions that may be present in ``self``. See
- docs for `MutablePipelineGraph` for details.
- """
- raise NotImplementedError()
-
- def _iter_task_defs(self) -> Iterator[TaskDef]:
- """Iterate over this pipeline as a sequence of `TaskDef` instances.
-
- Notes
- -----
- This is a package-private method intended to aid in the transition to a
- codebase more fully integrated with the `PipelineGraph` class, in which
- both `TaskDef` and `PipelineDatasetTypes` are expected to go away, and
- much of the functionality on the `Pipeline` class will be moved to
- `PipelineGraph` as well.
- """
- from ..pipeline import TaskDef
-
- for node in self._tasks.values():
- yield TaskDef(
- config=node.config,
- taskClass=node.task_class,
- label=node.label,
- connections=node._get_connections(),
- )
+ self._import_and_configure(check=check, rebuild=rebuild, universe=registry.dimensions)
+ self.sort()
+ node_key: NodeKey
+ updates: dict[NodeKey, TaskNode | DatasetTypeNode] = {}
+ for node_key, node_state in self._xgraph.nodes.items():
+ match node_key.node_type:
+ case NodeType.TASK:
+ task_node: TaskNode = node_state["instance"]
+ new_task_node = task_node._resolved(registry.dimensions)
+ if new_task_node is not task_node:
+ updates[node_key] = new_task_node
+ case NodeType.DATASET_TYPE:
+ dataset_type_node: DatasetTypeNode | None = node_state["instance"]
+ if dataset_type_node is None:
+ updates[node_key] = DatasetTypeNode._from_edges(node_key, self._xgraph, registry)
+ else:
+ new_dataset_type_node = dataset_type_node._resolved(registry)
+ if new_dataset_type_node is not dataset_type_node:
+ updates[node_key] = new_dataset_type_node
+ for node_key, node_value in updates.items():
+ self._xgraph.nodes[node_key]["instance"] = node_value
+ self._universe = registry.dimensions
@classmethod
def read_stream(
- cls, stream: BinaryIO, import_and_configure: bool = True, check_edges: bool = True
- ) -> MutablePipelineGraph | ResolvedPipelineGraph:
- from ._mutable_pipeline_graph import MutablePipelineGraphReader
- from ._resolved_pipeline_graph import ResolvedPipelineGraphReader
-
+ cls, stream: BinaryIO, import_and_configure: bool = True, check: bool = True, rebuild: bool = False
+ ) -> PipelineGraph:
serialized_graph = SerializedPipelineGraph.read_stream(stream)
- reader: MutablePipelineGraphReader | ResolvedPipelineGraphReader
- if serialized_graph.dimensions is None:
- reader = MutablePipelineGraphReader()
- else:
- reader = ResolvedPipelineGraphReader()
+ reader = PipelineGraphReader()
reader.deserialize_graph(serialized_graph)
- result = reader.finish()
+ result = PipelineGraph.__new__(PipelineGraph)
+ result._init_from_args(
+ reader.xgraph, reader.sort_keys, reader.task_subsets, reader.description, universe=reader.universe
+ )
if import_and_configure:
- result.import_and_configure_in_place(check_edges=check_edges)
+ result._import_and_configure(check=check, rebuild=rebuild, universe=result.universe)
return result
@classmethod
def read_uri(
- cls, uri: ResourcePathExpression, import_and_configure: bool = True, check_edges: bool = True
- ) -> MutablePipelineGraph | ResolvedPipelineGraph:
+ cls,
+ uri: ResourcePathExpression,
+ import_and_configure: bool = True,
+ check: bool = True,
+ rebuild: bool = False,
+ ) -> PipelineGraph:
uri = ResourcePath(uri)
with uri.open("rb") as stream:
return cls.read_stream(
- cast(BinaryIO, stream), import_and_configure=import_and_configure, check_edges=check_edges
+ cast(BinaryIO, stream),
+ import_and_configure=import_and_configure,
+ check=check,
+ rebuild=rebuild,
)
def write_stream(self, stream: BinaryIO, basename: str = "pipeline", compression: str = "gz") -> None:
@@ -524,8 +624,12 @@ def _serialize(self) -> SerializedPipelineGraph:
result = SerializedPipelineGraph.construct(
description=self.description,
tasks={label: node._serialize() for label, node in self.tasks.items()},
- dataset_types={name: node._serialize() for name, node in self.dataset_types.items()},
+ dataset_types={
+ name: node._serialize() if node is not None else SerializedDatasetTypeNode()
+ for name, node in self.dataset_types.items()
+ },
task_subsets={label: subset._serialize() for label, subset in self.task_subsets.items()},
+ dimensions=self.universe.dimensionConfig.toDict() if self.universe is not None else None,
)
if self._sorted_keys:
for index, node_key in enumerate(self._sorted_keys):
@@ -537,3 +641,180 @@ def _serialize(self) -> SerializedPipelineGraph:
case NodeType.TASK_INIT:
result.tasks[node_key.name].init.index = index
return result
+
+ def _iter_task_defs(self) -> Iterator[TaskDef]:
+ """Iterate over this pipeline as a sequence of `TaskDef` instances.
+
+ Notes
+ -----
+ This is a package-private method intended to aid in the transition to a
+ codebase more fully integrated with the `PipelineGraph` class, in which
+ both `TaskDef` and `PipelineDatasetTypes` are expected to go away, and
+ much of the functionality on the `Pipeline` class will be moved to
+ `PipelineGraph` as well.
+ """
+ from ..pipeline import TaskDef
+
+ for node in self._tasks.values():
+ yield TaskDef(
+ config=node.config,
+ taskClass=node.task_class,
+ label=node.label,
+ connections=node._get_connections(),
+ )
+
+ def _import_and_configure(self, check: bool, rebuild: bool, universe: DimensionUniverse | None) -> None:
+ if rebuild:
+ check = False
+ updates: dict[NodeKey, TaskNode] = {}
+ node_key: NodeKey
+ for node_key, node_state in self._xgraph.nodes.items():
+ if node_key.node_type is NodeType.TASK:
+ task_node: TaskNode = node_state["instance"]
+ new_task_node = task_node._imported_and_configured(node_key, rebuild or check, universe)
+ if new_task_node is not task_node:
+ updates[node_key] = new_task_node
+ if check:
+ messages = new_task_node.diff(task_node)
+ if messages:
+ messages.insert(
+ 0,
+ f"Imported and reconfigured edges for task {node_key.name!r} "
+ "differ from those persisted:",
+ )
+ raise ReadInconsistencyError("\n".join(messages))
+ if rebuild:
+ self._remove_task_nodes(updates)
+ self._add_task_nodes(updates.items())
+ else:
+ for node_key, task_node in updates.items():
+ self._xgraph.nodes[node_key]["instance"] = task_node
+ self._xgraph.nodes[task_node.init._key]["instance"] = task_node.init
+
+ def _add_task_nodes(self, nodes: Iterable[tuple[NodeKey, TaskNode]]) -> None:
+ node_data: list[tuple[NodeKey, dict[str, Any]]] = []
+ for key, task_node in nodes:
+ node_data.append((key, {"instance": task_node, "bipartite": key.node_type.bipartite}))
+ node_data.append(
+ (
+ task_node.init._key,
+ {"instance": task_node.init, "bipartite": task_node.init._key.node_type.bipartite},
+ )
+ )
+ # Convert the edge objects attached to the task node to networkx.
+ edge_data: list[tuple[NodeKey, NodeKey, dict[str, Any]]] = []
+ for read_edge in task_node.init.iter_all_inputs():
+ self._append_graph_data_from_edge(node_data, edge_data, read_edge)
+ for write_edge in task_node.init.iter_all_outputs():
+ self._append_graph_data_from_edge(node_data, edge_data, write_edge)
+ for read_edge in task_node.prerequisite_inputs:
+ self._append_graph_data_from_edge(node_data, edge_data, read_edge)
+ for read_edge in task_node.inputs:
+ self._append_graph_data_from_edge(node_data, edge_data, read_edge)
+ for write_edge in task_node.iter_all_outputs():
+ self._append_graph_data_from_edge(node_data, edge_data, write_edge)
+ # Add a special edge (with no Edge instance) that connects the
+ # TaskInitNode to the runtime TaskNode.
+ edge_data.append((task_node.init._key, key, {"instance": None}))
+ if not node_data and not edge_data:
+ return
+ # Checks complete; time to start the actual modification, during which
+ # it's hard to provide strong exception safety.
+ self._reset()
+ try:
+ self._xgraph.add_nodes_from(node_data)
+ self._xgraph.add_edges_from(edge_data)
+ if not networkx.algorithms.dag.is_directed_acyclic_graph(self._xgraph):
+ cycle = networkx.find_cycle(self._xgraph)
+ raise PipelineDataCycleError(f"Cycle detected while adding task {key.name} graph: {cycle}.")
+ except Exception:
+ # First try to roll back our changes.
+ try:
+ self._xgraph.remove_edges_from(edge_data)
+ self._xgraph.remove_nodes_from(key for key, _ in node_data)
+ except Exception as err:
+ raise RuntimeError(
+ "Error while attempting to revert PipelineGraph modification has left the graph in "
+ "an inconsistent state."
+ ) from err
+ # Successfully rolled back; raise the original exception.
+ raise
+
+ def _remove_task_nodes(self, nodes: Mapping[NodeKey, TaskNode]) -> None:
+ dataset_types: set[NodeKey] = set()
+ for task_key, task_node in nodes.items():
+ dataset_types.add(self._xgraph.predecessors(task_key))
+ dataset_types.add(self._xgraph.successors(task_key))
+ dataset_types.add(self._xgraph.predecessors(task_node.init._key))
+ dataset_types.add(self._xgraph.successors(task_node.init._key))
+ # Since there's an edge between the task and its init node, we'll
+ # have added those two nodes here, too, and we don't want that.
+ dataset_types.remove(task_node.init._key)
+ dataset_types.remove(task_key)
+ to_remove = list(nodes.keys())
+ to_unresolve: list[NodeKey] = []
+ for dataset_type_key in dataset_types:
+ related_tasks = set()
+ related_tasks.update(self._xgraph.predecessors(dataset_type_key))
+ related_tasks.update(self._xgraph.successors(dataset_type_key))
+ related_tasks.difference_update(nodes.keys())
+ if not related_tasks:
+ to_remove.append(dataset_type_key)
+ else:
+ to_unresolve.append(dataset_type_key)
+ for dataset_type_key in to_unresolve:
+ self._xgraph.nodes[dataset_type_key]["instance"] = None
+ if to_remove:
+ self._reset()
+ self._xgraph.remove_nodes_from(to_remove)
+
+ def _append_graph_data_from_edge(
+ self,
+ node_data: list[tuple[NodeKey, dict[str, Any]]],
+ edge_data: list[tuple[NodeKey, NodeKey, dict[str, Any]]],
+ edge: Edge,
+ ) -> None:
+ """Append networkx state dictionaries for an edge and the corresponding
+ dataset type node.
+
+ Parameters
+ ----------
+ node_data : `list`
+ List of node keys and state dictionaries. A node is appended if
+ one does not already exist for this dataset type.
+ edge_data : `list`
+ List of node key pairs and state dictionaries for edges.
+ edge : `Edge`
+ New edge being processed.
+ """
+ if (existing_dataset_type_state := self._xgraph.nodes.get(edge.dataset_type_key)) is not None:
+ existing_dataset_type_state["instance"] = edge._update_dataset_type(
+ self._xgraph, existing_dataset_type_state["instance"]
+ )
+ else:
+ node_data.append(
+ (
+ edge.dataset_type_key,
+ {
+ "instance": None,
+ "bipartite": NodeType.DATASET_TYPE.bipartite,
+ },
+ )
+ )
+ edge_data.append(edge.key + ({"instance": edge},))
+
+ def _reorder(self, sorted_keys: Sequence[NodeKey]) -> None:
+ """Set the order of all views of this graph from the given sorted
+ sequence of task labels and dataset type names.
+ """
+ self._sorted_keys = sorted_keys
+ self._tasks._reorder(sorted_keys)
+ self._dataset_types._reorder(sorted_keys)
+
+ def _reset(self) -> None:
+ """Reset the all views of this graph following a modification that
+ might invalidate them.
+ """
+ self._sorted_keys = None
+ self._tasks._reset()
+ self._dataset_types._reset()
diff --git a/python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py
deleted file mode 100644
index 1ea260f2c..000000000
--- a/python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py
+++ /dev/null
@@ -1,231 +0,0 @@
-# This file is part of pipe_base.
-#
-# Developed for the LSST Data Management System.
-# This product includes software developed by the LSST Project
-# (http://www.lsst.org).
-# See the COPYRIGHT file at the top-level directory of this distribution
-# for details of code ownership.
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see .
-from __future__ import annotations
-
-__all__ = ("ResolvedPipelineGraph",)
-
-from collections.abc import Sequence
-from typing import TYPE_CHECKING, BinaryIO, cast, final
-
-import networkx
-import networkx.algorithms.bipartite
-import networkx.algorithms.dag
-from lsst.daf.butler import DatasetType, DimensionConfig, DimensionGraph, DimensionUniverse, Registry
-from lsst.resources import ResourcePathExpression
-
-from ._dataset_types import ResolvedDatasetTypeNode, SerializedDatasetTypeNode
-from ._io import PipelineGraphReader, SerializedPipelineGraph, expect_not_none
-from ._nodes import Node, NodeKey, NodeType
-from ._pipeline_graph import PipelineGraph
-from ._task_subsets import ResolvedTaskSubset, SerializedTaskSubset
-from ._tasks import SerializedTaskNode, TaskNode
-
-if TYPE_CHECKING:
- from ._mutable_pipeline_graph import MutablePipelineGraph
-
-
-@final
-class ResolvedPipelineGraph(PipelineGraph[TaskNode, ResolvedDatasetTypeNode, ResolvedTaskSubset]):
- """An immutable pipeline graph with resolved dimensions and dataset types.
-
- Resolved pipeline graphs are sorted at construction and cannot be modified,
- so calling `sort` on them does nothing.
- """
-
- def __init__(self, universe: DimensionUniverse) -> None:
- super().__init__()
- self.universe = universe
-
- def _init_from_args(
- self,
- xgraph: networkx.DiGraph | None = None,
- sorted_keys: Sequence[NodeKey] | None = None,
- task_subsets: dict[str, ResolvedTaskSubset] | None = None,
- description: str = "",
- ) -> None:
- super()._init_from_args(xgraph, sorted_keys, task_subsets, description)
- super().sort()
-
- def sort(self) -> None:
- # Docstring inherited.
- assert self.is_sorted, "Sorted at construction and immutable."
-
- def copy(self) -> ResolvedPipelineGraph:
- # Docstring inherited.
- # Immutable types shouldn't actually be copied, since there's nothing
- # one could do with the copy that couldn't be done with the original.
- return self
-
- def resolved(self, registry: Registry, *, redo: bool = False) -> ResolvedPipelineGraph:
- # Docstring inherited.
- if redo:
- return self.mutable_copy().resolved(registry)
- return self
-
- def mutable_copy(self) -> MutablePipelineGraph:
- # Docstring inherited.
- from ._mutable_pipeline_graph import MutablePipelineGraph
-
- xgraph = self._xgraph.copy()
- for state in xgraph.nodes.values():
- node: Node = state["instance"]
- state["instance"] = node._unresolved()
- result = MutablePipelineGraph.__new__(MutablePipelineGraph)
- result._init_from_args(
- xgraph,
- self._sorted_keys,
- task_subsets={k: v._mutable_copy(xgraph) for k, v in self._task_subsets.items()},
- description=self._description,
- )
- return result
-
- def group_by_dimensions(
- self, prerequisites: bool = False
- ) -> dict[DimensionGraph, tuple[list[TaskNode], list[ResolvedDatasetTypeNode]]]:
- """Group this graph's tasks and runtime dataset types by their
- dimensions.
-
- Parameters
- ----------
- prerequisites : `bool`, optional
- If `True`, include prerequisite dataset types as well as regular
- input and output datasets (including intermediates).
-
- Returns
- -------
- groups : `dict` [ `DimensionGraph`, `tuple` ]
- A dictionary of groups keyed by `DimensionGraph`, which each value
- a tuple of:
-
- - a `list` of `TaskNode` instances
- - a `list` of `ResolvedDatasetTypeNode` instances
-
- that have those dimensions.
-
- Notes
- -----
- Init inputs and outputs are always included, but always have empty
- dimensions and are hence easily filtered out.
- """
- result: dict[DimensionGraph, tuple[list[TaskNode], list[ResolvedDatasetTypeNode]]] = {}
- next_new_value: tuple[list[TaskNode], list[ResolvedDatasetTypeNode]] = ([], [])
- for task_node in self.tasks.values():
- if (
- group := result.setdefault(cast(DimensionGraph, task_node.dimensions), next_new_value)
- ) is next_new_value:
- next_new_value = ([], []) # make new lists for next time
- group[0].append(task_node)
- for dataset_type_node in self.dataset_types.values():
- if not dataset_type_node.is_prerequisite or prerequisites:
- if (
- group := result.setdefault(dataset_type_node.dataset_type.dimensions, next_new_value)
- ) is next_new_value:
- next_new_value = ([], []) # make new lists for next time
- group[1].append(dataset_type_node)
- return result
-
- def _serialize(self) -> SerializedPipelineGraph:
- # Docstring inherited.
- result = super()._serialize()
- result.dimensions = self.universe.dimensionConfig.toDict()
- return result
-
- @classmethod
- def read_stream(
- cls, stream: BinaryIO, import_and_configure: bool = True, check_edges: bool = True
- ) -> ResolvedPipelineGraph:
- serialized_graph = SerializedPipelineGraph.read_stream(stream)
- reader = ResolvedPipelineGraphReader()
- reader.deserialize_graph(serialized_graph)
- result = ResolvedPipelineGraph.__new__(ResolvedPipelineGraph)
- result._init_from_args(reader.xgraph, reader.sort_keys, reader.task_subsets, reader.description)
- result.universe = reader.universe
- if import_and_configure:
- result.import_and_configure_in_place(check_edges=check_edges)
- return result
-
- @classmethod
- def read_uri(
- cls, uri: ResourcePathExpression, import_and_configure: bool = True, check_edges: bool = True
- ) -> ResolvedPipelineGraph:
- return cast(
- ResolvedPipelineGraph,
- super().read_uri(uri, import_and_configure=import_and_configure, check_edges=check_edges),
- )
-
-
-class ResolvedPipelineGraphReader(PipelineGraphReader[TaskNode, ResolvedDatasetTypeNode, ResolvedTaskSubset]):
- def deserialize_graph(
- self,
- serialized_graph: SerializedPipelineGraph,
- ) -> None:
- self.universe = DimensionUniverse(
- config=DimensionConfig(
- expect_not_none(
- serialized_graph.dimensions,
- "Serialized pipeline graph has not been resolved; "
- "load it is a MutablePipelineGraph instead.",
- )
- )
- )
- super().deserialize_graph(serialized_graph)
-
- def deserialize_dataset_type(
- self, name: str, serialized_dataset_type: SerializedDatasetTypeNode
- ) -> ResolvedDatasetTypeNode:
- dataset_type = DatasetType(
- name,
- expect_not_none(
- serialized_dataset_type.dimensions, f"Serialized dataset type {name!r} has no dimensions."
- ),
- storageClass=expect_not_none(
- serialized_dataset_type.storage_class,
- f"Serialized dataset type {name!r} has no storage class.",
- ),
- isCalibration=serialized_dataset_type.is_calibration,
- universe=self.universe,
- )
- return ResolvedDatasetTypeNode(
- key=NodeKey(NodeType.DATASET_TYPE, name),
- dataset_type=dataset_type,
- is_prerequisite=serialized_dataset_type.is_prerequisite,
- is_initial_query_constraint=serialized_dataset_type.is_initial_query_constraint,
- )
-
- def deserialize_task(self, label: str, serialized_task: SerializedTaskNode) -> TaskNode:
- kwargs = self.deserialize_task_args(label, serialized_task)
- kwargs["dimensions"] = self.universe.extract(
- expect_not_none(
- serialized_task.dimensions, f"Serialized task with label {label!r} has no dimensions."
- )
- )
- return TaskNode(**kwargs)
-
- def deserialize_task_subset(
- self, label: str, serialized_task_subset: SerializedTaskSubset
- ) -> ResolvedTaskSubset:
- members = set(serialized_task_subset.tasks)
- return ResolvedTaskSubset(self.xgraph, label, members, serialized_task_subset.description)
-
- def finish(self) -> ResolvedPipelineGraph:
- result = ResolvedPipelineGraph.__new__(ResolvedPipelineGraph)
- result._init_from_args(self.xgraph, self.sort_keys, self.task_subsets, self.description)
- return result
diff --git a/python/lsst/pipe/base/pipeline_graph/_task_subsets.py b/python/lsst/pipe/base/pipeline_graph/_task_subsets.py
index 215c88973..b5cdea09e 100644
--- a/python/lsst/pipe/base/pipeline_graph/_task_subsets.py
+++ b/python/lsst/pipe/base/pipeline_graph/_task_subsets.py
@@ -20,9 +20,9 @@
# along with this program. If not, see .
from __future__ import annotations
-__all__ = ("TaskSubset", "MutableTaskSubset", "ResolvedTaskSubset", "SerializedTaskSubset")
+__all__ = ("TaskSubset", "SerializedTaskSubset")
-from collections.abc import Iterator, MutableSet, Set
+from collections.abc import Iterator, MutableSet
import networkx
import networkx.algorithms.boundary
@@ -32,7 +32,7 @@
from ._nodes import NodeKey, NodeType
-class TaskSubset(Set[str]):
+class TaskSubset(MutableSet[str]):
"""An abstract base class whose instances represent a labeled subset of the
tasks in a pipeline.
@@ -70,6 +70,10 @@ def description(self) -> str:
"""Description string associated with this labeled subset."""
return self._description
+ @description.setter
+ def description(self, value: str) -> None:
+ self._description = value
+
def __str__(self) -> str:
return f"{self.label}: {self.description}, tasks={', '.join(iter(self))}"
@@ -82,53 +86,6 @@ def __len__(self) -> int:
def __iter__(self) -> Iterator[str]:
return iter(self._members)
- def _resolved(self, parent_xgraph: networkx.DiGraph) -> ResolvedTaskSubset:
- """Return a version of this view appropriate for a resolved pipeline
- graph.
-
- Parameters
- ----------
- parent_xgraph : `networkx.DiGraph`
- The new parent networkx graph that will back the new view.
-
- Returns
- -------
- resolved : `ResolvedTaskSubsetGraph`
- A resolved version of this object.
- """
- return ResolvedTaskSubset(parent_xgraph, self.label, self._members.copy(), self._description)
-
- def _mutable_copy(self, parent_xgraph: networkx.DiGraph) -> MutableTaskSubset:
- """Return a copy of this view appropriate for a mutable pipeline
- graph.
-
- Parameters
- ----------
- parent_xgraph : `networkx.DiGraph`
- The new parent networkx graph that will back the new view.
-
- Returns
- -------
- mutable : `MutableTaskSubsetGraph`
- A mutable version of this object.
- """
- return MutableTaskSubset(parent_xgraph, self.label, self._members.copy(), self._description)
-
- def _serialize(self) -> SerializedTaskSubset:
- return SerializedTaskSubset.construct(description=self._description, tasks=list(sorted(self)))
-
-
-class MutableTaskSubset(TaskSubset, MutableSet[str]):
- @property
- def description(self) -> str:
- # Docstring inherited.
- return self._description
-
- @description.setter
- def description(self, value: str) -> None:
- # Docstring inherited.
- self._description = value
-
def add(self, task_label: str) -> None:
"""Add a new task to this subset.
@@ -146,9 +103,8 @@ def add(self, task_label: str) -> None:
def discard(self, task_label: str) -> None:
self._members.discard(task_label)
-
-class ResolvedTaskSubset(TaskSubset):
- pass
+ def _serialize(self) -> SerializedTaskSubset:
+ return SerializedTaskSubset.construct(description=self._description, tasks=list(sorted(self)))
class SerializedTaskSubset(BaseModel):
diff --git a/python/lsst/pipe/base/pipeline_graph/_tasks.py b/python/lsst/pipe/base/pipeline_graph/_tasks.py
index f52dd62d4..326555d09 100644
--- a/python/lsst/pipe/base/pipeline_graph/_tasks.py
+++ b/python/lsst/pipe/base/pipeline_graph/_tasks.py
@@ -28,7 +28,7 @@
from typing import TYPE_CHECKING, Any
import networkx
-from lsst.daf.butler import DimensionGraph, Registry
+from lsst.daf.butler import DimensionGraph, DimensionUniverse, Registry
from lsst.utils.classes import immutable
from lsst.utils.doImport import doImportType
from lsst.utils.introspection import get_full_type_name
@@ -38,8 +38,8 @@
from ..connections import PipelineTaskConnections
from ..connectionTypes import BaseConnection, InitOutput, Output
from ._edges import ReadEdge, SerializedEdge, WriteEdge
-from ._exceptions import ReadInconsistencyError, TaskNotImportedError
-from ._nodes import Node, NodeKey, NodeType
+from ._exceptions import TaskNotImportedError
+from ._nodes import NodeKey, NodeType
if TYPE_CHECKING:
from ..config import PipelineTaskConfig
@@ -122,14 +122,12 @@ def configure(
@immutable
-class TaskInitNode(Node):
+class TaskInitNode:
"""A node in a pipeline graph that represents the construction of a
`PipelineTask`.
Parameters
----------
- key : `NodeKey`
- Key for this node in the graph.
inputs : `~collections.abc.Set` [ `ReadEdge` ]
Graph edges that represent inputs required just to construct an
instance of this task.
@@ -157,7 +155,7 @@ def __init__(
task_class_name: str | None = None,
config_str: str | None = None,
):
- super().__init__(key)
+ self._key = key
self.inputs = inputs
self.outputs = outputs
self.config_output = config_output
@@ -195,7 +193,7 @@ def __init__(
@property
def label(self) -> str:
"""Label of this configuration of a task in the pipeline."""
- return str(self.key)
+ return str(self._key)
@property
def is_imported(self) -> bool:
@@ -286,7 +284,7 @@ def _serialize(self) -> SerializedTaskInitNode:
def _to_xgraph_state(self) -> dict[str, Any]:
# Docstring inherited.
- result = {"task_class_name": self.task_class_name, "bipartite": self.key.node_type.bipartite}
+ result = {"task_class_name": self.task_class_name, "bipartite": NodeType.TASK_INIT.bipartite}
if hasattr(self, "_imported_data"):
result["task_class"] = self.task_class
result["config"] = self.config
@@ -316,14 +314,12 @@ class SerializedTaskInitNode(BaseModel):
@immutable
-class TaskNode(Node):
+class TaskNode:
"""A node in a pipeline graph that represents a labeled configuration of a
`PipelineTask`.
Parameters
----------
- key : `NodeKey`
- Key for this node in the graph.
init : `TaskInitNode`
Node representing the initialization of this task.
prerequisite_inputs : `~collections.abc.Set` [ `ReadEdge` ]
@@ -357,7 +353,6 @@ class TaskNode(Node):
def __init__(
self,
- key: NodeKey,
init: TaskInitNode,
*,
prerequisite_inputs: Set[ReadEdge],
@@ -367,7 +362,6 @@ def __init__(
metadata_output: WriteEdge,
dimensions: DimensionGraph | None,
):
- super().__init__(key)
self.init = init
self.prerequisite_inputs = prerequisite_inputs
self.inputs = inputs
@@ -378,25 +372,26 @@ def __init__(
@staticmethod
def _from_imported_data(
- label: str,
+ key: NodeKey,
+ init_key: NodeKey,
data: _TaskNodeImportedData,
+ universe: DimensionUniverse | None,
) -> TaskNode:
"""Construct from a `PipelineTask` type and its configuration.
Parameters
----------
- label : `str`
- Label for the task in the pipeline.
+ TODO
+
data : `_TaskNodeImportedData`
Internal data for the node.
+ universe : `lsst.daf.butler.DimensionUniverse` or `None`
+ Definitions of all dimensions.
Returns
-------
node : `TaskNode`
New task node.
- state: `dict` [ `str`, `Any` ]
- State object for the networkx representation of this node. The
- returned ``node`` object is the value of the "instance" key.
Raises
------
@@ -416,8 +411,7 @@ def _from_imported_data(
at a stage that may have left the graph in an inconsistent state.
All other exceptions should leave the graph unchanged.
"""
- key = NodeKey(NodeType.TASK, label)
- init_key = NodeKey(NodeType.TASK_INIT, label)
+
init_inputs = {
ReadEdge._from_connection_map(init_key, name, data.connection_map)
for name in data.connections.initInputs
@@ -447,7 +441,6 @@ def _from_imported_data(
imported_data=data,
)
instance = TaskNode(
- key=key,
init=init,
prerequisite_inputs=prerequisite_inputs,
inputs=inputs,
@@ -460,7 +453,7 @@ def _from_imported_data(
metadata_output=WriteEdge._from_connection_map(
key, acc.METADATA_OUTPUT_CONNECTION_NAME, data.connection_map
),
- dimensions=None,
+ dimensions=None if universe is None else universe.extract(data.connections.dimensions),
)
return instance
@@ -498,7 +491,7 @@ def _from_imported_data(
@property
def label(self) -> str:
"""Label of this configuration of a task in the pipeline."""
- return str(self.key)
+ return self.init.label
@property
def task_class(self) -> type[PipelineTask]:
@@ -545,66 +538,67 @@ def iter_all_outputs(self) -> Iterator[WriteEdge]:
if self.log_output is not None:
yield self.log_output
- def import_and_configure(self, check_edges: bool = True) -> None:
+ def _imported_and_configured(
+ self, key: NodeKey, rebuild: bool, universe: DimensionUniverse | None
+ ) -> TaskNode:
# TODO: docs
from ..pipelineTask import PipelineTask
if self.is_imported:
- return
+ return self
task_class = doImportType(self.task_class_name)
if not issubclass(task_class, PipelineTask):
raise TypeError(f"{self.task_class_name!r} is not a PipelineTask subclass.")
config = task_class.ConfigClass()
config.loadFromString(self.get_config_str())
- imported_data = _TaskNodeImportedData.configure(self.label, task_class, config)
- if check_edges:
- if messages := self.diff(self._from_imported_data(self.label, imported_data)):
- messages.insert(
- 0,
- f"Inconsistency between serialized and configured edges for task {self.label!r}:",
- )
- raise ReadInconsistencyError("\n".join(messages))
- self.init._imported_data = imported_data
+ imported_data = _TaskNodeImportedData.configure(key.name, task_class, config)
+ if rebuild:
+ return self._from_imported_data(
+ key,
+ self.init._key,
+ imported_data,
+ universe=universe,
+ )
+ else:
+ return TaskNode(
+ TaskInitNode(
+ self.init._key,
+ inputs=self.init.inputs,
+ outputs=self.init.outputs,
+ config_output=self.init.config_output,
+ imported_data=imported_data,
+ ),
+ prerequisite_inputs=self.prerequisite_inputs,
+ inputs=self.inputs,
+ outputs=self.outputs,
+ log_output=self.log_output,
+ metadata_output=self.metadata_output,
+ dimensions=(
+ universe.extract(self._get_connections().dimensions) if universe is not None else None
+ ),
+ )
def diff(self, other: TaskNode) -> list[str]:
# TODO: docs
return self.init.diff(other.init)
- def _get_connections(self) -> PipelineTaskConnections:
- # TODO: docs
- return self.init._get_connections()
-
- def _resolved(self, xgraph: networkx.DiGraph, registry: Registry) -> TaskNode:
- # Docstring inherited.
- if self.dimensions is not None:
- if self.dimensions.universe is registry.dimensions:
- return self
- return TaskNode(
- key=self.key,
- init=self.init,
- prerequisite_inputs=self.prerequisite_inputs,
- inputs=self.inputs,
- outputs=self.outputs,
- log_output=self.log_output,
- metadata_output=self.metadata_output,
- dimensions=registry.dimensions.extract(self._get_connections().dimensions),
- )
-
- def _unresolved(self) -> TaskNode:
- # Docstring inherited.
- if self.dimensions is None:
+ def _resolved(self, universe: DimensionUniverse) -> TaskNode:
+ if self.dimensions is not None and self.dimensions.universe is universe:
return self
return TaskNode(
- key=self.key,
init=self.init,
prerequisite_inputs=self.prerequisite_inputs,
inputs=self.inputs,
outputs=self.outputs,
log_output=self.log_output,
metadata_output=self.metadata_output,
- dimensions=None,
+ dimensions=universe.extract(self._get_connections().dimensions),
)
+ def _get_connections(self) -> PipelineTaskConnections:
+ # TODO: docs
+ return self.init._get_connections()
+
def _serialize(self) -> SerializedTaskNode:
# Docstring inherited.
return SerializedTaskNode.construct(
diff --git a/python/lsst/pipe/base/tests/pipelineStepTester.py b/python/lsst/pipe/base/tests/pipelineStepTester.py
index 00f4523ec..a63ec8aa1 100644
--- a/python/lsst/pipe/base/tests/pipelineStepTester.py
+++ b/python/lsst/pipe/base/tests/pipelineStepTester.py
@@ -26,9 +26,11 @@
import dataclasses
import unittest
+from typing import cast
from lsst.daf.butler import Butler, DatasetType
from lsst.pipe.base import Pipeline
+from lsst.pipe.base.pipeline_graph import DatasetTypeNode
@dataclasses.dataclass
@@ -88,25 +90,22 @@ def run(self, butler: Butler, test_case: unittest.TestCase) -> None:
pure_inputs: dict[str, str] = dict()
for suffix in self.step_suffixes:
- step_graph = Pipeline.from_uri(self.filename + suffix).to_graph().resolved(butler.registry)
+ step_graph = Pipeline.from_uri(self.filename + suffix).to_graph()
+ step_graph.resolve(butler.registry)
pure_inputs.update(
- {
- node.name: suffix
- for node in step_graph.iter_overall_inputs()
- if node.name not in all_outputs
- }
+ {name: suffix for name, _ in step_graph.iter_overall_inputs() if name not in all_outputs}
)
all_outputs.update(
{
- name: node.dataset_type
+ name: cast(DatasetTypeNode, node).dataset_type
for name, node in step_graph.dataset_types.items()
if step_graph.producer_of(name) is not None
}
)
for node in step_graph.dataset_types.values():
- butler.registry.registerDatasetType(node.dataset_type)
+ butler.registry.registerDatasetType(cast(DatasetTypeNode, node).dataset_type)
if not pure_inputs.keys() <= self.expected_inputs:
missing = [f"{k} ({pure_inputs[k]})" for k in pure_inputs.keys() - self.expected_inputs]
diff --git a/tests/test_pipeline_graph.py b/tests/test_pipeline_graph.py
index a0a3fb6cf..7ff704ec7 100644
--- a/tests/test_pipeline_graph.py
+++ b/tests/test_pipeline_graph.py
@@ -31,13 +31,7 @@
import lsst.utils.tests
from lsst.daf.butler import DatasetType, DimensionUniverse
from lsst.daf.butler.registry import MissingDatasetTypeError
-from lsst.pipe.base.pipeline_graph import (
- MutablePipelineGraph,
- NodeKey,
- NodeType,
- PipelineGraph,
- ResolvedPipelineGraph,
-)
+from lsst.pipe.base.pipeline_graph import NodeKey, NodeType, PipelineGraph
from lsst.pipe.base.tests import no_dimensions
_LOG = logging.getLogger(__name__)
@@ -62,109 +56,99 @@ def setUp(self) -> None:
# any of those. We add tasks in reverse order to better test sorting.
# There is one labeled task subset, 'only_b', with just 'b' in it.
self.description = "A pipeline for PipelineGraph unit tests."
- self.mgraph = MutablePipelineGraph()
- self.mgraph.description = self.description
+ self.graph = PipelineGraph()
+ self.graph.description = self.description
self.b_config = no_dimensions.NoDimensionsTestConfig()
self.b_config.connections.input = "intermediate"
- self.mgraph.add_task("b", no_dimensions.NoDimensionsTestTask, self.b_config)
+ self.graph.add_task("b", no_dimensions.NoDimensionsTestTask, self.b_config)
self.a_config = no_dimensions.NoDimensionsTestConfig()
self.a_config.connections.output = "intermediate"
- self.mgraph.add_task("a", no_dimensions.NoDimensionsTestTask, self.a_config)
- self.mgraph.add_task_subset("only_b", ["b"])
+ self.graph.add_task("a", no_dimensions.NoDimensionsTestTask, self.a_config)
+ self.graph.add_task_subset("only_b", ["b"])
self.dimensions = DimensionUniverse()
self.maxDiff = None
- def test_mutable_accessors(self) -> None:
- self.check_base_accessors(self.mgraph)
- self.assertTrue(repr(self.mgraph).startswith(f"MutablePipelineGraph({self.description!r}, tasks="))
+ def test_unresolved_accessors(self) -> None:
+ self.check_base_accessors(self.graph)
+ self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks="))
def test_sorting(self) -> None:
- """Test sort methods on MutablePipelineGraph."""
- self.assertFalse(self.mgraph.has_been_sorted)
- self.assertFalse(self.mgraph.is_sorted)
- self.mgraph.sort()
- self.check_sorted(self.mgraph)
-
- def test_mutable_xgraph_export(self) -> None:
- self.check_make_xgraph(self.mgraph, resolved=False)
- self.check_make_bipartite_xgraph(self.mgraph, resolved=False)
- self.check_make_task_xgraph(self.mgraph, resolved=False)
- self.check_make_dataset_type_xgraph(self.mgraph, resolved=False)
-
- def test_mutable_stream_io(self) -> None:
+ """Test sort methods on PipelineGraph."""
+ self.assertFalse(self.graph.has_been_sorted)
+ self.assertFalse(self.graph.is_sorted)
+ self.graph.sort()
+ self.check_sorted(self.graph)
+
+ def test_unresolved_xgraph_export(self) -> None:
+ self.check_make_xgraph(self.graph, resolved=False)
+ self.check_make_bipartite_xgraph(self.graph, resolved=False)
+ self.check_make_task_xgraph(self.graph, resolved=False)
+ self.check_make_dataset_type_xgraph(self.graph, resolved=False)
+
+ def test_unresolved_stream_io(self) -> None:
stream = io.BytesIO()
- self.mgraph.write_stream(stream)
+ self.graph.write_stream(stream)
stream.seek(0)
- roundtripped = MutablePipelineGraph.read_stream(stream)
+ roundtripped = PipelineGraph.read_stream(stream)
self.check_make_xgraph(roundtripped, resolved=False)
- def test_mutable_file_io(self) -> None:
+ def test_unresolved_file_io(self) -> None:
with lsst.utils.tests.getTempFilePath(".tar.gz") as filename:
- self.mgraph.write_uri(filename)
- roundtripped = MutablePipelineGraph.read_uri(filename)
+ self.graph.write_uri(filename)
+ roundtripped = PipelineGraph.read_uri(filename)
self.check_make_xgraph(roundtripped, resolved=False)
def test_resolved_accessors(self) -> None:
"""Test resolving a pipeline graph against a data repository."""
- rgraph = self.mgraph.resolved(MockRegistry(self.dimensions, {}))
- self.check_base_accessors(rgraph)
- self.check_sorted(rgraph)
- self.assertTrue(repr(rgraph).startswith(f"ResolvedPipelineGraph({self.description!r}, tasks="))
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
+ self.check_base_accessors(self.graph)
+ self.check_sorted(self.graph)
+ self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks="))
def test_resolved_xgraph_export(self) -> None:
- rgraph = self.mgraph.resolved(MockRegistry(self.dimensions, {}))
- self.check_make_xgraph(rgraph, resolved=True)
- self.check_make_bipartite_xgraph(rgraph, resolved=True)
- self.check_make_task_xgraph(rgraph, resolved=True)
- self.check_make_dataset_type_xgraph(rgraph, resolved=True)
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
+ self.check_make_xgraph(self.graph, resolved=True)
+ self.check_make_bipartite_xgraph(self.graph, resolved=True)
+ self.check_make_task_xgraph(self.graph, resolved=True)
+ self.check_make_dataset_type_xgraph(self.graph, resolved=True)
def test_resolved_stream_io(self) -> None:
- rgraph = self.mgraph.resolved(MockRegistry(self.dimensions, {}))
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
stream = io.BytesIO()
- rgraph.write_stream(stream)
+ self.graph.write_stream(stream)
stream.seek(0)
- roundtripped = ResolvedPipelineGraph.read_stream(stream)
+ roundtripped = PipelineGraph.read_stream(stream)
self.check_make_xgraph(roundtripped, resolved=True)
def test_resolved_file_io(self) -> None:
- rgraph = self.mgraph.resolved(MockRegistry(self.dimensions, {}))
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
with lsst.utils.tests.getTempFilePath(".tar.gz") as filename:
- rgraph.write_uri(filename)
- roundtripped = ResolvedPipelineGraph.read_uri(filename)
+ self.graph.write_uri(filename)
+ roundtripped = PipelineGraph.read_uri(filename)
self.check_make_xgraph(roundtripped, resolved=True)
- def test_mixed_io(self) -> None:
- """Test writing a ResolvedPipelineGraph and reading it as a
- MutablePipelineGraph.
- """
- rgraph = self.mgraph.resolved(MockRegistry(self.dimensions, {}))
- stream = io.BytesIO()
- rgraph.write_stream(stream)
- stream.seek(0)
- roundtripped = MutablePipelineGraph.read_stream(stream)
- self.check_make_xgraph(roundtripped, resolved=False)
-
- def test_mutable_copies(self) -> None:
- mcopy = self.mgraph.mutable_copy()
- self.assertIsNot(mcopy, self.mgraph)
- self.check_make_xgraph(mcopy, resolved=False)
- mcopy = copy.copy(self.mgraph)
- self.assertIsNot(mcopy, self.mgraph)
- self.check_make_xgraph(mcopy, resolved=False)
- mcopy = copy.deepcopy(self.mgraph)
- self.assertIsNot(mcopy, self.mgraph)
- self.check_make_xgraph(mcopy, resolved=False)
+ def test_unresolved_copies(self) -> None:
+ copy1 = self.graph.copy()
+ self.assertIsNot(copy1, self.graph)
+ self.check_make_xgraph(copy1, resolved=False)
+ copy2 = copy.copy(self.graph)
+ self.assertIsNot(copy2, self.graph)
+ self.check_make_xgraph(copy2, resolved=False)
+ copy3 = copy.deepcopy(self.graph)
+ self.assertIsNot(copy3, self.graph)
+ self.check_make_xgraph(copy3, resolved=False)
def test_resolved_copies(self) -> None:
- rgraph = self.mgraph.resolved(MockRegistry(self.dimensions, {}))
- self.assertIs(rgraph, rgraph.resolved(MockRegistry(self.dimensions, {})))
- self.assertIs(rgraph, copy.copy(rgraph))
- self.assertIs(rgraph, copy.deepcopy(rgraph))
- rcopy = rgraph.resolved(MockRegistry(self.dimensions, {}), redo=True)
- self.assertIsNot(rgraph, rcopy)
- self.check_make_xgraph(rcopy, resolved=True)
- mcopy = rgraph.mutable_copy()
- self.check_make_xgraph(mcopy, resolved=False)
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
+ copy1 = self.graph.copy()
+ self.assertIsNot(copy1, self.graph)
+ self.check_make_xgraph(copy1, resolved=True)
+ copy2 = copy.copy(self.graph)
+ self.assertIsNot(copy2, self.graph)
+ self.check_make_xgraph(copy2, resolved=True)
+ copy3 = copy.deepcopy(self.graph)
+ self.assertIsNot(copy3, self.graph)
+ self.check_make_xgraph(copy3, resolved=True)
def check_base_accessors(self, graph: PipelineGraph) -> None:
self.assertEqual(graph.description, self.description)
@@ -205,7 +189,7 @@ def check_base_accessors(self, graph: PipelineGraph) -> None:
},
)
self.assertEqual(
- {node.key for node in graph.iter_nodes()},
+ {(node_type, name) for node_type, name, _ in graph.iter_nodes()},
{
NodeKey(NodeType.TASK, "a"),
NodeKey(NodeType.TASK, "b"),
@@ -222,7 +206,7 @@ def check_base_accessors(self, graph: PipelineGraph) -> None:
NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
},
)
- self.assertEqual({node.name for node in graph.iter_overall_inputs()}, {"input"})
+ self.assertEqual({name for name, _ in graph.iter_overall_inputs()}, {"input"})
self.assertEqual({label for label in graph.consumers_of("input")}, {"a"})
self.assertEqual({label for label in graph.consumers_of("intermediate")}, {"b"})
self.assertEqual({label for label in graph.consumers_of("output")}, set())
@@ -234,7 +218,7 @@ def check_sorted(self, graph: PipelineGraph) -> None:
self.assertTrue(graph.has_been_sorted)
self.assertTrue(graph.is_sorted)
self.assertEqual(
- [node.key for node in graph.iter_nodes()],
+ [(node_type, name) for node_type, name, _ in graph.iter_nodes()],
[
# We only advertise that the order is topological and
# deterministic, so this test is slightly over-specified; there
@@ -274,39 +258,42 @@ def check_sorted(self, graph: PipelineGraph) -> None:
def check_make_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
xgraph = graph.make_xgraph()
- self.assertEqual(
- set(xgraph.edges),
+ expected_edges = (
{edge.key for edge in graph.iter_edges()}
| {edge.key for edge in graph.iter_edges(init=True)}
| {
(NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK, "a")),
(NodeKey(NodeType.TASK_INIT, "b"), NodeKey(NodeType.TASK, "b")),
- },
- )
- self.assertEqual(
- dict(xgraph.nodes.items()),
- {
- NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved),
- NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved),
- NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved),
- NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved),
- NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
- NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
- NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
- NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
- NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
- NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
- NodeKey(NodeType.DATASET_TYPE, "input"): self.get_expected_connection_node(
- "input", resolved, True
- ),
- NodeKey(NodeType.DATASET_TYPE, "intermediate"): self.get_expected_connection_node(
- "intermediate", resolved, False
- ),
- NodeKey(NodeType.DATASET_TYPE, "output"): self.get_expected_connection_node(
- "output", resolved, False
- ),
- },
+ }
)
+ test_edges = set(xgraph.edges)
+ self.assertEqual(test_edges, expected_edges)
+ expected_nodes = {
+ NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved),
+ NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved),
+ NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved),
+ NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "input"): self.get_expected_connection_node(
+ "input", resolved, True
+ ),
+ NodeKey(NodeType.DATASET_TYPE, "intermediate"): self.get_expected_connection_node(
+ "intermediate", resolved, False
+ ),
+ NodeKey(NodeType.DATASET_TYPE, "output"): self.get_expected_connection_node(
+ "output", resolved, False
+ ),
+ }
+ test_nodes = dict(xgraph.nodes.items())
+ self.assertEqual(set(test_nodes.keys()), set(expected_nodes.keys()))
+ for key, expected_node in expected_nodes.items():
+ test_node = test_nodes[key]
+ self.assertEqual(expected_node, test_node, key)
def check_make_bipartite_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
run_xgraph = graph.make_bipartite_xgraph()
@@ -439,6 +426,7 @@ def get_expected_config_node(self, label: str, resolved: bool) -> dict[str, Any]
),
"is_initial_query_constraint": False,
"is_prerequisite": False,
+ "is_registered": False,
"dimensions": self.dimensions.empty,
"storage_class_name": acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
"bipartite": 0,
@@ -456,6 +444,7 @@ def get_expected_log_node(self, label: str, resolved: bool) -> dict[str, Any]:
),
"is_initial_query_constraint": False,
"is_prerequisite": False,
+ "is_registered": False,
"dimensions": self.dimensions.empty,
"storage_class_name": acc.LOG_OUTPUT_STORAGE_CLASS,
"bipartite": 0,
@@ -473,6 +462,7 @@ def get_expected_metadata_node(self, label: str, resolved: bool) -> dict[str, An
),
"is_initial_query_constraint": False,
"is_prerequisite": False,
+ "is_registered": False,
"dimensions": self.dimensions.empty,
"storage_class_name": acc.METADATA_OUTPUT_STORAGE_CLASS,
"bipartite": 0,
@@ -492,6 +482,7 @@ def get_expected_connection_node(
),
"is_initial_query_constraint": is_initial_query_constraint,
"is_prerequisite": False,
+ "is_registered": False,
"dimensions": self.dimensions.empty,
"storage_class_name": "StructuredDataDict",
"bipartite": 0,