diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py index e2ab4335..8bfd6357 100644 --- a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py +++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py @@ -47,7 +47,7 @@ from ._mapping_views import DatasetTypeMappingView, TaskMappingView from ._nodes import NodeKey, NodeType from ._task_subsets import TaskSubset -from ._tasks import TaskInitNode, TaskNode, _TaskNodeImportedData +from ._tasks import TaskImportMode, TaskInitNode, TaskNode, _TaskNodeImportedData if TYPE_CHECKING: from ..config import PipelineTaskConfig @@ -1002,11 +1002,7 @@ def make_dataset_type_xgraph(self, init: bool = False) -> networkx.DiGraph: @classmethod def _read_stream( - cls, - stream: BinaryIO, - import_and_configure: bool = True, - check_edges_unchanged: bool = False, - assume_edges_unchanged: bool = False, + cls, stream: BinaryIO, import_mode: TaskImportMode = TaskImportMode.REQUIRE_CONSISTENT_EDGES ) -> PipelineGraph: """Read a serialized `PipelineGraph` from a file-like object. @@ -1015,15 +1011,11 @@ def _read_stream( stream : `BinaryIO` File-like object opened for binary reading, containing gzip-compressed JSON. - import_and_configure : `bool`, optional - If `True`, import and configure all tasks immediately (see the - `import_and_configure` method). If `False`, some `TaskNode` and - `TaskInitNode` attributes will not be available, but reading may be - much faster. - check_edges_unchanged : `bool`, optional - Forwarded to `import_and_configure` after reading. - assume_edges_unchanged : `bool`, optional - Forwarded to `import_and_configure` after reading. + import_mode : `TaskImportMode`, optional + Whether to import tasks, and how to reconcile any differences + between the imported task's connections and the those that were + persisted with the graph. Default is to check that they are the + same. Returns ------- @@ -1035,8 +1027,9 @@ def _read_stream( PipelineGraphReadError Raised if the serialized `PipelineGraph` is not self-consistent. EdgesChangedError - Raised if ``check_edges_unchanged=True`` and the edges of a task do - change after import and reconfiguration. + Raised if ``import_mode`` is + `TaskImportMode.REQUIRED_CONSISTENT_EDGES` and the edges of a task + did change after import and reconfiguration. Notes ----- @@ -1049,19 +1042,13 @@ def _read_stream( with gzip.open(stream, "rb") as uncompressed_stream: data = json.load(uncompressed_stream) serialized_graph = SerializedPipelineGraph.parse_obj(data) - return serialized_graph.deserialize( - import_and_configure=import_and_configure, - check_edges_unchanged=check_edges_unchanged, - assume_edges_unchanged=assume_edges_unchanged, - ) + return serialized_graph.deserialize(import_mode) @classmethod def _read_uri( cls, uri: ResourcePathExpression, - import_and_configure: bool = True, - check_edges_unchanged: bool = False, - assume_edges_unchanged: bool = False, + import_mode: TaskImportMode = TaskImportMode.REQUIRE_CONSISTENT_EDGES, ) -> PipelineGraph: """Read a serialized `PipelineGraph` from a file at a URI. @@ -1070,15 +1057,11 @@ def _read_uri( uri : convertible to `lsst.resources.ResourcePath` URI to a gzip-compressed JSON file containing a serialized pipeline graph. - import_and_configure : `bool`, optional - If `True`, import and configure all tasks immediately (see - the `import_and_configure` method). If `False`, some `TaskNode` - and `TaskInitNode` attributes will not be available, but reading - may be much faster. - check_edges_unchanged : `bool`, optional - Forwarded to `import_and_configure` after reading. - assume_edges_unchanged : `bool`, optional - Forwarded to `import_and_configure` after reading. + import_mode : `TaskImportMode`, optional + Whether to import tasks, and how to reconcile any differences + between the imported task's connections and the those that were + persisted with the graph. Default is to check that they are the + same. Returns ------- @@ -1090,8 +1073,9 @@ def _read_uri( PipelineGraphReadError Raised if the serialized `PipelineGraph` is not self-consistent. EdgesChangedError - Raised if ``check_edges_unchanged=True`` and the edges of a task do - change after import and reconfiguration. + Raised if ``import_mode`` is + `TaskImportMode.REQUIRED_CONSISTENT_EDGES` and the edges of a task + did change after import and reconfiguration. Notes ----- @@ -1101,12 +1085,7 @@ def _read_uri( """ uri = ResourcePath(uri) with uri.open("rb") as stream: - return cls._read_stream( - cast(BinaryIO, stream), - import_and_configure=import_and_configure, - check_edges_unchanged=check_edges_unchanged, - assume_edges_unchanged=assume_edges_unchanged, - ) + return cls._read_stream(cast(BinaryIO, stream), import_mode=import_mode) def _write_stream(self, stream: BinaryIO) -> None: """Write the pipeline to a file-like object. @@ -1164,31 +1143,26 @@ def _write_uri(self, uri: ResourcePathExpression) -> None: self._write_stream(cast(BinaryIO, stream)) def _import_and_configure( - self, check_edges_unchanged: bool = False, assume_edges_unchanged: bool = False + self, import_mode: TaskImportMode = TaskImportMode.REQUIRE_CONSISTENT_EDGES ) -> None: """Import the `PipelineTask` classes referenced by all task nodes and update those nodes accordingly. Parameters ---------- - check_edges_unchanged : `bool`, optional - If `True`, require the edges (connections) of the modified tasks to - remain unchanged after importing and configuring each task, and - verify that this is the case. - assume_edges_unchanged : `bool`, optional - If `True`, the caller declares that the edges (connections) of the - modified tasks will remain unchanged importing and configuring each - task, and that it is unnecessary to check this. + import_mode : `TaskImportMode`, optional + Whether to import tasks, and how to reconcile any differences + between the imported task's connections and the those that were + persisted with the graph. Default is to check that they are the + same. This method does nothing if this is + `TaskImportMode.DO_NOT_IMPORT`. Raises ------ - ValueError - Raised if ``assume_edges_unchanged`` and ``check_edges_unchanged`` - are both `True`, or if a full config is provided for a task after - another full config or an override has already been provided. EdgesChangedError - Raised if ``check_edges_unchanged=True`` and the edges of a task do - change. + Raised if ``import_mode`` is + `TaskImportMode.REQUIRED_CONSISTENT_EDGES` and the edges of a task + did change after import and reconfiguration. Notes ----- @@ -1202,13 +1176,19 @@ def _import_and_configure( usually because the software used to read a serialized graph is newer than the software used to write it (e.g. a new config option has been added, or the task was moved to a new module with a forwarding alias - left behind). These changes are allowed by ``check=True``. + left behind). These changes are allowed by + `TaskImportMode.REQUIRE_CONSISTENT_EDGES`. If importing and configuring a task causes its edges to change, any dataset type nodes linked to those edges will be reset to the unresolved state. """ - rebuild = check_edges_unchanged or not assume_edges_unchanged + if import_mode is TaskImportMode.DO_NOT_IMPORT: + return + rebuild = ( + import_mode is TaskImportMode.REQUIRE_CONSISTENT_EDGES + or import_mode is TaskImportMode.OVERRIDE_EDGES + ) updates: dict[str, TaskNode] = {} node_key: NodeKey for node_key, node_state in self._xgraph.nodes.items(): @@ -1219,8 +1199,8 @@ def _import_and_configure( updates[task_node.label] = new_task_node self._replace_task_nodes( updates, - check_edges_unchanged=check_edges_unchanged, - assume_edges_unchanged=assume_edges_unchanged, + check_edges_unchanged=(import_mode is TaskImportMode.REQUIRE_CONSISTENT_EDGES), + assume_edges_unchanged=(import_mode is TaskImportMode.ASSUME_CONSISTENT_EDGES), message_header=( "In task with label {task_label!r}, persisted edges (A)" "differ from imported and configured edges (B):" diff --git a/python/lsst/pipe/base/pipeline_graph/_tasks.py b/python/lsst/pipe/base/pipeline_graph/_tasks.py index 8c3b4e2b..d1d7b236 100644 --- a/python/lsst/pipe/base/pipeline_graph/_tasks.py +++ b/python/lsst/pipe/base/pipeline_graph/_tasks.py @@ -20,9 +20,10 @@ # along with this program. If not, see . from __future__ import annotations -__all__ = ("TaskNode", "TaskInitNode") +__all__ = ("TaskNode", "TaskInitNode", "TaskImportMode") import dataclasses +import enum from collections.abc import Iterator, Mapping from typing import TYPE_CHECKING, Any, cast @@ -43,6 +44,38 @@ from ..pipelineTask import PipelineTask +class TaskImportMode(enum.Enum): + """Enumeration of the ways to handle importing tasks when reading a + serialized PipelineGraph. + """ + + DO_NOT_IMPORT = enum.auto() + """Do not import tasks or instantiate their configs and connections.""" + + REQUIRE_CONSISTENT_EDGES = enum.auto() + """Import tasks and instantiate their config and connection objects, and + check that the connections still define the same edges. + """ + + ASSUME_CONSISTENT_EDGES = enum.auto() + """Import tasks and instantiate their config and connection objects, but do + not check that the connections still define the same edges. + + This is safe only when the caller knows the task definition has not changed + since the pipeline graph was persisted, such as when it was saved and + loaded with the same pipeline version. + """ + + OVERRIDE_EDGES = enum.auto() + """Import tasks and instantiate their config and connection objects, and + allow the edges defined in those connections to override those in the + persisted graph. + + This may cause dataset type nodes to be unresolved, since resolutions + consistent with the original edges may be invalidated. + """ + + @dataclasses.dataclass(frozen=True) class _TaskNodeImportedData: """An internal struct that holds `TaskNode` and `TaskInitNode` state that diff --git a/python/lsst/pipe/base/pipeline_graph/io.py b/python/lsst/pipe/base/pipeline_graph/io.py index 53013280..02506ac3 100644 --- a/python/lsst/pipe/base/pipeline_graph/io.py +++ b/python/lsst/pipe/base/pipeline_graph/io.py @@ -45,7 +45,7 @@ from ._nodes import NodeKey, NodeType from ._pipeline_graph import PipelineGraph from ._task_subsets import TaskSubset -from ._tasks import TaskInitNode, TaskNode +from ._tasks import TaskImportMode, TaskInitNode, TaskNode _U = TypeVar("_U") @@ -527,9 +527,7 @@ def serialize(cls, target: PipelineGraph) -> SerializedPipelineGraph: def deserialize( self, - import_and_configure: bool = True, - check_edges_unchanged: bool = False, - assume_edges_unchanged: bool = False, + import_mode: TaskImportMode, ) -> PipelineGraph: """Transform a `SerializedPipelineGraph` into a `PipelineGraph`.""" universe: DimensionUniverse | None = None @@ -615,9 +613,5 @@ def deserialize( universe=universe, data_id=self.data_id, ) - if import_and_configure: - result._import_and_configure( - check_edges_unchanged=check_edges_unchanged, - assume_edges_unchanged=assume_edges_unchanged, - ) + result._import_and_configure(import_mode) return result diff --git a/tests/test_pipeline_graph.py b/tests/test_pipeline_graph.py index 479e9202..44e568b0 100644 --- a/tests/test_pipeline_graph.py +++ b/tests/test_pipeline_graph.py @@ -41,6 +41,7 @@ NodeType, PipelineGraph, PipelineGraphError, + TaskImportMode, UnresolvedGraphError, ) from lsst.pipe.base.tests.mocks import ( @@ -154,12 +155,12 @@ def test_unresolved_deferred_import_io(self) -> None: stream = io.BytesIO() self.graph._write_stream(stream) stream.seek(0) - roundtripped = PipelineGraph._read_stream(stream, import_and_configure=False) + roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT) self.check_make_xgraph(roundtripped, resolved=False, imported_and_configured=False) # Check that we can still resolve the graph without importing tasks. roundtripped.resolve(MockRegistry(self.dimensions, {})) self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) - roundtripped._import_and_configure(assume_edges_unchanged=True) + roundtripped._import_and_configure(TaskImportMode.ASSUME_CONSISTENT_EDGES) self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) def test_resolved_accessors(self) -> None: @@ -221,9 +222,9 @@ def test_resolved_deferred_import_io(self) -> None: stream = io.BytesIO() self.graph._write_stream(stream) stream.seek(0) - roundtripped = PipelineGraph._read_stream(stream, import_and_configure=False) + roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT) self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) - roundtripped._import_and_configure(check_edges_unchanged=True) + roundtripped._import_and_configure(TaskImportMode.REQUIRE_CONSISTENT_EDGES) self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) def test_unresolved_copies(self) -> None: