From c0486b3e14addd72b545c1ff6ae3f642d4feb76e Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Wed, 21 Jun 2023 13:01:07 -0400 Subject: [PATCH 01/16] Add test-utility PipelineTask with fully dynamic connections. --- .../pipe/base/tests/mocks/_pipeline_task.py | 340 +++++++++++++----- 1 file changed, 246 insertions(+), 94 deletions(-) diff --git a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py index d6b21e49..ffb2fdc0 100644 --- a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py +++ b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py @@ -20,19 +20,27 @@ # along with this program. If not, see . from __future__ import annotations -__all__ = ("MockPipelineTask", "MockPipelineTaskConfig", "mock_task_defs") +__all__ = ( + "DynamicConnectionConfig", + "DynamicTestPipelineTask", + "DynamicTestPipelineTaskConfig", + "MockPipelineTask", + "MockPipelineTaskConfig", + "mock_task_defs", +) import dataclasses import logging from collections.abc import Iterable, Mapping -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar -from lsst.daf.butler import DeferredDatasetHandle -from lsst.pex.config import ConfigurableField, Field, ListField +from lsst.daf.butler import DatasetRef, DeferredDatasetHandle +from lsst.pex.config import Config, ConfigDictField, ConfigurableField, Field, ListField from lsst.utils.doImport import doImportType from lsst.utils.introspection import get_full_type_name from lsst.utils.iteration import ensure_iterable +from ... import connectionTypes as cT from ...config import PipelineTaskConfig from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections from ...pipeline import TaskDef @@ -46,6 +54,9 @@ from ..._quantumContext import QuantumContext +_T = TypeVar("_T", bound=cT.BaseConnection) + + def mock_task_defs( originals: Iterable[TaskDef], unmocked_dataset_types: Iterable[str] = (), @@ -96,73 +107,11 @@ def mock_task_defs( return results -class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()): +class BaseTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()): pass -class MockPipelineDefaultTargetConfig( - PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections -): - pass - - -class MockPipelineDefaultTargetTask(PipelineTask): - """A `~lsst.pipe.base.PipelineTask` class used as the default target for - ``MockPipelineTaskConfig.original``. - - This is effectively a workaround for `lsst.pex.config.ConfigurableField` - not supporting ``optional=True``, but that is generally a reasonable - limitation for production code and it wouldn't make sense just to support - test utilities. - """ - - ConfigClass = MockPipelineDefaultTargetConfig - - -class MockPipelineTaskConnections(PipelineTaskConnections, dimensions=()): - def __init__(self, *, config: MockPipelineTaskConfig): - original: PipelineTaskConnections = config.original.connections.ConnectionsClass( - config=config.original.value - ) - self.dimensions.update(original.dimensions) - unmocked_dataset_types = frozenset(config.unmocked_dataset_types) - for name, connection in original.allConnections.items(): - if name in original.initInputs or name in original.initOutputs: - # We just ignore initInputs and initOutputs, because the task - # is never given DatasetRefs for those and hence can't create - # mocks. - continue - if connection.name not in unmocked_dataset_types: - # We register the mock storage class with the global singleton - # here, but can only put its name in the connection. That means - # the same global singleton (or one that also has these - # registrations) has to be available whenever this dataset type - # is used. - storage_class = MockStorageClass.get_or_register_mock(connection.storageClass) - kwargs = {} - if hasattr(connection, "dimensions"): - connection_dimensions = set(connection.dimensions) - # Replace the generic "skypix" placeholder with htm7, since - # that requires the dataset type to have already been - # registered. - if "skypix" in connection_dimensions: - connection_dimensions.remove("skypix") - connection_dimensions.add("htm7") - kwargs["dimensions"] = connection_dimensions - connection = dataclasses.replace( - connection, - name=get_mock_name(connection.name), - storageClass=storage_class.name, - **kwargs, - ) - elif name in original.outputs: - raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.") - setattr(self, name, connection) - - -class MockPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections): - """Configuration class for `MockPipelineTask`.""" - +class BaseTestPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=BaseTestPipelineTaskConnections): fail_condition = Field[str]( dtype=str, default="", @@ -181,29 +130,15 @@ class MockPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=MockPipelin ), ) - original: ConfigurableField = ConfigurableField( - doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask - ) - - unmocked_dataset_types = ListField[str]( - doc=( - "Names of input dataset types that should be used as-is instead " - "of being mocked. May include dataset types not relevant for " - "this task, which will be ignored." - ), - default=(), - optional=False, - ) - def data_id_match(self) -> DataIdMatch | None: if not self.fail_condition: return None return DataIdMatch(self.fail_condition) -class MockPipelineTask(PipelineTask): - """Implementation of `~lsst.pipe.base.PipelineTask` used for running a - mock pipeline. +class BaseTestPipelineTask(PipelineTask): + """A base class for test-utility `PipelineTask` classes that read and write + mock datasets `runQuantum`. Notes ----- @@ -213,20 +148,21 @@ class MockPipelineTask(PipelineTask): `MockDataset` inputs and simulates reading inputs of other types by creating `MockDataset` inputs from their DatasetRefs. - At present `MockPipelineTask` simply drops any ``initInput`` and - ``initOutput`` connections present on the original, since `MockDataset` - creation for those would have to happen in the code that executes the task, - not in the task itself. Because `MockPipelineTask` never instantiates the - mock task (just its connections class), this is a limitation on what the - mocks can be used to test, not anything deeper. + Subclasses are responsible for defining connections, but init-input and + init-output connections are not supported at runtime (they may be present + as long as the task is never constructed). All output connections must + use mock storage classes. `..Input` and `..PrerequisiteInput` connections + that do not use mock storage classes will be handled by constructing a + `MockDataset` from the `~lsst.daf.butler.DatasetRef` rather than actually + reading them. """ - ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig + ConfigClass: ClassVar[type[PipelineTaskConfig]] = BaseTestPipelineTaskConfig def __init__( self, *, - config: MockPipelineTaskConfig, + config: BaseTestPipelineTaskConfig, **kwargs: Any, ): super().__init__(config=config, **kwargs) @@ -235,7 +171,7 @@ def __init__( if self.data_id_match: self.fail_exception = doImportType(self.config.fail_exception) - config: MockPipelineTaskConfig + config: BaseTestPipelineTaskConfig def runQuantum( self, @@ -264,6 +200,7 @@ def runQuantum( ) for name, refs in inputRefs: inputs_list = [] + ref: DatasetRef for ref in ensure_iterable(refs): if isinstance(ref.datasetType.storageClass, MockStorageClass): input_dataset = butlerQC.get(ref) @@ -291,3 +228,218 @@ def runQuantum( butlerQC.put(output, ref) _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId) + + +class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()): + pass + + +class MockPipelineDefaultTargetConfig( + PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections +): + pass + + +class MockPipelineDefaultTargetTask(PipelineTask): + """A `~lsst.pipe.base.PipelineTask` class used as the default target for + ``MockPipelineTaskConfig.original``. + + This is effectively a workaround for `lsst.pex.config.ConfigurableField` + not supporting ``optional=True``, but that is generally a reasonable + limitation for production code and it wouldn't make sense just to support + test utilities. + """ + + ConfigClass = MockPipelineDefaultTargetConfig + + +class MockPipelineTaskConnections(BaseTestPipelineTaskConnections, dimensions=()): + """A connections class that creates mock connections from the connections + of a real PipelineTask. + """ + + def __init__(self, *, config: MockPipelineTaskConfig): + original: PipelineTaskConnections = config.original.connections.ConnectionsClass( + config=config.original.value + ) + self.dimensions.update(original.dimensions) + unmocked_dataset_types = frozenset(config.unmocked_dataset_types) + for name, connection in original.allConnections.items(): + if name in original.initInputs or name in original.initOutputs: + # We just ignore initInputs and initOutputs, because the task + # is never given DatasetRefs for those and hence can't create + # mocks. + continue + if connection.name not in unmocked_dataset_types: + # We register the mock storage class with the global singleton + # here, but can only put its name in the connection. That means + # the same global singleton (or one that also has these + # registrations) has to be available whenever this dataset type + # is used. + storage_class = MockStorageClass.get_or_register_mock(connection.storageClass) + kwargs = {} + if hasattr(connection, "dimensions"): + connection_dimensions = set(connection.dimensions) + # Replace the generic "skypix" placeholder with htm7, since + # that requires the dataset type to have already been + # registered. + if "skypix" in connection_dimensions: + connection_dimensions.remove("skypix") + connection_dimensions.add("htm7") + kwargs["dimensions"] = connection_dimensions + connection = dataclasses.replace( + connection, + name=get_mock_name(connection.name), + storageClass=storage_class.name, + **kwargs, + ) + elif name in original.outputs: + raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.") + setattr(self, name, connection) + + +class MockPipelineTaskConfig(BaseTestPipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections): + """Configuration class for `MockPipelineTask`.""" + + original: ConfigurableField = ConfigurableField( + doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask + ) + + unmocked_dataset_types = ListField[str]( + doc=( + "Names of input dataset types that should be used as-is instead " + "of being mocked. May include dataset types not relevant for " + "this task, which will be ignored." + ), + default=(), + optional=False, + ) + + +class MockPipelineTask(BaseTestPipelineTask): + """A test-utility implementation of `PipelineTask` with connections + generated by mocking those of a real task. + + Notes + ----- + At present `MockPipelineTask` simply drops any ``initInput`` and + ``initOutput`` connections present on the original, since `MockDataset` + creation for those would have to happen in the code that executes the task, + not in the task itself. Because `MockPipelineTask` never instantiates the + mock task (just its connections class), this is a limitation on what the + mocks can be used to test, not anything deeper. + """ + + ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig + + +class DynamicConnectionConfig(Config): + """A config class that defines a completely dynamic connection.""" + + dataset_type_name = Field[str](doc="Name for the dataset type as seen by the butler.", dtype=str) + dimensions = ListField[str](doc="Dimensions for the dataset type.", dtype=str, default=[]) + storage_class = Field[str]( + doc="Name of the butler storage class for the dataset type.", dtype=str, default="StructuredDataDict" + ) + is_calibration = Field[bool](doc="Whether this dataset type is a calibration.", dtype=bool, default=False) + multiple = Field[bool]( + doc="Whether this connection gets or puts multiple datasets for each quantum.", + dtype=bool, + default=False, + ) + mock_storage_class = Field[bool]( + doc="Whether the storage class should actually be a mock of the storage class given.", + dtype=bool, + default=True, + ) + + def make_connection(self, cls: type[_T]) -> _T: + storage_class = self.storage_class + if self.mock_storage_class: + storage_class = MockStorageClass.get_or_register_mock(storage_class).name + if issubclass(cls, cT.DimensionedConnection): + return cls( # type: ignore + name=self.dataset_type_name, + storageClass=storage_class, + isCalibration=self.is_calibration, + multiple=self.multiple, + dimensions=frozenset(self.dimensions), + ) + else: + return cls( + name=self.dataset_type_name, + storageClass=storage_class, + multiple=self.multiple, + ) + + +class DynamicTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()): + """A connections class whose dimensions and connections are wholly + determined via configuration. + """ + + def __init__(self, *, config: DynamicTestPipelineTaskConfig): + self.dimensions.update(config.dimensions) + connection_config: DynamicConnectionConfig + for connection_name, connection_config in config.init_inputs.items(): + setattr(self, connection_name, connection_config.make_connection(cT.InitInput)) + for connection_name, connection_config in config.init_outputs.items(): + setattr(self, connection_name, connection_config.make_connection(cT.InitOutput)) + for connection_name, connection_config in config.prerequisite_inputs.items(): + setattr(self, connection_name, connection_config.make_connection(cT.PrerequisiteInput)) + for connection_name, connection_config in config.inputs.items(): + setattr(self, connection_name, connection_config.make_connection(cT.Input)) + for connection_name, connection_config in config.outputs.items(): + setattr(self, connection_name, connection_config.make_connection(cT.Output)) + + +class DynamicTestPipelineTaskConfig( + PipelineTaskConfig, pipelineConnections=DynamicTestPipelineTaskConnections +): + """Configuration for DynamicTestPipelineTask.""" + + dimensions = ListField[str](doc="Dimensions for the task's quanta.", dtype=str, default=[]) + init_inputs = ConfigDictField( + doc=( + "Init-input connections, keyed by the connection name as seen by the task. " + "Must be empty if the task will be constructed." + ), + keytype=str, + itemtype=DynamicConnectionConfig, + default={}, + ) + init_outputs = ConfigDictField( + doc=( + "Init-output connections, keyed by the connection name as seen by the task. " + "Must be empty if the task will be constructed." + ), + keytype=str, + itemtype=DynamicConnectionConfig, + default={}, + ) + prerequisite_inputs = ConfigDictField( + doc="Prerequisite input connections, keyed by the connection name as seen by the task.", + keytype=str, + itemtype=DynamicConnectionConfig, + default={}, + ) + inputs = ConfigDictField( + doc="Regular input connections, keyed by the connection name as seen by the task.", + keytype=str, + itemtype=DynamicConnectionConfig, + default={}, + ) + outputs = ConfigDictField( + doc="Regular output connections, keyed by the connection name as seen by the task.", + keytype=str, + itemtype=DynamicConnectionConfig, + default={}, + ) + + +class DynamicTestPipelineTask(BaseTestPipelineTask): + """A test-utility implementation of `PipelineTask` with dimensions and + connections determined wholly from configuration. + """ + + ConfigClass: ClassVar[type[PipelineTaskConfig]] = DynamicTestPipelineTaskConfig From 3f6da4c414484cb481d79e080c87235f66769903 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Mon, 29 May 2023 11:49:41 -0400 Subject: [PATCH 02/16] Minor cleanups to docs and signatures in pipeTools. --- python/lsst/pipe/base/pipeTools.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/python/lsst/pipe/base/pipeTools.py b/python/lsst/pipe/base/pipeTools.py index 0a6f249f..8793e27b 100644 --- a/python/lsst/pipe/base/pipeTools.py +++ b/python/lsst/pipe/base/pipeTools.py @@ -31,7 +31,7 @@ # Imports of standard modules -- # ------------------------------- import itertools -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING # ----------------------------- @@ -80,15 +80,15 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF Parameters ---------- - pipeline : `pipe.base.Pipeline` + pipeline : `pipe.base.Pipeline` or `collections.abc.Iterable` [ `TaskDef` ] Pipeline description. taskFactory: `pipe.base.TaskFactory`, optional - Instance of an object which knows how to import task classes. It is - only used if pipeline task definitions do not define task classes. + Ignored; present only for backwards compatibility. Returns ------- - True for correctly ordered pipeline, False otherwise. + is_ordered : `bool` + True for correctly ordered pipeline, False otherwise. Raises ------ @@ -96,8 +96,6 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF Raised when task class cannot be imported. DuplicateOutputError Raised when there is more than one producer for a dataset type. - MissingTaskFactoryError - Raised when TaskFactory is needed but not provided. """ # Build a map of DatasetType name to producer's index in a pipeline producerIndex = {} @@ -123,27 +121,27 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF return True -def orderPipeline(pipeline: list[TaskDef]) -> list[TaskDef]: +def orderPipeline(pipeline: Sequence[TaskDef]) -> list[TaskDef]: """Re-order tasks in pipeline to satisfy data dependencies. When possible new ordering keeps original relative order of the tasks. Parameters ---------- - pipeline : `list` of `pipe.base.TaskDef` + pipeline : `pipe.base.Pipeline` or `collections.abc.Iterable` [ `TaskDef` ] Pipeline description. Returns ------- - Correctly ordered pipeline (`list` of `pipe.base.TaskDef` objects). + ordered : `list` [ `TaskDef` ] + Correctly ordered pipeline. Raises ------ - `DuplicateOutputError` is raised when there is more than one producer for a - dataset type. - `PipelineDataCycleError` is also raised when pipeline has dependency - cycles. `MissingTaskFactoryError` is raised when `TaskFactory` is needed - but not provided. + DuplicateOutputError + Raised when there is more than one producer for a dataset type. + PipelineDataCycleError + Raised when the pipeline has dependency cycles. """ # This is a modified version of Kahn's algorithm that preserves order From bac56eabcc64ea1f3b5de0cfda35be27eecf59d3 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Mon, 29 May 2023 11:51:23 -0400 Subject: [PATCH 03/16] Add PipelineGraph package. --- python/lsst/pipe/base/__init__.py | 6 +- python/lsst/pipe/base/pipeline.py | 26 +- .../lsst/pipe/base/pipeline_graph/__init__.py | 29 + .../base/pipeline_graph/_dataset_types.py | 221 +++ .../lsst/pipe/base/pipeline_graph/_edges.py | 714 +++++++++ .../pipe/base/pipeline_graph/_exceptions.py | 95 ++ .../base/pipeline_graph/_mapping_views.py | 197 +++ .../lsst/pipe/base/pipeline_graph/_nodes.py | 85 + .../base/pipeline_graph/_pipeline_graph.py | 1389 +++++++++++++++++ .../pipe/base/pipeline_graph/_task_subsets.py | 122 ++ .../lsst/pipe/base/pipeline_graph/_tasks.py | 871 +++++++++++ python/lsst/pipe/base/pipeline_graph/io.py | 578 +++++++ tests/test_pipeline_graph.py | 1258 +++++++++++++++ 13 files changed, 5582 insertions(+), 9 deletions(-) create mode 100644 python/lsst/pipe/base/pipeline_graph/__init__.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_dataset_types.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_edges.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_exceptions.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_mapping_views.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_nodes.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_task_subsets.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_tasks.py create mode 100644 python/lsst/pipe/base/pipeline_graph/io.py create mode 100644 tests/test_pipeline_graph.py diff --git a/python/lsst/pipe/base/__init__.py b/python/lsst/pipe/base/__init__.py index 74339da9..51652a71 100644 --- a/python/lsst/pipe/base/__init__.py +++ b/python/lsst/pipe/base/__init__.py @@ -1,4 +1,4 @@ -from . import automatic_connection_constants, connectionTypes, pipelineIR +from . import automatic_connection_constants, connectionTypes, pipeline_graph, pipelineIR from ._dataset_handle import * from ._instrument import * from ._observation_dimension_packer import * @@ -11,6 +11,10 @@ from .graph import * from .graphBuilder import * from .pipeline import * + +# We import the main PipelineGraph type and the module (above), but we don't +# lift all symbols to package scope. +from .pipeline_graph import PipelineGraph from .pipelineTask import * from .struct import * from .task import * diff --git a/python/lsst/pipe/base/pipeline.py b/python/lsst/pipe/base/pipeline.py index efec0c1f..65dfc4d0 100644 --- a/python/lsst/pipe/base/pipeline.py +++ b/python/lsst/pipe/base/pipeline.py @@ -58,7 +58,7 @@ from ._instrument import Instrument as PipeBaseInstrument from ._task_metadata import TaskMetadata from .config import PipelineTaskConfig -from .connections import iterConnections +from .connections import PipelineTaskConnections, iterConnections from .connectionTypes import Input from .pipelineTask import PipelineTask from .task import _TASK_METADATA_TYPE @@ -127,6 +127,11 @@ class TaskDef: Task label, usually a short string unique in a pipeline. If not provided, ``taskClass`` must be, and ``taskClass._DefaultName`` will be used. + connections : `PipelineTaskConnections`, optional + Object that describes the dataset types used by the task. If not + provided, one will be constructed from the given configuration. If + provided, it is assumed that ``config`` has already been validated + and frozen. """ def __init__( @@ -135,6 +140,7 @@ def __init__( config: PipelineTaskConfig | None = None, taskClass: type[PipelineTask] | None = None, label: str | None = None, + connections: PipelineTaskConnections | None = None, ): if taskName is None: if taskClass is None: @@ -151,16 +157,20 @@ def __init__( raise ValueError("`taskClass` must be provided if `label` is not.") label = taskClass._DefaultName self.taskName = taskName - try: - config.validate() - except Exception: - _LOG.error("Configuration validation failed for task %s (%s)", label, taskName) - raise - config.freeze() + if connections is None: + # If we don't have connections yet, assume the config hasn't been + # validated yet. + try: + config.validate() + except Exception: + _LOG.error("Configuration validation failed for task %s (%s)", label, taskName) + raise + config.freeze() + connections = config.connections.ConnectionsClass(config=config) self.config = config self.taskClass = taskClass self.label = label - self.connections = config.connections.ConnectionsClass(config=config) + self.connections = connections @property def configDatasetName(self) -> str: diff --git a/python/lsst/pipe/base/pipeline_graph/__init__.py b/python/lsst/pipe/base/pipeline_graph/__init__.py new file mode 100644 index 00000000..3cf7a810 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/__init__.py @@ -0,0 +1,29 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +from ._dataset_types import * +from ._edges import * +from ._exceptions import * +from ._nodes import * +from ._pipeline_graph import * +from ._task_subsets import * +from ._tasks import * diff --git a/python/lsst/pipe/base/pipeline_graph/_dataset_types.py b/python/lsst/pipe/base/pipeline_graph/_dataset_types.py new file mode 100644 index 00000000..4d949edf --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_dataset_types.py @@ -0,0 +1,221 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ("DatasetTypeNode",) + +import dataclasses +from typing import TYPE_CHECKING, Any + +import networkx +from lsst.daf.butler import DatasetRef, DatasetType, DimensionGraph, Registry, StorageClass +from lsst.daf.butler.registry import MissingDatasetTypeError + +from ._exceptions import DuplicateOutputError +from ._nodes import NodeKey, NodeType + +if TYPE_CHECKING: + from ._edges import ReadEdge, WriteEdge + + +@dataclasses.dataclass(frozen=True, eq=False) +class DatasetTypeNode: + """A node in a pipeline graph that represents a resolved dataset type. + + Notes + ----- + A dataset type node represents a common definition of the dataset type + across the entire graph - it is never a component, and the storage class is + the registry dataset type's storage class or (if there isn't one) the one + defined by the producing task. + + Dataset type nodes are intentionally not equality comparable, since there + are many different (and useful) ways to compare these objects with no clear + winner as the most obvious behavior. + """ + + dataset_type: DatasetType + """Common definition of this dataset type for the graph. + """ + + is_initial_query_constraint: bool + """Whether this dataset should be included as a constraint in the initial + query for data IDs in QuantumGraph generation. + + This is only `True` for dataset types that are overall regular inputs, and + only if none of those input connections had ``deferQueryConstraint=True``. + """ + + is_prerequisite: bool + """Whether this dataset type is a prerequisite input that must exist in + the Registry before graph creation. + """ + + @classmethod + def _from_edges( + cls, key: NodeKey, xgraph: networkx.MultiDiGraph, registry: Registry, previous: DatasetTypeNode | None + ) -> DatasetTypeNode: + """Construct a dataset type node from its edges. + + Parameters + ---------- + key : `NodeKey` + Named tuple that holds the dataset type and serves as the node + object in the internal networkx graph. + xgraph : `networkx.MultiDiGraph` + The internal networkx graph. + registry : `lsst.daf.butler.Registry` + Registry client for the data repository. Only used to get + dataset type definitions and the dimension universe. + previous : `DatasetTypeNode` or `None` + Previous node for this dataset type. + + Returns + ------- + node : `DatasetTypeNode` + Node consistent with all edges pointing to it and the data + repository. + """ + try: + dataset_type = registry.getDatasetType(key.name) + is_registered = True + except MissingDatasetTypeError: + dataset_type = None + is_registered = False + if previous is not None and previous.dataset_type == dataset_type: + # This node was already resolved (with exactly the same edges + # contributing, since we clear resolutions when edges are added or + # removed). The only thing that might have changed was the + # definition in the registry, and it didn't. + return previous + is_initial_query_constraint = True + is_prerequisite: bool | None = None + producer: str | None = None + write_edge: WriteEdge + # Iterate over the incoming edges to this node, which represent the + # output connections of tasks that write this dataset type; these take + # precedence over the inputs in determining the graph-wide dataset type + # definition (and hence which storage class we register when using the + # graph to register dataset types). There should only be one such + # connection, but we won't necessarily have checked that rule until + # here. As a result there can be at most one iteration of this loop. + for _, _, write_edge in xgraph.in_edges(key, data="instance"): + if producer is not None: + raise DuplicateOutputError( + f"Dataset type {key.name!r} is produced by both {write_edge.task_label!r} " + f"and {producer!r}." + ) + producer = write_edge.task_label + dataset_type = write_edge._resolve_dataset_type(dataset_type, universe=registry.dimensions) + is_prerequisite = False + is_initial_query_constraint = False + read_edge: ReadEdge + consumers: list[str] = [] + read_edges = list(read_edge for _, _, read_edge in xgraph.out_edges(key, data="instance")) + # Put edges that are not component datasets before any edges that are. + read_edges.sort(key=lambda read_edge: read_edge.component is not None) + for read_edge in read_edges: + dataset_type, is_initial_query_constraint, is_prerequisite = read_edge._resolve_dataset_type( + current=dataset_type, + universe=registry.dimensions, + is_initial_query_constraint=is_initial_query_constraint, + is_prerequisite=is_prerequisite, + is_registered=is_registered, + producer=producer, + consumers=consumers, + ) + consumers.append(read_edge.task_label) + assert dataset_type is not None, "Graph structure guarantees at least one edge." + assert is_prerequisite is not None, "Having at least one edge guarantees is_prerequisite is known." + return DatasetTypeNode( + dataset_type=dataset_type, + is_initial_query_constraint=is_initial_query_constraint, + is_prerequisite=is_prerequisite, + ) + + @property + def name(self) -> str: + """Name of the dataset type. + + This is always the parent dataset type, never that of a component. + """ + return self.dataset_type.name + + @property + def key(self) -> NodeKey: + """Key that identifies this dataset type in internal and exported + networkx graphs. + """ + return NodeKey(NodeType.DATASET_TYPE, self.dataset_type.name) + + @property + def dimensions(self) -> DimensionGraph: + """Dimensions of the dataset type.""" + return self.dataset_type.dimensions + + @property + def storage_class_name(self) -> str: + """String name of the storage class for this dataset type.""" + return self.dataset_type.storageClass_name + + @property + def storage_class(self) -> StorageClass: + """Storage class for this dataset type.""" + return self.dataset_type.storageClass + + def __repr__(self) -> str: + return f"{self.name} ({self.storage_class_name}, {self.dimensions})" + + def generalize_ref(self, ref: DatasetRef) -> DatasetRef: + """Convert a `~lsst.daf.butler.DatasetRef` with the dataset type + associated with some task to one with the common dataset type defined + by this node. + + Parameters + ---------- + ref : `lsst.daf.butler.DatasetRef` + Reference whose dataset type is convertible to this node's, either + because it is a component with the node's dataset type as its + parent, or because it has a compatible storage class. + + Returns + ------- + ref : `lsst.daf.butler.DatasetRef` + Reference with exactly this node's dataset type. + """ + if ref.isComponent(): + ref = ref.makeCompositeRef() + if ref.datasetType.storageClass_name != self.dataset_type.storageClass_name: + return ref.overrideStorageClass(self.dataset_type.storageClass_name) + return ref + + def _to_xgraph_state(self) -> dict[str, Any]: + """Convert this node's attributes into a dictionary suitable for use + in exported networkx graphs. + """ + return { + "dataset_type": self.dataset_type, + "is_initial_query_constraint": self.is_initial_query_constraint, + "is_prerequisite": self.is_prerequisite, + "dimensions": self.dataset_type.dimensions, + "storage_class_name": self.dataset_type.storageClass_name, + "bipartite": NodeType.DATASET_TYPE.bipartite, + } diff --git a/python/lsst/pipe/base/pipeline_graph/_edges.py b/python/lsst/pipe/base/pipeline_graph/_edges.py new file mode 100644 index 00000000..10ea6b11 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_edges.py @@ -0,0 +1,714 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ("Edge", "ReadEdge", "WriteEdge") + +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from typing import Any, ClassVar, TypeVar + +from lsst.daf.butler import DatasetRef, DatasetType, DimensionUniverse, SkyPixDimension +from lsst.daf.butler.registry import MissingDatasetTypeError +from lsst.utils.classes import immutable + +from ..connectionTypes import BaseConnection +from ._exceptions import ConnectionTypeConsistencyError, IncompatibleDatasetTypeError +from ._nodes import NodeKey, NodeType + +_S = TypeVar("_S", bound="Edge") + + +@immutable +class Edge(ABC): + """Base class for edges in a pipeline graph. + + This represents the link between a task node and an input or output dataset + type. + + Parameters + ---------- + task_key : `NodeKey` + Key for the task node this edge is connected to. + dataset_type_key : `NodeKey` + Key for the dataset type node this edge is connected to. + storage_class_name : `str` + Name of the dataset type's storage class as seen by the task. + connection_name : `str` + Internal name for the connection as seen by the task. + is_calibration : `bool` + Whether this dataset type can be included in + `~lsst.daf.butler.CollectionType.CALIBRATION` collections. + raw_dimensions : `frozenset` [ `str` ] + Raw dimensions from the connection definition. + """ + + def __init__( + self, + *, + task_key: NodeKey, + dataset_type_key: NodeKey, + storage_class_name: str, + connection_name: str, + is_calibration: bool, + raw_dimensions: frozenset[str], + ): + self.task_key = task_key + self.dataset_type_key = dataset_type_key + self.connection_name = connection_name + self.storage_class_name = storage_class_name + self.is_calibration = is_calibration + self.raw_dimensions = raw_dimensions + + INIT_TO_TASK_NAME: ClassVar[str] = "INIT" + """Edge key for the special edge that connects a task init node to the + task node itself (for regular edges, this would be the connection name). + """ + + task_key: NodeKey + """Task part of the key for this edge in networkx graphs.""" + + dataset_type_key: NodeKey + """Task part of the key for this edge in networkx graphs.""" + + connection_name: str + """Name used by the task to refer to this dataset type.""" + + storage_class_name: str + """Storage class expected by this task. + + If `ReadEdge.component` is not `None`, this is the component storage class, + not the parent storage class. + """ + + is_calibration: bool + """Whether this dataset type can be included in + `~lsst.daf.butler.CollectionType.CALIBRATION` collections. + """ + + raw_dimensions: frozenset[str] + """Raw dimensions in the task declaration. + + This can only be used safely for partial comparisons: two edges with the + same ``raw_dimensions`` (and the same parent dataset type name) always have + the same resolved dimensions, but edges with different ``raw_dimensions`` + may also have the same resolvd dimensions. + """ + + @property + def is_init(self) -> bool: + """Whether this dataset is read or written when the task is + constructed, not when it is run. + """ + return self.task_key.node_type is NodeType.TASK_INIT + + @property + def task_label(self) -> str: + """Label of the task.""" + return str(self.task_key) + + @property + def parent_dataset_type_name(self) -> str: + """Name of the parent dataset type. + + All dataset type nodes in a pipeline graph are for parent dataset + types; components are represented by additional `ReadEdge` state. + """ + return str(self.dataset_type_key) + + @property + @abstractmethod + def nodes(self) -> tuple[NodeKey, NodeKey]: + """The directed pair of `NodeKey` instances this edge connects. + + This tuple is ordered in the same direction as the pipeline flow: + `task_key` precedes `dataset_type_key` for writes, and the + reverse is true for reads. + """ + raise NotImplementedError() + + @property + def key(self) -> tuple[NodeKey, NodeKey, str]: + """Ordered tuple of node keys and connection name that uniquely + identifies this edge in a pipeline graph. + """ + return self.nodes + (self.connection_name,) + + def __repr__(self) -> str: + return f"{self.nodes[0]} -> {self.nodes[1]} ({self.connection_name})" + + @property + def dataset_type_name(self) -> str: + """Dataset type name seen by the task. + + This defaults to the parent dataset type name, which is appropriate + for all writes and most reads. + """ + return self.parent_dataset_type_name + + def diff(self: _S, other: _S, connection_type: str = "connection") -> list[str]: + """Compare this edge to another one from a possibly-different + configuration of the same task label. + + Parameters + ---------- + other : `Edge` + Another edge of the same type to compare to. + connection_type : `str` + Human-readable name of the connection type of this edge (e.g. + "init input", "output") for use in returned messages. + + Returns + ------- + differences : `list` [ `str` ] + List of string messages describing differences between ``self`` and + ``other``. Will be empty if ``self == other`` or if the only + difference is in the task label or connection name (which are not + checked). Messages will use 'A' to refer to ``self`` and 'B' to + refer to ``other``. + """ + result = [] + if self.dataset_type_name != other.dataset_type_name: + result.append( + f"{connection_type.capitalize()} {self.connection_name!r} has dataset type " + f"{self.dataset_type_name!r} in A, but {other.dataset_type_name!r} in B." + ) + if self.storage_class_name != other.storage_class_name: + result.append( + f"{connection_type.capitalize()} {self.connection_name!r} has storage class " + f"{self.storage_class_name!r} in A, but {other.storage_class_name!r} in B." + ) + if self.raw_dimensions != other.raw_dimensions: + result.append( + f"{connection_type.capitalize()} {self.connection_name!r} has raw dimensions " + f"{set(self.raw_dimensions)} in A, but {set(other.raw_dimensions)} in B " + "(differences in raw dimensions may not lead to differences in resolved dimensions, " + "but this cannot be checked without re-resolving the dataset type)." + ) + if self.is_calibration != other.is_calibration: + result.append( + f"{connection_type.capitalize()} {self.connection_name!r} is marked as a calibration " + f"{'in A but not in B' if self.is_calibration else 'in B but not in A'}." + ) + return result + + @abstractmethod + def adapt_dataset_type(self, dataset_type: DatasetType) -> DatasetType: + """Transform the graph's definition of a dataset type (parent, with the + registry or producer's storage class) to the one seen by this task. + """ + raise NotImplementedError() + + @abstractmethod + def adapt_dataset_ref(self, ref: DatasetRef) -> DatasetRef: + """Transform the graph's definition of a dataset reference (parent + dataset type, with the registry or producer's storage class) to the one + seen by this task. + """ + raise NotImplementedError() + + def _to_xgraph_state(self) -> dict[str, Any]: + """Convert this edges's attributes into a dictionary suitable for use + in exported networkx graphs. + """ + return { + "parent_dataset_type_name": self.parent_dataset_type_name, + "storage_class_name": self.storage_class_name, + "is_init": bool, + } + + +class ReadEdge(Edge): + """Representation of an input connection (including init-inputs and + prerequisites) in a pipeline graph. + + Parameters + ---------- + dataset_type_key : `NodeKey` + Key for the dataset type node this edge is connected to. This should + hold the parent dataset type name for component dataset types. + task_key : `NodeKey` + Key for the task node this edge is connected to. + storage_class_name : `str` + Name of the dataset type's storage class as seen by the task. + connection_name : `str` + Internal name for the connection as seen by the task. + is_calibration : `bool` + Whether this dataset type can be included in + `~lsst.daf.butler.CollectionType.CALIBRATION` collections. + raw_dimensions : `frozenset` [ `str` ] + Raw dimensions from the connection definition. + is_prerequisite : `bool` + Whether this dataset must be present in the data repository prior to + `QuantumGraph` generation. + component : `str` or `None` + Component of the dataset type requested by the task. + defer_query_constraint : `bool` + If `True`, by default do not include this dataset type's existence as a + constraint on the initial data ID query in QuantumGraph generation. + + Notes + ----- + When included in an exported `networkx` graph (e.g. + `PipelineGraph.make_xgraph`), read edges set the following edge attributes: + + - ``parent_dataset_type_name`` + - ``storage_class_name`` + - ``is_init`` + - ``component`` + - ``is_prerequisite`` + + As with `ReadEdge` instance attributes, these descriptions of dataset types + are those specific to a task, and may differ from the graph's resolved + dataset type or (if `PipelineGraph.resolve` has not been called) there may + not even be a consistent definition of the dataset type. + """ + + def __init__( + self, + dataset_type_key: NodeKey, + task_key: NodeKey, + *, + storage_class_name: str, + connection_name: str, + is_calibration: bool, + raw_dimensions: frozenset[str], + is_prerequisite: bool, + component: str | None, + defer_query_constraint: bool, + ): + super().__init__( + task_key=task_key, + dataset_type_key=dataset_type_key, + storage_class_name=storage_class_name, + connection_name=connection_name, + raw_dimensions=raw_dimensions, + is_calibration=is_calibration, + ) + self.is_prerequisite = is_prerequisite + self.component = component + self.defer_query_constraint = defer_query_constraint + + component: str | None + """Component to add to `parent_dataset_type_name` to form the dataset type + name seen by this task. + """ + + is_prerequisite: bool + """Whether this dataset must be present in the data repository prior to + `QuantumGraph` generation. + """ + + defer_query_constraint: bool + """If `True`, by default do not include this dataset type's existence as a + constraint on the initial data ID query in QuantumGraph generation. + """ + + @property + def nodes(self) -> tuple[NodeKey, NodeKey]: + # Docstring inherited. + return (self.dataset_type_key, self.task_key) + + @property + def dataset_type_name(self) -> str: + """Complete dataset type name, as seen by the task.""" + if self.component is not None: + return f"{self.parent_dataset_type_name}.{self.component}" + return self.parent_dataset_type_name + + def diff(self: ReadEdge, other: ReadEdge, connection_type: str = "connection") -> list[str]: + # Docstring inherited. + result = super().diff(other, connection_type) + if self.defer_query_constraint != other.defer_query_constraint: + result.append( + f"{connection_type.capitalize()} {self.connection_name!r} is marked as a deferred query " + f"constraint {'in A but not in B' if self.defer_query_constraint else 'in B but not in A'}." + ) + return result + + def adapt_dataset_type(self, dataset_type: DatasetType) -> DatasetType: + # Docstring inherited. + if self.component is not None: + assert ( + self.storage_class_name == dataset_type.storageClass.allComponents()[self.component].name + ), "components with storage class overrides are not supported" + return dataset_type.makeComponentDatasetType(self.component) + if self.storage_class_name != dataset_type.storageClass_name: + return dataset_type.overrideStorageClass(self.storage_class_name) + return dataset_type + + def adapt_dataset_ref(self, ref: DatasetRef) -> DatasetRef: + # Docstring inherited. + if self.component is not None: + assert ( + self.storage_class_name == ref.datasetType.storageClass.allComponents()[self.component].name + ), "components with storage class overrides are not supported" + return ref.makeComponentRef(self.component) + if self.storage_class_name != ref.datasetType.storageClass_name: + return ref.overrideStorageClass(self.storage_class_name) + return ref + + @classmethod + def _from_connection_map( + cls, + task_key: NodeKey, + connection_name: str, + connection_map: Mapping[str, BaseConnection], + is_prerequisite: bool = False, + ) -> ReadEdge: + """Construct a `ReadEdge` instance from a `.BaseConnection` object. + + Parameters + ---------- + task_key : `NodeKey` + Key for the associated task node or task init node. + connection_name : `str` + Internal name for the connection as seen by the task,. + connection_map : Mapping [ `str`, `.BaseConnection` ] + Mapping of post-configuration object to draw dataset type + information from, keyed by connection name. + is_prerequisite : `bool`, optional + Whether this dataset must be present in the data repository prior + to `QuantumGraph` generation. + + Returns + ------- + edge : `ReadEdge` + New edge instance. + """ + connection = connection_map[connection_name] + parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) + return cls( + dataset_type_key=NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name), + task_key=task_key, + component=component, + storage_class_name=connection.storageClass, + # InitInput connections don't have .isCalibration. + is_calibration=getattr(connection, "isCalibration", False), + is_prerequisite=is_prerequisite, + connection_name=connection_name, + # InitInput connections don't have a .dimensions because they + # always have empty dimensions. + raw_dimensions=frozenset(getattr(connection, "dimensions", frozenset())), + # PrerequisiteInput and InitInput connections don't have a + # .eferQueryConstraints, because they never constrain the initial + # data ID query. + defer_query_constraint=getattr(connection, "deferQueryConstraint", False), + ) + + def _resolve_dataset_type( + self, + *, + current: DatasetType | None, + is_initial_query_constraint: bool, + is_prerequisite: bool | None, + universe: DimensionUniverse, + producer: str | None, + consumers: Sequence[str], + is_registered: bool, + ) -> tuple[DatasetType, bool, bool]: + """Participate in the construction of the `DatasetTypeNode` object + associated with this edge. + + Parameters + ---------- + current : `lsst.daf.butler.DatasetType` or `None` + The current graph-wide `DatasetType`, or `None`. This will always + be the registry's definition of the parent dataset type, if one + exists. If not, it will be the dataset type definition from the + task in the graph that writes it, if there is one. If there is no + such task, this will be `None`. + is_initial_query_constraint : `bool` + Whether this dataset type is currently marked as a constraint on + the initial data ID query in QuantumGraph generation. + is_prerequisite : `bool` | None` + Whether this dataset type is marked as a prerequisite input in all + edges processed so far. `None` if this is the first edge. + universe : `lsst.daf.butler.DimensionUniverse` + Object that holds all dimension definitions. + producer : `str` or `None` + The label of the task that produces this dataset type in the + pipeline, or `None` if it is an overall input. + consumers : `Sequence` [ `str` ] + Labels for other consuming tasks that have already participated in + this dataset type's resolution. + is_registered : `bool` + Whether a registration for this dataset type was found in the + data repository. + + Returns + ------- + dataset_type : `DatasetType` + The updated graph-wide dataset type. If ``current`` was provided, + this must be equal to it. + is_initial_query_constraint : `bool` + If `True`, this dataset type should be included as a constraint in + the initial data ID query during QuantumGraph generation; this + requires that ``is_initial_query_constraint`` also be `True` on + input. + is_prerequisite : `bool` + Whether this dataset type is marked as a prerequisite input in this + task and all other edges processed so far. + + Raises + ------ + MissingDatasetTypeError + Raised if ``current is None`` and this edge cannot define one on + its own. + IncompatibleDatasetTypeError + Raised if ``current is not None`` and this edge's definition is not + compatible with it. + ConnectionTypeConsistencyError + Raised if a prerequisite input for one task appears as a different + kind of connection in any other task. + """ + if "skypix" in self.raw_dimensions: + if current is None: + raise MissingDatasetTypeError( + f"DatasetType '{self.dataset_type_name}' referenced by " + f"{self.task_label!r} uses 'skypix' as a dimension " + f"placeholder, but has not been registered with the data repository. " + f"Note that reference catalog names are now used as the dataset " + f"type name instead of 'ref_cat'." + ) + rest1 = set(universe.extract(self.raw_dimensions - set(["skypix"])).names) + rest2 = set(dim.name for dim in current.dimensions if not isinstance(dim, SkyPixDimension)) + if rest1 != rest2: + raise IncompatibleDatasetTypeError( + f"Non-skypix dimensions for dataset type {self.dataset_type_name} declared in " + f"connections ({rest1}) are inconsistent with those in " + f"registry's version of this dataset ({rest2})." + ) + dimensions = current.dimensions + else: + dimensions = universe.extract(self.raw_dimensions) + is_initial_query_constraint = is_initial_query_constraint and not self.defer_query_constraint + if is_prerequisite is None: + is_prerequisite = self.is_prerequisite + elif is_prerequisite and not self.is_prerequisite: + raise ConnectionTypeConsistencyError( + f"Dataset type {self.parent_dataset_type_name!r} is a prerequisite input to {consumers}, " + f"but it is not a prerequisite to {self.task_label!r}." + ) + elif not is_prerequisite and self.is_prerequisite: + if producer is not None: + raise ConnectionTypeConsistencyError( + f"Dataset type {self.parent_dataset_type_name!r} is a prerequisite input to " + f"{self.task_label}, but it is produced by {producer!r}." + ) + else: + raise ConnectionTypeConsistencyError( + f"Dataset type {self.parent_dataset_type_name!r} is a prerequisite input to " + f"{self.task_label}, but it is a regular input to {consumers!r}." + ) + + def report_current_origin() -> str: + if is_registered: + return "data repository" + elif producer is not None: + return f"producing task {producer!r}" + else: + return f"consuming task(s) {consumers!r}" + + if self.component is not None: + if current is None: + raise MissingDatasetTypeError( + f"Dataset type {self.parent_dataset_type_name!r} is not registered and not produced by " + f"this pipeline, but it used by task {self.task_label!r}, via component " + f"{self.component!r}. This pipeline cannot be resolved until the parent dataset type is " + "registered." + ) + all_current_components = current.storageClass.allComponents() + if self.component not in all_current_components: + raise IncompatibleDatasetTypeError( + f"Dataset type {self.parent_dataset_type_name!r} has storage class " + f"{current.storageClass_name!r} (from {report_current_origin()}), " + f"which does not include component {self.component!r} " + f"as requested by task {self.task_label!r}." + ) + if all_current_components[self.component].name != self.storage_class_name: + raise IncompatibleDatasetTypeError( + f"Dataset type '{self.parent_dataset_type_name}.{self.component}' has storage class " + f"{all_current_components[self.component].name!r} " + f"(from {report_current_origin()}), which does not match " + f"{self.storage_class_name!r}, as requested by task {self.task_label!r}. " + "Note that storage class conversions of components are not supported." + ) + return current, is_initial_query_constraint, is_prerequisite + else: + dataset_type = DatasetType( + self.parent_dataset_type_name, + dimensions, + storageClass=self.storage_class_name, + isCalibration=self.is_calibration, + ) + if current is not None: + if not is_registered and producer is None: + # Current definition comes from another consumer; we + # require the dataset types to be exactly equal (not just + # compatible), since neither connection should take + # precedence. + if dataset_type != current: + raise MissingDatasetTypeError( + f"Definitions differ for input dataset type {self.parent_dataset_type_name!r}; " + f"task {self.task_label!r} has {dataset_type}, but the definition " + f"from {report_current_origin()} is {current}. If the storage classes are " + "compatible but different, registering the dataset type in the data repository " + "in advance will avoid this error." + ) + elif not dataset_type.is_compatible_with(current): + raise IncompatibleDatasetTypeError( + f"Incompatible definition for input dataset type {self.parent_dataset_type_name!r}; " + f"task {self.task_label!r} has {dataset_type}, but the definition " + f"from {report_current_origin()} is {current}." + ) + return current, is_initial_query_constraint, is_prerequisite + else: + return dataset_type, is_initial_query_constraint, is_prerequisite + + def _to_xgraph_state(self) -> dict[str, Any]: + # Docstring inherited. + result = super()._to_xgraph_state() + result["component"] = self.component + result["is_prerequisite"] = self.is_prerequisite + return result + + +class WriteEdge(Edge): + """Representation of an output connection (including init-outputs) in a + pipeline graph. + + Notes + ----- + When included in an exported `networkx` graph (e.g. + `PipelineGraph.make_xgraph`), write edges set the following edge + attributes: + + - ``parent_dataset_type_name`` + - ``storage_class_name`` + - ``is_init`` + + As with `WRiteEdge` instance attributes, these descriptions of dataset + types are those specific to a task, and may differ from the graph's + resolved dataset type or (if `PipelineGraph.resolve` has not been called) + there may not even be a consistent definition of the dataset type. + """ + + @property + def nodes(self) -> tuple[NodeKey, NodeKey]: + # Docstring inherited. + return (self.task_key, self.dataset_type_key) + + def adapt_dataset_type(self, dataset_type: DatasetType) -> DatasetType: + # Docstring inherited. + if self.storage_class_name != dataset_type.storageClass_name: + return dataset_type.overrideStorageClass(self.storage_class_name) + return dataset_type + + def adapt_dataset_ref(self, ref: DatasetRef) -> DatasetRef: + # Docstring inherited. + if self.storage_class_name != ref.datasetType.storageClass_name: + return ref.overrideStorageClass(self.storage_class_name) + return ref + + @classmethod + def _from_connection_map( + cls, + task_key: NodeKey, + connection_name: str, + connection_map: Mapping[str, BaseConnection], + ) -> WriteEdge: + """Construct a `WriteEdge` instance from a `.BaseConnection` object. + + Parameters + ---------- + task_key : `NodeKey` + Key for the associated task node or task init node. + connection_name : `str` + Internal name for the connection as seen by the task,. + connection_map : Mapping [ `str`, `.BaseConnection` ] + Mapping of post-configuration object to draw dataset type + information from, keyed by connection name. + + Returns + ------- + edge : `WriteEdge` + New edge instance. + """ + connection = connection_map[connection_name] + parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) + if component is not None: + raise ValueError( + f"Illegal output component dataset {connection.name!r} in task {task_key.name!r}." + ) + return cls( + task_key=task_key, + dataset_type_key=NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name), + storage_class_name=connection.storageClass, + connection_name=connection_name, + # InitOutput connections don't have .isCalibration. + is_calibration=getattr(connection, "isCalibration", False), + # InitOutput connections don't have a .dimensions because they + # always have empty dimensions. + raw_dimensions=frozenset(getattr(connection, "dimensions", frozenset())), + ) + + def _resolve_dataset_type(self, current: DatasetType | None, universe: DimensionUniverse) -> DatasetType: + """Participate in the construction of the `DatasetTypeNode` object + associated with this edge. + + Parameters + ---------- + current : `lsst.daf.butler.DatasetType` or `None` + The current graph-wide `DatasetType`, or `None`. This will always + be the registry's definition of the parent dataset type, if one + exists. + universe : `lsst.daf.butler.DimensionUniverse` + Object that holds all dimension definitions. + + Returns + ------- + dataset_type : `DatasetType` + A dataset type compatible with this edge. If ``current`` was + provided, this must be equal to it. + + Raises + ------ + IncompatibleDatasetTypeError + Raised if ``current is not None`` and this edge's definition is not + compatible with it. + """ + dimensions = universe.extract(self.raw_dimensions) + dataset_type = DatasetType( + self.parent_dataset_type_name, + dimensions, + storageClass=self.storage_class_name, + isCalibration=self.is_calibration, + ) + if current is not None: + if not current.is_compatible_with(dataset_type): + raise IncompatibleDatasetTypeError( + f"Incompatible definition for output dataset type {self.parent_dataset_type_name!r}: " + f"task {self.task_label!r} has {current}, but data repository has {dataset_type}." + ) + return current + else: + return dataset_type diff --git a/python/lsst/pipe/base/pipeline_graph/_exceptions.py b/python/lsst/pipe/base/pipeline_graph/_exceptions.py new file mode 100644 index 00000000..8ed6cd16 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_exceptions.py @@ -0,0 +1,95 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ( + "ConnectionTypeConsistencyError", + "DuplicateOutputError", + "IncompatibleDatasetTypeError", + "PipelineGraphExceptionSafetyError", + "PipelineDataCycleError", + "PipelineGraphError", + "PipelineGraphReadError", + "EdgesChangedError", + "UnresolvedGraphError", + "TaskNotImportedError", +) + + +class PipelineGraphError(RuntimeError): + """Base exception raised when there is a problem constructing or resolving + a pipeline graph. + """ + + +class DuplicateOutputError(PipelineGraphError): + """Exception raised when multiple tasks in one pipeline produce the same + output dataset type. + """ + + +class PipelineDataCycleError(PipelineGraphError): + """Exception raised when a pipeline graph contains a cycle.""" + + +class ConnectionTypeConsistencyError(PipelineGraphError): + """Exception raised when the tasks in a pipeline graph use different (and + incompatible) connection types for the same dataset type. + """ + + +class IncompatibleDatasetTypeError(PipelineGraphError): + """Exception raised when the tasks in a pipeline graph define dataset types + with the same name in incompatible ways, or when these are incompatible + with the data repository definition. + """ + + +class UnresolvedGraphError(PipelineGraphError): + """Exception raised when an operation requires dimensions or dataset types + to have been resolved, but they have not been. + """ + + +class PipelineGraphReadError(PipelineGraphError, IOError): + """Exception raised when a serialized PipelineGraph cannot be read.""" + + +class TaskNotImportedError(PipelineGraphError): + """Exception raised when accessing an attribute of a graph or graph node + that is not available unless the task class has been imported and + configured. + """ + + +class EdgesChangedError(PipelineGraphError): + """Exception raised when the edges in one version of a pipeline graph + are not consistent with those in another, but they were expected to be. + """ + + +class PipelineGraphExceptionSafetyError(PipelineGraphError): + """Exception raised when a PipelineGraph method could not provide strong + exception safety, and the graph may have been left in an inconsistent + state. + + The originating exception is always chained when this exception is raised. + """ diff --git a/python/lsst/pipe/base/pipeline_graph/_mapping_views.py b/python/lsst/pipe/base/pipeline_graph/_mapping_views.py new file mode 100644 index 00000000..6f12d42c --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_mapping_views.py @@ -0,0 +1,197 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +from collections.abc import Iterable, Iterator, Mapping +from typing import Any, ClassVar, Sequence, TypeVar, cast, overload + +import networkx + +from ._dataset_types import DatasetTypeNode +from ._exceptions import UnresolvedGraphError +from ._nodes import NodeKey, NodeType +from ._tasks import TaskInitNode, TaskNode + +_N = TypeVar("_N", covariant=True) +_T = TypeVar("_T") + + +class MappingView(Mapping[str, _N]): + """Base class for mapping views into nodes of certain types in a + `PipelineGraph`. + + + Parameters + ---------- + parent_xgraph : `networkx.MultiDiGraph` + Backing networkx graph for the `PipelineGraph` instance. + + Notes + ----- + Instances should only be constructed by `PipelineGraph` and its helper + classes. + + Iteration order is topologically sorted if and only if the backing + `PipelineGraph` has been sorted since its last modification. + """ + + def __init__(self, parent_xgraph: networkx.MultiDiGraph) -> None: + self._parent_xgraph = parent_xgraph + self._keys: list[str] | None = None + + _NODE_TYPE: ClassVar[NodeType] # defined by derived classes + + def __contains__(self, key: object) -> bool: + # The given key may not be a str, but if it isn't it'll just fail the + # check, which is what we want anyway. + return NodeKey(self._NODE_TYPE, cast(str, key)) in self._parent_xgraph + + def __iter__(self) -> Iterator[str]: + if self._keys is None: + self._keys = self._make_keys(self._parent_xgraph) + return iter(self._keys) + + def __getitem__(self, key: str) -> _N: + return self._parent_xgraph.nodes[NodeKey(self._NODE_TYPE, key)]["instance"] + + def __len__(self) -> int: + if self._keys is None: + self._keys = self._make_keys(self._parent_xgraph) + return len(self._keys) + + def __repr__(self) -> str: + return f"{type(self).__name__}({self!s})" + + def __str__(self) -> str: + return f"{{{', '.join(iter(self))}}}" + + def _reorder(self, parent_keys: Sequence[NodeKey]) -> None: + """Set this view's iteration order according to the given iterable of + parent keys. + + Parameters + ---------- + parent_keys : `~collections.abc.Sequence` [ `NodeKey` ] + Superset of the keys in this view, in the new order. + """ + self._keys = self._make_keys(parent_keys) + + def _reset(self) -> None: + """Reset all cached content. + + This should be called by the parent graph after any changes that could + invalidate the view, causing it to be reconstructed when next + requested. + """ + self._keys = None + + def _make_keys(self, parent_keys: Iterable[NodeKey]) -> list[str]: + """Make a sequence of keys for this view from an iterable of parent + keys. + + Parameters + ---------- + parent_keys : `~collections.abc.Iterable` [ `NodeKey` ] + Superset of the keys in this view. + """ + return [str(k) for k in parent_keys if k.node_type is self._NODE_TYPE] + + +class TaskMappingView(MappingView[TaskNode]): + """A mapping view of the tasks in a `PipelineGraph`. + + Notes + ----- + Mapping keys are task labels and values are `TaskNode` instances. + Iteration order is topological if and only if the `PipelineGraph` has been + sorted since its last modification. + """ + + _NODE_TYPE = NodeType.TASK + + +class TaskInitMappingView(MappingView[TaskInitNode]): + """A mapping view of the nodes representing task initialization in a + `PipelineGraph`. + + Notes + ----- + Mapping keys are task labels and values are `TaskInitNode` instances. + Iteration order is topological if and only if the `PipelineGraph` has been + sorted since its last modification. + """ + + _NODE_TYPE = NodeType.TASK_INIT + + +class DatasetTypeMappingView(MappingView[DatasetTypeNode]): + """A mapping view of the nodes representing task initialization in a + `PipelineGraph`. + + Notes + ----- + Mapping keys are parent dataset type names and values are `DatasetTypeNode` + instances, but values are only available for nodes that have been resolved + (see `PipelineGraph.resolve`). Attempting to access an unresolved value + will result in `UnresolvedGraphError` being raised. Keys for unresolved + nodes are always present and iterable. + + Iteration order is topological if and only if the `PipelineGraph` has been + sorted since its last modification. + """ + + _NODE_TYPE = NodeType.DATASET_TYPE + + def __getitem__(self, key: str) -> DatasetTypeNode: + if (result := super().__getitem__(key)) is None: + raise UnresolvedGraphError(f"Node for dataset type {key!r} has not been resolved.") + return result + + def is_resolved(self, key: str) -> bool: + """Test whether a node has been resolved.""" + return super().__getitem__(key) is not None + + @overload + def get_if_resolved(self, key: str) -> DatasetTypeNode | None: + ... # pragma: nocover + + @overload + def get_if_resolved(self, key: str, default: _T) -> DatasetTypeNode | _T: + ... # pragma: nocover + + def get_if_resolved(self, key: str, default: Any = None) -> DatasetTypeNode | Any: + """Get a node or return a default if it has not been resolved. + + Parameters + ---------- + key : `str` + Parent dataset type name. + default + Value to return if this dataset type has not been resolved. + + Raises + ------ + KeyError + Raised if the node is not present in the graph at all. + """ + if (result := super().__getitem__(key)) is None: + return default # type: ignore + return result diff --git a/python/lsst/pipe/base/pipeline_graph/_nodes.py b/python/lsst/pipe/base/pipeline_graph/_nodes.py new file mode 100644 index 00000000..b9ec00fc --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_nodes.py @@ -0,0 +1,85 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ( + "NodeKey", + "NodeType", +) + +import enum +from typing import NamedTuple + + +class NodeType(enum.Enum): + """Enumeration of the types of nodes in a PipelineGraph.""" + + DATASET_TYPE = 0 + TASK_INIT = 1 + TASK = 2 + + @property + def bipartite(self) -> int: + """The integer used as the "bipartite" key in networkx exports of a + `PipelineGraph`. + + This key is used by the `networkx.algorithms.bipartite` module. + """ + return int(self is not NodeType.DATASET_TYPE) + + def __lt__(self, other: NodeType) -> bool: + # We define __lt__ only to be able to provide deterministic tiebreaking + # on top of topological ordering of `PipelineGraph`` and views thereof. + return self.value < other.value + + +class NodeKey(NamedTuple): + """A special key type for nodes in networkx graphs. + + Notes + ----- + Using a tuple for the key allows tasks labels and dataset type names with + the same string value to coexist in the graph. These only rarely appear in + `PipelineGraph` public interfaces; when the node type is implicit, bare + `str` task labels or dataset type names are used instead. + + NodeKey objects stringify to just their name, which is used both as a way + to convert to the `str` objects used in the main public interface and as an + easy way to usefully stringify containers returned directly by networkx + algorithms (especially in error messages). Note that this requires `repr`, + not just `str`, because Python builtin containers always use `repr` on + their items, even in their implementations for `str`. + """ + + node_type: NodeType + """Node type enum for this key.""" + + name: str + """Task label or dataset type name. + + This is always the parent dataset type name for component dataset types. + """ + + def __repr__(self) -> str: + return self.name + + def __str__(self) -> str: + return self.name diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py new file mode 100644 index 00000000..5e8ea5c1 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py @@ -0,0 +1,1389 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ("PipelineGraph",) + +import gzip +import itertools +import json +from collections.abc import Iterable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, BinaryIO, Literal, TypeVar, cast + +import networkx +import networkx.algorithms.bipartite +import networkx.algorithms.dag +from lsst.daf.butler import DataCoordinate, DataId, DimensionGraph, DimensionUniverse, Registry +from lsst.resources import ResourcePath, ResourcePathExpression + +from ._dataset_types import DatasetTypeNode +from ._edges import Edge, ReadEdge, WriteEdge +from ._exceptions import ( + EdgesChangedError, + PipelineDataCycleError, + PipelineGraphError, + PipelineGraphExceptionSafetyError, + UnresolvedGraphError, +) +from ._mapping_views import DatasetTypeMappingView, TaskMappingView +from ._nodes import NodeKey, NodeType +from ._task_subsets import TaskSubset +from ._tasks import TaskInitNode, TaskNode, _TaskNodeImportedData + +if TYPE_CHECKING: + from ..config import PipelineTaskConfig + from ..connections import PipelineTaskConnections + from ..pipeline import TaskDef + from ..pipelineTask import PipelineTask + + +_G = TypeVar("_G", bound=networkx.DiGraph | networkx.MultiDiGraph) + + +class PipelineGraph: + """A graph representation of fully-configured pipeline. + + `PipelineGraph` instances are typically constructed by calling + `.Pipeline.to_graph`, but in rare cases constructing and then populating + an empty one may be preferable. + + Parameters + ---------- + description : `str`, optional + String description for this pipeline. + universe : `lsst.daf.butler.DimensionUniverse`, optional + Definitions for all butler dimensions. If not provided, some + attributes will not be available until `resolve` is called. + data_id : `lsst.daf.butler.DataCoordinate` or other data ID, optional + Data ID that represents a constraint on all quanta generated by this + pipeline. This typically just holds the instrument constraint included + in the pipeline definition, if there was one. + """ + + def __init__( + self, + *, + description: str = "", + universe: DimensionUniverse | None = None, + data_id: DataId | None = None, + ) -> None: + self._init_from_args( + xgraph=None, + sorted_keys=None, + task_subsets=None, + description=description, + universe=universe, + data_id=data_id, + ) + + def _init_from_args( + self, + xgraph: networkx.MultiDiGraph | None, + sorted_keys: Sequence[NodeKey] | None, + task_subsets: dict[str, TaskSubset] | None, + description: str, + universe: DimensionUniverse | None, + data_id: DataId | None, + ) -> None: + """Initialize the graph with possibly-nontrivial arguments. + + Parameters + ---------- + xgraph : `networkx.MultiDiGraph` or `None` + The backing networkx graph, or `None` to create an empty one. + This graph has `NodeKey` instances for nodes and the same structure + as the graph exported by `make_xgraph`, but its nodes and edges + have a single ``instance`` attribute that holds a `TaskNode`, + `TaskInitNode`, `DatasetTypeNode` (or `None`), `ReadEdge`, or + `WriteEdge` instance. + sorted_keys : `Sequence` [ `NodeKey` ] or `None` + Topologically sorted sequence of node keys, or `None` if the graph + is not sorted. + task_subsets : `dict` [ `str`, `TaskSubset` ] + Labeled subsets of tasks. Values must be constructed with + ``xgraph`` as their parent graph. + description : `str` + String description for this pipeline. + universe : `lsst.daf.butler.DimensionUniverse` or `None` + Definitions of all dimensions. + data_id : `lsst.daf.butler.DataCoordinate` or other data ID mapping. + Data ID that represents a constraint on all quanta generated from + this pipeline. + + Notes + ----- + Only empty `PipelineGraph` instances should be constructed directly by + users, which sets the signature of ``__init__`` itself, but methods on + `PipelineGraph` and its helper classes need to be able to create them + with state. Those methods can call this after calling ``__new__`` + manually, skipping ``__init__``. + + `PipelineGraph` mutator methods provide strong exception safety (the + graph is left unchanged when an exception is raised and caught) unless + the exception raised is `PipelineGraphExceptionSafetyError`. + """ + self._xgraph = xgraph if xgraph is not None else networkx.MultiDiGraph() + self._sorted_keys: Sequence[NodeKey] | None = None + self._task_subsets = task_subsets if task_subsets is not None else {} + self._description = description + self._tasks = TaskMappingView(self._xgraph) + self._dataset_types = DatasetTypeMappingView(self._xgraph) + self._raw_data_id: dict[str, Any] + if isinstance(data_id, DataCoordinate): + universe = data_id.universe + self._raw_data_id = data_id.byName() + elif data_id is None: + self._raw_data_id = {} + else: + self._raw_data_id = dict(data_id) + self._universe = universe + if sorted_keys is not None: + self._reorder(sorted_keys) + + def __repr__(self) -> str: + return f"{type(self).__name__}({self.description!r}, tasks={self.tasks!s})" + + @property + def description(self) -> str: + """String description for this pipeline.""" + return self._description + + @description.setter + def description(self, value: str) -> None: + # Docstring in setter. + self._description = value + + @property + def universe(self) -> DimensionUniverse | None: + """Definitions for all butler dimensions.""" + return self._universe + + @property + def data_id(self) -> DataCoordinate: + """Data ID that represents a constraint on all quanta generated from + this pipeline. + + This is may not be available unless `universe` is not `None`. + """ + return DataCoordinate.standardize(self._raw_data_id, universe=self.universe) + + @property + def tasks(self) -> TaskMappingView: + """A mapping view of the tasks in the graph. + + This mapping has `str` task label keys and `TaskNode` values. Iteration + is topologically and deterministically ordered if and only if `sort` + has been called since the last modification to the graph. + """ + return self._tasks + + @property + def dataset_types(self) -> DatasetTypeMappingView: + """A mapping view of the dataset types in the graph. + + This mapping has `str` parent dataset type name keys, but only provides + access to its `DatasetTypeNode` values if `resolve` has been called + since the last modification involving a task that uses a dataset type. + See `DatasetTypeMappingView` for details. + """ + return self._dataset_types + + @property + def task_subsets(self) -> Mapping[str, TaskSubset]: + """A mapping of all labeled subsets of tasks. + + Keys are subset labels, values are sets of task labels. See + `TaskSubset` for more information. + + Use `add_task_subset` to add a new subset. The subsets themselves may + be modified in-place. + """ + return self._task_subsets + + def iter_edges(self, init: bool = False) -> Iterator[Edge]: + """Iterate over edges in the graph. + + Parameters + ---------- + init : `bool`, optional + If `True` (`False` is default) iterate over the edges between task + initialization node and init input/output dataset types, instead of + the runtime task nodes and regular input/output/prerequisite + dataset types. + + Returns + ------- + edges : `~collections.abc.Iterator` [ `Edge` ] + A lazy iterator over `Edge` (`WriteEdge` or `ReadEdge`) instances. + + Notes + ----- + This method always returns _either_ init edges or runtime edges, never + both. The full (internal) graph that contains both also includes a + special edge that connects each task init node to its runtime node; + that is also never returned by this method, since it is never a part of + the init-only or runtime-only subgraphs. + """ + edge: Edge + for _, _, edge in self._xgraph.edges(data="instance"): + if edge is not None and edge.is_init == init: + yield edge + + def iter_nodes( + self, + ) -> Iterator[ + tuple[Literal[NodeType.TASK_INIT], str, TaskInitNode] + | tuple[Literal[NodeType.TASK], str, TaskInitNode] + | tuple[Literal[NodeType.DATASET_TYPE], str, DatasetTypeNode | None] + ]: + """Iterate over nodes in the graph. + + Returns + ------- + nodes : `~collections.abc.Iterator` [ `tuple` ] + A lazy iterator over all of the nodes in the graph. Each yielded + element is a tuple of: + + - the node type enum value (`NodeType`); + - the string name for the node (task label or parent dataset type + name); + - the node value (`TaskNode`, `TaskInitNode`, `DatasetTypeNode`, + or `None` for dataset type nodes that have not been resolved). + """ + key: NodeKey + if self._sorted_keys is not None: + for key in self._sorted_keys: + yield key.node_type, key.name, self._xgraph.nodes[key]["instance"] # type: ignore + else: + for key, node in self._xgraph.nodes(data="instance"): + yield key.node_type, key.name, node # type: ignore + + def iter_overall_inputs(self) -> Iterator[tuple[str, DatasetTypeNode | None]]: + """Iterate over all of the dataset types that are consumed but not + produced by the graph. + + Returns + ------- + dataset_types : `~collections.abc.Iterator` [ `tuple` ] + A lazy iterator over the overall-input dataset types (including + overall init inputs and prerequisites). Each yielded element is a + tuple of: + + - the parent dataset type name; + - the resolved `DatasetTypeNode`, or `None` if the dataset type has + - not been resolved. + """ + for generation in networkx.algorithms.dag.topological_generations(self._xgraph): + key: NodeKey + for key in generation: + # While we expect all tasks to have at least one input and + # hence never appear in the first topological generation, that + # is not true of task init nodes. + if key.node_type is NodeType.DATASET_TYPE: + yield key.name, self._xgraph.nodes[key]["instance"] + return + + def make_xgraph(self) -> networkx.MultiDiGraph: + """Export a networkx representation of the full pipeline graph, + including both init and runtime edges. + + Returns + ------- + xgraph : `networkx.MultiDiGraph` + Directed acyclic graph with parallel edges. + + Notes + ----- + The returned graph uses `NodeKey` instances for nodes. Parallel edges + represent the same dataset type appearing in multiple connections for + the same task, and are hence rare. The connection name is used as the + edge key to disambiguate those parallel edges. + + Almost all edges connect dataset type nodes to task or task init nodes + or vice versa, but there is also a special edge that connects each task + init node to its runtime node. The existence of these nodes makes the + graph not quite bipartite, unless its init-only and runtime-only + subgraphs. + + See `TaskNode`, `TaskInitNode`, `DatasetTypeNode`, `ReadEdge`, and + `WriteEdge` for the descriptive node and edge attributes added. + """ + return self._transform_xgraph_state(self._xgraph.copy(), skip_edges=False) + + def make_bipartite_xgraph(self, init: bool = False) -> networkx.MultiDiGraph: + """Return a bipartite networkx representation of just the runtime or + init-time pipeline graph. + + Parameters + ---------- + init : `bool`, optional + If `True` (`False` is default) return the graph of task + initialization nodes and init input/output dataset types, instead + of the graph of runtime task nodes and regular + input/output/prerequisite dataset types. + + Returns + ------- + xgraph : `networkx.MultiDiGraph` + Directed acyclic graph with parallel edges. + + Notes + ----- + The returned graph uses `NodeKey` instances for nodes. Parallel edges + represent the same dataset type appearing in multiple connections for + the same task, and are hence rare. The connection name is used as the + edge key to disambiguate those parallel edges. + + This graph is bipartite because each dataset type node only has edges + that connect it to a task [init] node, and vice versa. + + See `TaskNode`, `TaskInitNode`, `DatasetTypeNode`, `ReadEdge`, and + `WriteEdge` for the descriptive node and edge attributes added. + """ + return self._transform_xgraph_state( + self._make_bipartite_xgraph_internal(init).copy(), skip_edges=False + ) + + def make_task_xgraph(self, init: bool = False) -> networkx.DiGraph: + """Return a networkx representation of just the tasks in the pipeline. + + Parameters + ---------- + init : `bool`, optional + If `True` (`False` is default) return the graph of task + initialization nodes, instead of the graph of runtime task nodes. + + Returns + ------- + xgraph : `networkx.DiGraph` + Directed acyclic graph with no parallel edges. + + Notes + ----- + The returned graph uses `NodeKey` instances for nodes. The dataset + types that link these tasks are not represented at all; edges have no + attributes, and there are no parallel edges. + + See `TaskNode` and `TaskInitNode` for the descriptive node and + attributes added. + """ + bipartite_xgraph = self._make_bipartite_xgraph_internal(init) + task_keys = [ + key + for key, bipartite in bipartite_xgraph.nodes(data="bipartite") + if bipartite == NodeType.TASK.bipartite + ] + return self._transform_xgraph_state( + networkx.algorithms.bipartite.projected_graph(networkx.DiGraph(bipartite_xgraph), task_keys), + skip_edges=True, + ) + + def make_dataset_type_xgraph(self, init: bool = False) -> networkx.DiGraph: + """Return a networkx representation of just the dataset types in the + pipeline. + + Parameters + ---------- + init : `bool`, optional + If `True` (`False` is default) return the graph of init input and + output dataset types, instead of the graph of runtime (input, + output, prerequisite input) dataset types. + + Returns + ------- + xgraph : `networkx.DiGraph` + Directed acyclic graph with no parallel edges. + + Notes + ----- + The returned graph uses `NodeKey` instances for nodes. The tasks that + link these tasks are not represented at all; edges have no attributes, + and there are no parallel edges. + + See `DatasetTypeNode` for the descriptive node and attributes added. + """ + bipartite_xgraph = self._make_bipartite_xgraph_internal(init) + dataset_type_keys = [ + key + for key, bipartite in bipartite_xgraph.nodes(data="bipartite") + if bipartite == NodeType.DATASET_TYPE.bipartite + ] + return self._transform_xgraph_state( + networkx.algorithms.bipartite.projected_graph( + networkx.DiGraph(bipartite_xgraph), dataset_type_keys + ), + skip_edges=True, + ) + + def _make_bipartite_xgraph_internal(self, init: bool) -> networkx.MultiDiGraph: + """Make a bipartite init-only or runtime-only internal subgraph. + + See `make_bipartite_xgraph` for parameters and return values. + + Notes + ----- + This method returns a view of the `PipelineGraph` object's internal + backing graph, and hence should only be called in methods that copy the + result either explicitly or by running a copying algorithm before + returning it to the user. + """ + return self._xgraph.edge_subgraph([edge.key for edge in self.iter_edges(init)]) + + def _transform_xgraph_state(self, xgraph: _G, skip_edges: bool) -> _G: + """Transform networkx graph attributes in-place from the internal + "instance" attributes to the documented exported attributes. + + Parameters + ---------- + xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph` + Graph whose state should be transformed. + skip_edges : `bool` + If `True`, do not transform edge state. + + Returns + ------- + xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph` + The same object passed in, after modification. + + Notes + ----- + This should be called after making a copy of the internal graph but + before any projection down to just task or dataset type nodes, since + it assumes stateful edges. + """ + state: dict[str, Any] + for state in xgraph.nodes.values(): + node_value: TaskInitNode | TaskNode | DatasetTypeNode | None = state.pop("instance") + if node_value is not None: + state.update(node_value._to_xgraph_state()) + if not skip_edges: + for _, _, state in xgraph.edges(data=True): + edge: Edge | None = state.pop("instance", None) + if edge is not None: + state.update(edge._to_xgraph_state()) + return xgraph + + def group_by_dimensions( + self, prerequisites: bool = False + ) -> dict[DimensionGraph, tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]]]: + """Group this graph's tasks and dataset types by their dimensions. + + Parameters + ---------- + prerequisites : `bool`, optional + If `True`, include prerequisite dataset types as well as regular + input and output datasets (including intermediates). + + Returns + ------- + groups : `dict` [ `DimensionGraph`, `tuple` ] + A dictionary of groups keyed by `DimensionGraph`, in which each + value is a tuple of: + + - a `dict` of `TaskNode` instances, keyed by task label + - a `dict` of `DatasetTypeNode` instances, keyed by + dataset type name. + + that have those dimensions. + + Notes + ----- + Init inputs and outputs are always included, but always have empty + dimensions and are hence are all grouped together. + """ + result: dict[DimensionGraph, tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]]] = {} + next_new_value: tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]] = ({}, {}) + for task_label, task_node in self.tasks.items(): + if task_node.dimensions is None: + raise UnresolvedGraphError(f"Task with label {task_label!r} has not been resolved.") + if (group := result.setdefault(task_node.dimensions, next_new_value)) is next_new_value: + next_new_value = ({}, {}) # make new lists for next time + group[0][task_node.label] = task_node + for dataset_type_name, dataset_type_node in self.dataset_types.items(): + if dataset_type_node is None: + raise UnresolvedGraphError(f"Dataset type {dataset_type_name!r} has not been resolved.") + if not dataset_type_node.is_prerequisite or prerequisites: + if ( + group := result.setdefault(dataset_type_node.dataset_type.dimensions, next_new_value) + ) is next_new_value: + next_new_value = ({}, {}) # make new lists for next time + group[1][dataset_type_node.name] = dataset_type_node + return result + + @property + def is_sorted(self) -> bool: + """Whether this graph's tasks and dataset types are topologically + sorted with the exact same deterministic tiebreakers that `sort` would + apply. + + This may perform (and then discard) a full sort if `has_been_sorted` is + `False`. If the goal is to obtain a sorted graph, it is better to just + call `sort` without guarding that with an ``if not graph.is_sorted`` + check. + """ + if self._sorted_keys is not None: + return True + return all( + sorted == unsorted + for sorted, unsorted in zip(networkx.lexicographical_topological_sort(self._xgraph), self._xgraph) + ) + + @property + def has_been_sorted(self) -> bool: + """Whether this graph's tasks and dataset types have been + topologically sorted (with unspecified but deterministic tiebreakers) + since the last modification to the graph. + + This may return `False` if the graph *happens* to be sorted but `sort` + was never called, but it is potentially much faster than `is_sorted`, + which may attempt (and then discard) a full sort if `has_been_sorted` + is `False`. + """ + return self._sorted_keys is not None + + def sort(self) -> None: + """Sort this graph's nodes topologically with deterministic (but + unspecified) tiebreakers. + + This does nothing if the graph is already known to be sorted. + """ + if self._sorted_keys is None: + try: + sorted_keys: Sequence[NodeKey] = list(networkx.lexicographical_topological_sort(self._xgraph)) + except networkx.NetworkXUnfeasible as err: # pragma: no cover + # Should't be possible to get here, because we check for cycles + # when adding tasks, but we guard against it anyway. + cycle = networkx.find_cycle(self._xgraph) + raise PipelineDataCycleError( + f"Cycle detected while attempting to sort graph: {cycle}." + ) from err + self._reorder(sorted_keys) + + def producer_of(self, dataset_type_name: str) -> WriteEdge | None: + """Return the `WriteEdge` that links the producing task to the named + dataset type. + + Parameters + ---------- + dataset_type_name : `str` + Dataset type name. Must not be a component. + + Returns + ------- + edge : `WriteEdge` or `None` + Producing edge or `None` if there isn't one in this graph. + """ + for _, _, edge in self._xgraph.in_edges( + NodeKey(NodeType.DATASET_TYPE, dataset_type_name), data="instance" + ): + return edge + return None + + def consumers_of(self, dataset_type_name: str) -> list[ReadEdge]: + """Return the `ReadEdge` objects that link the named dataset type to + the tasks that consume it. + + Parameters + ---------- + dataset_type_name : `str` + Dataset type name. Must not be a component. + + Returns + ------- + edges : `list` [ `ReadEdge` ] + Edges that connect this dataset type to the tasks that consume it. + """ + return [ + edge + for _, _, edge in self._xgraph.out_edges( + NodeKey(NodeType.DATASET_TYPE, dataset_type_name), data="instance" + ) + ] + + def add_task( + self, + label: str, + task_class: type[PipelineTask], + config: PipelineTaskConfig, + connections: PipelineTaskConnections | None = None, + ) -> TaskNode: + """Add a new task to the graph. + + Parameters + ---------- + label : `str` + Label for the task in the pipeline. + task_class : `type` [ `PipelineTask` ] + Class object for the task. + config : `PipelineTaskConfig` + Configuration for the task. + connections : `PipelineTaskConnections`, optional + Object that describes the dataset types used by the task. If not + provided, one will be constructed from the given configuration. If + provided, it is assumed that ``config`` has already been validated + and frozen. + + Returns + ------- + node : `TaskNode` + The new task node added to the graph. + + Raises + ------ + ValueError + Raised if configuration validation failed when constructing + ``connections``. + PipelineDataCycleError + Raised if the graph is cyclic after this addition. + RuntimeError + Raised if an unexpected exception (which will be chained) occurred + at a stage that may have left the graph in an inconsistent state. + Other exceptions should leave the graph unchanged. + + Notes + ----- + Checks for dataset type consistency and multiple producers do not occur + until `resolve` is called, since the resolution depends on both the + state of the data repository and all contributing tasks. + + Adding new tasks removes any existing resolutions of all dataset types + it references and marks the graph as unsorted. It is most effiecient + to add all tasks up front and only then resolve and/or sort the graph. + """ + key = NodeKey(NodeType.TASK, label) + init_key = NodeKey(NodeType.TASK_INIT, label) + task_node = TaskNode._from_imported_data( + key, + init_key, + _TaskNodeImportedData.configure(label, task_class, config, connections), + universe=self.universe, + ) + self.add_task_nodes([task_node]) + return task_node + + def add_task_nodes(self, nodes: Iterable[TaskNode]) -> None: + """Add one or more existing task nodes to the graph. + + Parameters + ---------- + nodes : `~collections.abc.Iterable` [ `TaskNode` ] + Iterable of task nodes to add. If any tasks have resolved + dimensions, they must have the same dimension universe as the rest + of the graph. + + Raises + ------ + PipelineDataCycleError + Raised if the graph is cyclic after this addition. + + Notes + ----- + Checks for dataset type consistency and multiple producers do not occur + until `resolve` is called, since the resolution depends on both the + state of the data repository and all contributing tasks. + + Adding new tasks removes any existing resolutions of all dataset types + it references and marks the graph as unsorted. It is most effiecient + to add all tasks up front and only then resolve and/or sort the graph. + """ + node_data: list[tuple[NodeKey, dict[str, Any]]] = [] + edge_data: list[tuple[NodeKey, NodeKey, str, dict[str, Any]]] = [] + for task_node in nodes: + task_node = task_node._resolved(self._universe) + node_data.append( + (task_node.key, {"instance": task_node, "bipartite": task_node.key.node_type.bipartite}) + ) + node_data.append( + ( + task_node.init.key, + {"instance": task_node.init, "bipartite": task_node.init.key.node_type.bipartite}, + ) + ) + # Convert the edge objects attached to the task node to networkx. + for read_edge in task_node.init.iter_all_inputs(): + self._append_graph_data_from_edge(node_data, edge_data, read_edge) + for write_edge in task_node.init.iter_all_outputs(): + self._append_graph_data_from_edge(node_data, edge_data, write_edge) + for read_edge in task_node.iter_all_inputs(): + self._append_graph_data_from_edge(node_data, edge_data, read_edge) + for write_edge in task_node.iter_all_outputs(): + self._append_graph_data_from_edge(node_data, edge_data, write_edge) + # Add a special edge (with no Edge instance) that connects the + # TaskInitNode to the runtime TaskNode. + edge_data.append((task_node.init.key, task_node.key, Edge.INIT_TO_TASK_NAME, {"instance": None})) + if not node_data and not edge_data: + return + # Checks and preparation complete; time to start the actual + # modification, during which it's hard to provide strong exception + # safety. Start by resetting the sort ordering, if there is one. + self._reset() + try: + self._xgraph.add_nodes_from(node_data) + self._xgraph.add_edges_from(edge_data) + if not networkx.algorithms.dag.is_directed_acyclic_graph(self._xgraph): + cycle = networkx.find_cycle(self._xgraph) + raise PipelineDataCycleError(f"Cycle detected while adding tasks: {cycle}.") + except Exception: + # First try to roll back our changes. + try: + self._xgraph.remove_edges_from(edge_data) + self._xgraph.remove_nodes_from(key for key, _ in node_data) + except Exception as err: # pragma: no cover + # There's no known way to get here, but we want to make it + # clear it's a big problem if we do. + raise PipelineGraphExceptionSafetyError( + "Error while attempting to revert PipelineGraph modification has left the graph in " + "an inconsistent state." + ) from err + # Successfully rolled back; raise the original exception. + raise + + def reconfigure_tasks( + self, + *args: tuple[str, PipelineTaskConfig], + check_edges_unchanged: bool = False, + assume_edges_unchanged: bool = False, + **kwargs: PipelineTaskConfig, + ) -> None: + """Update the configuration for one or more tasks. + + Parameters + ---------- + *args : `tuple` [ `str`, `.PipelineTaskConfig` ] + Positional arguments are each a 2-tuple of task label and new + config object. Note that the same arguments may also be passed as + ``**kwargs``, which is usually more readable, but task labels in + ``*args`` are not required to be valid Python identifiers. + check_edges_unchanged : `bool`, optional + If `True`, require the edges (connections) of the modified tasks to + remain unchanged after the configuration updates, and verify that + this is the case. + assume_edges_unchanged : `bool`, optional + If `True`, the caller declares that the edges (connections) of the + modified tasks will remain unchanged after the configuration + updates, and that it is unnecessary to check this. + **kwargs : `.PipelineTaskConfig` + New config objects or overrides to apply to copies of the current + config objects, with task labels as the keywords. + + Raises + ------ + ValueError + Raised if ``assume_edges_unchanged`` and ``check_edges_unchanged`` + are both `True`, or if the same task appears twice. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change. + + Notes + ----- + If reconfiguring a task causes its edges to change, any dataset type + nodes connected to that task (not just those whose edges have changed!) + will be unresolved. + """ + new_configs: dict[str, PipelineTaskConfig] = {} + for task_label, config_update in itertools.chain(args, kwargs.items()): + if new_configs.setdefault(task_label, config_update) is not config_update: + raise ValueError(f"Config for {task_label!r} provided more than once.") + updates = { + task_label: self.tasks[task_label]._reconfigured(config, rebuild=not assume_edges_unchanged) + for task_label, config in new_configs.items() + } + self._replace_task_nodes( + updates, + check_edges_unchanged=check_edges_unchanged, + assume_edges_unchanged=assume_edges_unchanged, + message_header=( + "Unexpected change in edges for task {task_label!r} from original config (A) to " + "new configs (B):" + ), + ) + + def remove_tasks( + self, labels: Iterable[str], drop_from_subsets: bool = True + ) -> list[tuple[TaskNode, set[str]]]: + """Remove one or more tasks from the graph. + + Parameters + ---------- + labels : `~collections.abc.Iterable` [ `str` ] + Iterable of the labels of the tasks to remove. + drop_from_subsets : `bool`, optional + If `True`, drop each removed task from any subset in which it + currently appears. If `False`, raise `PipelineGraphError` if any + such subsets exist. + + Returns + ------- + nodes_and_subsets : `list` [ `tuple` [ `TaskNode`, `set` [ `str` ] ] ] + List of nodes removed and the labels of task subsets that + referenced them. + + Raises + ------ + PipelineGraphError + Raised if ``drop_from_subsets`` is `False` and the task is still + part of one or more subsets. + + Notes + ----- + Removing a task will cause dataset nodes with no other referencing + tasks to be removed. Any other dataset type nodes referenced by a + removed task will be reset to an "unresolved" state. + """ + task_nodes_and_subsets = [] + dataset_types: set[NodeKey] = set() + nodes_to_remove = set() + for label in labels: + task_node: TaskNode = self._xgraph.nodes[NodeKey(NodeType.TASK, label)]["instance"] + # Find task subsets that reference this task. + referencing_subsets = { + subset_label + for subset_label, task_subset in self.task_subsets.items() + if label in task_subset + } + if not drop_from_subsets and referencing_subsets: + raise PipelineGraphError( + f"Task {label!r} is still referenced by subset(s) {referencing_subsets}." + ) + task_nodes_and_subsets.append((task_node, referencing_subsets)) + # Find dataset types referenced by this task. + dataset_types.update(self._xgraph.predecessors(task_node.key)) + dataset_types.update(self._xgraph.successors(task_node.key)) + dataset_types.update(self._xgraph.predecessors(task_node.init.key)) + dataset_types.update(self._xgraph.successors(task_node.init.key)) + # Since there's an edge between the task and its init node, we'll + # have added those two nodes here, too, and we don't want that. + dataset_types.remove(task_node.init.key) + dataset_types.remove(task_node.key) + # Mark the task node and its init node for removal from the graph. + nodes_to_remove.add(task_node.key) + nodes_to_remove.add(task_node.init.key) + # Process the referenced datasets to see which ones are orphaned and + # need to be removed vs. just unresolved. + nodes_to_unresolve = [] + for dataset_type_key in dataset_types: + related_tasks = set() + related_tasks.update(self._xgraph.predecessors(dataset_type_key)) + related_tasks.update(self._xgraph.successors(dataset_type_key)) + related_tasks.difference_update(nodes_to_remove) + if not related_tasks: + nodes_to_remove.add(dataset_type_key) + else: + nodes_to_unresolve.append(dataset_type_key) + # Checks and preparation complete; time to start the actual + # modification, during which it's hard to provide strong exception + # safety. Start by resetting the sort ordering. + self._reset() + try: + for dataset_type_key in nodes_to_unresolve: + self._xgraph.nodes[dataset_type_key]["instance"] = None + for task_node, referencing_subsets in task_nodes_and_subsets: + for subset_label in referencing_subsets: + self._task_subsets[subset_label].remove(task_node.label) + self._xgraph.remove_nodes_from(nodes_to_remove) + except Exception as err: # pragma: no cover + # There's no known way to get here, but we want to make it + # clear it's a big problem if we do. + raise PipelineGraphExceptionSafetyError( + "Error during task removal has left the graph in an inconsistent state." + ) from err + return task_nodes_and_subsets + + def add_task_subset(self, subset_label: str, task_labels: Iterable[str], description: str = "") -> None: + """Add a label for a set of tasks that are already in the pipeline. + + Parameters + ---------- + subset_label : `str` + Label for this set of tasks. + task_labels : `~collections.abc.Iterable` [ `str` ] + Labels of the tasks to include in the set. All must already be + included in the graph. + description : `str`, optional + String description to associate with this label. + """ + subset = TaskSubset(self._xgraph, subset_label, set(task_labels), description) + self._task_subsets[subset_label] = subset + + def remove_task_subset(self, subset_label: str) -> None: + """Remove a labeled set of tasks.""" + del self._task_subsets[subset_label] + + def copy(self) -> PipelineGraph: + """Return a copy of this graph that copies all mutable state.""" + xgraph = self._xgraph.copy() + result = PipelineGraph.__new__(PipelineGraph) + result._init_from_args( + xgraph, + self._sorted_keys, + task_subsets={ + k: TaskSubset(xgraph, v.label, set(v._members), v.description) + for k, v in self._task_subsets.items() + }, + description=self._description, + universe=self.universe, + data_id=self._raw_data_id, + ) + return result + + def __copy__(self) -> PipelineGraph: + # Fully shallow copies are dangerous; we don't want shared mutable + # state to lead to broken class invariants. + return self.copy() + + def __deepcopy__(self, memo: dict) -> PipelineGraph: + # Genuine deep copies are unnecessary, since we should only ever care + # that mutable state is copied. + return self.copy() + + def import_and_configure( + self, check_edges_unchanged: bool = False, assume_edges_unchanged: bool = False + ) -> None: + """Import the `PipelineTask` classes referenced by all task nodes and + update those nodes accordingly. + + Parameters + ---------- + check_edges_unchanged : `bool`, optional + If `True`, require the edges (connections) of the modified tasks to + remain unchanged after importing and configuring each task, and + verify that this is the case. + assume_edges_unchanged : `bool`, optional + If `True`, the caller declares that the edges (connections) of the + modified tasks will remain unchanged importing and configuring each + task, and that it is unnecessary to check this. + + Raises + ------ + ValueError + Raised if ``assume_edges_unchanged`` and ``check_edges_unchanged`` + are both `True`, or if a full config is provided for a task after + another full config or an override has already been provided. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change. + + Notes + ----- + This method shouldn't need to be called unless the graph was + deserialized without importing and configuring immediately, which is + not the default behavior (but it can greatly speed up deserialization). + If all tasks have already been imported this does nothing. + + Importing and configuring a task can change its + `~TaskNode.task_class_name` or `~TaskClass.get_config_str` output, + usually because the software used to read a serialized graph is newer + than the software used to write it (e.g. a new config option has been + added, or the task was moved to a new module with a forwarding alias + left behind). These changes are allowed by ``check=True``. + + If importing and configuring a task causes its edges to change, any + dataset type nodes linked to those edges will be reset to the + unresolved state. + """ + rebuild = check_edges_unchanged or not assume_edges_unchanged + updates: dict[str, TaskNode] = {} + node_key: NodeKey + for node_key, node_state in self._xgraph.nodes.items(): + if node_key.node_type is NodeType.TASK: + task_node: TaskNode = node_state["instance"] + new_task_node = task_node._imported_and_configured(rebuild) + if new_task_node is not task_node: + updates[task_node.label] = new_task_node + self._replace_task_nodes( + updates, + check_edges_unchanged=check_edges_unchanged, + assume_edges_unchanged=assume_edges_unchanged, + message_header=( + "In task with label {task_label!r}, persisted edges (A)" + "differ from imported and configured edges (B):" + ), + ) + + def resolve(self, registry: Registry) -> None: + """Resolve all dimensions and dataset types and check them for + consistency. + + Resolving a graph also causes it to be sorted. + + Parameters + ---------- + registry : `lsst.daf.butler.Registry` + Client for the data repository to resolve against. + + Notes + ----- + The `universe` attribute are set to ``registry.dimensions`` and used to + set all `TaskNode.dimensions` attributes. Dataset type nodes are + resolved by first looking for a registry definition, then using the + producing task's definition, then looking for consistency between all + consuming task definitions. + + Raises + ------ + ConnectionTypeConsistencyError + Raised if a prerequisite input for one task appears as a different + kind of connection in any other task. + DuplicateOutputError + Raised if multiple tasks have the same dataset type as an output. + IncompatibleDatasetTypeError + Raised if different tasks have different definitions of a dataset + type. Different but compatible storage classes are permitted. + MissingDatasetTypeError + Raised if a dataset type definition is required to exist in the + data repository but none was found. This should only occur for + dataset types that are not produced by a task in the pipeline and + are consumed with different storage classes or as components by + tasks in the pipeline. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change after import and reconfiguration. + """ + node_key: NodeKey + updates: dict[NodeKey, TaskNode | DatasetTypeNode] = {} + for node_key, node_state in self._xgraph.nodes.items(): + match node_key.node_type: + case NodeType.TASK: + task_node: TaskNode = node_state["instance"] + new_task_node = task_node._resolved(registry.dimensions) + if new_task_node is not task_node: + updates[node_key] = new_task_node + case NodeType.DATASET_TYPE: + dataset_type_node: DatasetTypeNode | None = node_state["instance"] + new_dataset_type_node = DatasetTypeNode._from_edges( + node_key, self._xgraph, registry, previous=dataset_type_node + ) + if new_dataset_type_node is not dataset_type_node: + updates[node_key] = new_dataset_type_node + try: + for node_key, node_value in updates.items(): + self._xgraph.nodes[node_key]["instance"] = node_value + except Exception as err: # pragma: no cover + # There's no known way to get here, but we want to make it + # clear it's a big problem if we do. + raise PipelineGraphExceptionSafetyError( + "Error during dataset type resolution has left the graph in an inconsistent state." + ) from err + self.sort() + self._universe = registry.dimensions + + @classmethod + def read_stream( + cls, + stream: BinaryIO, + import_and_configure: bool = True, + check_edges_unchanged: bool = False, + assume_edges_unchanged: bool = False, + ) -> PipelineGraph: + """Read a serialized `PipelineGraph` from a file-like object. + + Parameters + ---------- + stream : `BinaryIO` + File-like object opened for binary reading, containing + gzip-compressed JSON. + import_and_configure : `bool`, optional + If `True`, import and configure all tasks immediately (see the + `import_and_configure` method). If `False`, some `TaskNode` and + `TaskInitNode` attributes will not be available, but reading may be + much faster. + check_edges_unchanged : `bool`, optional + Forwarded to `import_and_configure` after reading. + assume_edges_unchanged : `bool`, optional + Forwarded to `import_and_configure` after reading. + + Returns + ------- + graph : `PipelineGraph` + Deserialized pipeline graph. + + Raises + ------ + PipelineGraphReadError + Raised if the serialized `PipelineGraph` is not self-consistent. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change after import and reconfiguration. + """ + from .io import SerializedPipelineGraph + + with gzip.open(stream, "rb") as uncompressed_stream: + data = json.load(uncompressed_stream) + serialized_graph = SerializedPipelineGraph.parse_obj(data) + return serialized_graph.deserialize( + import_and_configure=import_and_configure, + check_edges_unchanged=check_edges_unchanged, + assume_edges_unchanged=assume_edges_unchanged, + ) + + @classmethod + def read_uri( + cls, + uri: ResourcePathExpression, + import_and_configure: bool = True, + check_edges_unchanged: bool = False, + assume_edges_unchanged: bool = False, + ) -> PipelineGraph: + """Read a serialized `PipelineGraph` from a file at a URI. + + Parameters + ---------- + uri : convertible to `lsst.resources.ResourcePath` + URI to a gzip-compressed JSON file containing a serialized pipeline + graph. + import_and_configure : `bool`, optional + If `True`, import and configure all tasks immediately (see + the `import_and_configure` method). If `False`, some `TaskNode` + and `TaskInitNode` attributes will not be available, but reading + may be much faster. + check_edges_unchanged : `bool`, optional + Forwarded to `import_and_configure` after reading. + assume_edges_unchanged : `bool`, optional + Forwarded to `import_and_configure` after reading. + + Returns + ------- + graph : `PipelineGraph` + Deserialized pipeline graph. + + Raises + ------ + PipelineGraphReadError + Raised if the serialized `PipelineGraph` is not self-consistent. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change after import and reconfiguration. + """ + uri = ResourcePath(uri) + with uri.open("rb") as stream: + return cls.read_stream( + cast(BinaryIO, stream), + import_and_configure=import_and_configure, + check_edges_unchanged=check_edges_unchanged, + assume_edges_unchanged=assume_edges_unchanged, + ) + + def write_stream(self, stream: BinaryIO) -> None: + """Write the pipeline to a file-like object. + + Parameters + ---------- + stream + File-like object opened for binary writing. + + Notes + ----- + The file format is gzipped JSON, and is intended to be human-readable, + but it should not be considered a stable public interface for outside + code, which should always use `PipelineGraph` methods (or at least the + `io.SerializedPipelineGraph` class) to read these files. + """ + from .io import SerializedPipelineGraph + + with gzip.open(stream, mode="wb") as compressed_stream: + compressed_stream.write( + SerializedPipelineGraph.serialize(self).json(exclude_defaults=True, indent=2).encode("utf-8") + ) + + def write_uri(self, uri: ResourcePathExpression) -> None: + """Write the pipeline to a file given a URI. + + Parameters + ---------- + uri : convertible to `lsst.resources.ResourcePath` + URI to write to . May have ``.json.gz`` or no extension (which + will cause a ``.json.gz`` extension to be added). + + Notes + ----- + The file format is gzipped JSON, and is intended to be human-readable, + but it should not be considered a stable public interface for outside + code, which should always use `PipelineGraph` methods (or at least the + `io.SerializedPipelineGraph` class) to read these files. + """ + uri = ResourcePath(uri) + extension = uri.getExtension() + if not extension: + uri = uri.updatedExtension(".json.gz") + elif extension != ".json.gz": + raise ValueError("Expanded pipeline files should always have a .json.gz extension.") + with uri.open(mode="wb") as stream: + self.write_stream(cast(BinaryIO, stream)) + + def _iter_task_defs(self) -> Iterator[TaskDef]: + """Iterate over this pipeline as a sequence of `TaskDef` instances. + + Notes + ----- + This is a package-private method intended to aid in the transition to a + codebase more fully integrated with the `PipelineGraph` class, in which + both `TaskDef` and `PipelineDatasetTypes` are expected to go away, and + much of the functionality on the `Pipeline` class will be moved to + `PipelineGraph` as well. + + Raises + ------ + TaskNotImportedError + Raised if `TaskNode.is_imported` is `False` for any task. + """ + from ..pipeline import TaskDef + + for node in self._tasks.values(): + yield TaskDef( + config=node.config, + taskClass=node.task_class, + label=node.label, + connections=node._get_imported_data().connections, + ) + + def _replace_task_nodes( + self, + updates: Mapping[str, TaskNode], + check_edges_unchanged: bool, + assume_edges_unchanged: bool, + message_header: str, + ) -> None: + """Replace task nodes and update edges and dataset type nodes + accordingly. + + Parameters + ---------- + updates : `Mapping` [ `str`, `TaskNode` ] + New task nodes with task label keys. All keys must be task labels + that are already present in the graph. + check_edges_unchanged : `bool`, optional + If `True`, require the edges (connections) of the modified tasks to + remain unchanged after importing and configuring each task, and + verify that this is the case. + assume_edges_unchanged : `bool`, optional + If `True`, the caller declares that the edges (connections) of the + modified tasks will remain unchanged importing and configuring each + task, and that it is unnecessary to check this. + message_header : `str` + Template for `str.format` with a single ``task_label`` placeholder + to use as the first line in `EdgesChangedError` messages that show + the differences between new task edges and old task edges. Should + include the fact that the rest of the message will refer to the old + task as "A" and the new task as "B", and end with a colon. + + Raises + ------ + ValueError + Raised if ``assume_edges_unchanged`` and ``check_edges_unchanged`` + are both `True`, or if a full config is provided for a task after + another full config or an override has already been provided. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change. + """ + deep: dict[str, TaskNode] = {} + shallow: dict[str, TaskNode] = {} + if assume_edges_unchanged: + if check_edges_unchanged: + raise ValueError("Cannot simultaneously assume and check that edges have not changed.") + shallow.update(updates) + else: + for task_label, new_task_node in updates.items(): + old_task_node = self.tasks[task_label] + messages = old_task_node.diff_edges(new_task_node) + if messages: + if check_edges_unchanged: + messages.insert(0, message_header.format(task_label=task_label)) + raise EdgesChangedError("\n".join(messages)) + else: + deep[task_label] = new_task_node + else: + shallow[task_label] = new_task_node + try: + if deep: + removed = self.remove_tasks(deep.keys(), drop_from_subsets=True) + self.add_task_nodes(deep.values()) + for replaced_task_node, referencing_subsets in removed: + for subset_label in referencing_subsets: + self._task_subsets[subset_label].add(replaced_task_node.label) + for task_node in shallow.values(): + self._xgraph.nodes[task_node.key]["instance"] = task_node + self._xgraph.nodes[task_node.init.key]["instance"] = task_node.init + except PipelineGraphExceptionSafetyError: # pragma: no cover + raise + except Exception as err: # pragma: no cover + # There's no known way to get here, but we want to make it clear + # it's a big problem if we do. + raise PipelineGraphExceptionSafetyError( + "Error while replacing tasks has left the graph in an inconsistent state." + ) from err + + def _append_graph_data_from_edge( + self, + node_data: list[tuple[NodeKey, dict[str, Any]]], + edge_data: list[tuple[NodeKey, NodeKey, str, dict[str, Any]]], + edge: Edge, + ) -> None: + """Append networkx state dictionaries for an edge and the corresponding + dataset type node. + + Parameters + ---------- + node_data : `list` + List of node keys and state dictionaries. A node is appended if + one does not already exist for this dataset type. + edge_data : `list` + List of node key pairs, connection names, and state dictionaries + for edges. + edge : `Edge` + New edge being processed. + """ + if (existing_dataset_type_state := self._xgraph.nodes.get(edge.dataset_type_key)) is not None: + existing_dataset_type_state["instance"] = None + else: + node_data.append( + ( + edge.dataset_type_key, + { + "instance": None, + "bipartite": NodeType.DATASET_TYPE.bipartite, + }, + ) + ) + edge_data.append( + edge.nodes + + ( + edge.connection_name, + {"instance": edge}, + ) + ) + + def _reorder(self, sorted_keys: Sequence[NodeKey]) -> None: + """Set the order of all views of this graph from the given sorted + sequence of task labels and dataset type names. + """ + self._sorted_keys = sorted_keys + self._tasks._reorder(sorted_keys) + self._dataset_types._reorder(sorted_keys) + + def _reset(self) -> None: + """Reset the all views of this graph following a modification that + might invalidate them. + """ + self._sorted_keys = None + self._tasks._reset() + self._dataset_types._reset() diff --git a/python/lsst/pipe/base/pipeline_graph/_task_subsets.py b/python/lsst/pipe/base/pipeline_graph/_task_subsets.py new file mode 100644 index 00000000..1c48ecab --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_task_subsets.py @@ -0,0 +1,122 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ("TaskSubset",) + +from collections.abc import Iterator, MutableSet + +import networkx +import networkx.algorithms.boundary + +from ._exceptions import PipelineGraphError +from ._nodes import NodeKey, NodeType + + +class TaskSubset(MutableSet[str]): + """A specialized set that represents a labeles subset of the tasks in a + pipeline graph. + + Instances of this class should never be constructed directly; they should + only be accessed via the `PipelineGraph.task_subsets` attribute and created + by the `PipelineGraph.add_task_subset` method. + + Parameters + ---------- + parent_xgraph : `networkx.DiGraph` + Parent networkx graph that this subgraph is part of. + label : `str` + Label associated with this subset of the pipeline. + members : `set` [ `str` ] + Labels of the tasks that are members of this subset. + description : `str`, optional + Description string associated with this labeled subset. + + Notes + ----- + Iteration order is arbitrary, even when the parent pipeline graph is + ordered (there is no guarantee that an ordering of the tasks in the graph + implies a consistent ordering of subsets). + """ + + def __init__( + self, + parent_xgraph: networkx.DiGraph, + label: str, + members: set[str], + description: str, + ): + self._parent_xgraph = parent_xgraph + self._label = label + self._members = members + self._description = description + + @property + def label(self) -> str: + """Label associated with this subset of the pipeline.""" + return self._label + + @property + def description(self) -> str: + """Description string associated with this labeled subset.""" + return self._description + + @description.setter + def description(self, value: str) -> None: + # Docstring in getter. + self._description = value + + def __repr__(self) -> str: + return f"{self.label}: {self.description!r}, tasks={{{', '.join(iter(self))}}}" + + def __contains__(self, key: object) -> bool: + return key in self._members + + def __len__(self) -> int: + return len(self._members) + + def __iter__(self) -> Iterator[str]: + return iter(self._members) + + def add(self, task_label: str) -> None: + """Add a new task to this subset. + + Parameters + ---------- + task_label : `str` + Label for the task. Must already be present in the parent pipeline + graph. + """ + key = NodeKey(NodeType.TASK, task_label) + if key not in self._parent_xgraph: + raise PipelineGraphError(f"{task_label!r} is not a task in the parent pipeline.") + self._members.add(key.name) + + def discard(self, task_label: str) -> None: + """Remove a task from the subset if it is present. + + Parameters + ---------- + task_label : `str` + Label for the task. Must already be present in the parent pipeline + graph. + """ + self._members.discard(task_label) diff --git a/python/lsst/pipe/base/pipeline_graph/_tasks.py b/python/lsst/pipe/base/pipeline_graph/_tasks.py new file mode 100644 index 00000000..8c3b4e2b --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_tasks.py @@ -0,0 +1,871 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ("TaskNode", "TaskInitNode") + +import dataclasses +from collections.abc import Iterator, Mapping +from typing import TYPE_CHECKING, Any, cast + +from lsst.daf.butler import DimensionGraph, DimensionUniverse +from lsst.utils.classes import immutable +from lsst.utils.doImport import doImportType +from lsst.utils.introspection import get_full_type_name + +from .. import automatic_connection_constants as acc +from ..connections import PipelineTaskConnections +from ..connectionTypes import BaseConnection, InitOutput, Output +from ._edges import Edge, ReadEdge, WriteEdge +from ._exceptions import TaskNotImportedError, UnresolvedGraphError +from ._nodes import NodeKey, NodeType + +if TYPE_CHECKING: + from ..config import PipelineTaskConfig + from ..pipelineTask import PipelineTask + + +@dataclasses.dataclass(frozen=True) +class _TaskNodeImportedData: + """An internal struct that holds `TaskNode` and `TaskInitNode` state that + requires task classes to be imported. + """ + + task_class: type[PipelineTask] + """Type object for the task.""" + + config: PipelineTaskConfig + """Configuration object for the task.""" + + connection_map: dict[str, BaseConnection] + """Mapping from connection name to connection. + + In addition to ``connections.allConnections``, this also holds the + "automatic" config, log, and metadata connections using the names defined + in the `.automatic_connection_constants` module. + """ + + connections: PipelineTaskConnections + """Configured connections object for the task.""" + + @classmethod + def configure( + cls, + label: str, + task_class: type[PipelineTask], + config: PipelineTaskConfig, + connections: PipelineTaskConnections | None = None, + ) -> _TaskNodeImportedData: + """Construct while creating a `PipelineTaskConnections` instance if + necessary. + + Parameters + ---------- + label : `str` + Label for the task in the pipeline. Only used in error messages. + task_class : `type` [ `.PipelineTask` ] + Pipeline task `type` object. + config : `.PipelineTaskConfig` + Configuration for the task. + connections : `.PipelineTaskConnections`, optional + Object that describes the dataset types used by the task. If not + provided, one will be constructed from the given configuration. If + provided, it is assumed that ``config`` has already been validated + and frozen. + + Returns + ------- + data : `_TaskNodeImportedData` + Instance of this struct. + """ + if connections is None: + # If we don't have connections yet, assume the config hasn't been + # validated yet. + try: + config.validate() + except Exception as err: + raise ValueError( + f"Configuration validation failed for task {label!r} (see chained exception)." + ) from err + config.freeze() + connections = task_class.ConfigClass.ConnectionsClass(config=config) + connection_map = dict(connections.allConnections) + connection_map[acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME] = InitOutput( + acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label), + acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, + ) + if not config.saveMetadata: + raise ValueError(f"Metadata for task {label} cannot be disabled.") + connection_map[acc.METADATA_OUTPUT_CONNECTION_NAME] = Output( + acc.METADATA_OUTPUT_TEMPLATE.format(label=label), + acc.METADATA_OUTPUT_STORAGE_CLASS, + dimensions=set(connections.dimensions), + ) + if config.saveLogOutput: + connection_map[acc.LOG_OUTPUT_CONNECTION_NAME] = Output( + acc.LOG_OUTPUT_TEMPLATE.format(label=label), + acc.LOG_OUTPUT_STORAGE_CLASS, + dimensions=set(connections.dimensions), + ) + return cls(task_class, config, connection_map, connections) + + +@immutable +class TaskInitNode: + """A node in a pipeline graph that represents the construction of a + `PipelineTask`. + + Parameters + ---------- + inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] + Graph edges that represent inputs required just to construct an + instance of this task, keyed by connection name. + outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ] + Graph edges that represent outputs of this task that are available + after just constructing it, keyed by connection name. + + This does not include the special `config_init_output` edge; use + `iter_all_outputs` to include that, too. + config_output : `WriteEdge` + The special init output edge that persists the task's configuration. + imported_data : `_TaskNodeImportedData`, optional + Internal struct that holds information that requires the task class to + have been be imported. + task_class_name : `str`, optional + Fully-qualified name of the task class. Must be provided if + ``imported_data`` is not. + config_str : `str`, optional + Configuration for the task as a string of override statements. Must be + provided if ``imported_data`` is not. + + Notes + ----- + When included in an exported `networkx` graph (e.g. + `PipelineGraph.make_xgraph`), task initialization nodes set the following + node attributes: + + - ``task_class_name`` + - ``bipartite`` (see `NodeType.bipartite`) + - ``task_class`` (only if `is_imported` is `True`) + - ``config`` (only if `is_importd` is `True`) + """ + + def __init__( + self, + key: NodeKey, + *, + inputs: Mapping[str, ReadEdge], + outputs: Mapping[str, WriteEdge], + config_output: WriteEdge, + imported_data: _TaskNodeImportedData | None = None, + task_class_name: str | None = None, + config_str: str | None = None, + ): + self.key = key + self.inputs = inputs + self.outputs = outputs + self.config_output = config_output + # Instead of setting attributes to None, we do not set them at all; + # this works better with the @immutable decorator, which supports + # deferred initialization but not reassignment. + if task_class_name is not None: + self._task_class_name = task_class_name + if config_str is not None: + self._config_str = config_str + if imported_data is not None: + self._imported_data = imported_data + else: + assert ( + self._task_class_name is not None and self._config_str is not None + ), "If imported_data is not present, task_class_name and config_str must be." + + key: NodeKey + """Key that identifies this node in internal and exported networkx graphs. + """ + + inputs: Mapping[str, ReadEdge] + """Graph edges that represent inputs required just to construct an instance + of this task, keyed by connection name. + """ + + outputs: Mapping[str, WriteEdge] + """Graph edges that represent outputs of this task that are available after + just constructing it, keyed by connection name. + + This does not include the special `config_output` edge; use + `iter_all_outputs` to include that, too. + """ + + config_output: WriteEdge + """The special output edge that persists the task's configuration. + """ + + @property + def label(self) -> str: + """Label of this configuration of a task in the pipeline.""" + return str(self.key) + + @property + def is_imported(self) -> bool: + """Whether this the task type for this node has been imported and + its configuration overrides applied. + + If this is `False`, the `task_class` and `config` attributes may not + be accessed. + """ + return hasattr(self, "_imported_data") + + @property + def task_class(self) -> type[PipelineTask]: + """Type object for the task. + + Accessing this attribute when `is_imported` is `False` will raise + `TaskNotImportedError`, but accessing `task_class_name` will not. + """ + return self._get_imported_data().task_class + + @property + def task_class_name(self) -> str: + """The fully-qualified string name of the task class.""" + try: + return self._task_class_name + except AttributeError: + pass + self._task_class_name = get_full_type_name(self.task_class) + return self._task_class_name + + @property + def config(self) -> PipelineTaskConfig: + """Configuration for the task. + + This is always frozen. + + Accessing this attribute when `is_imported` is `False` will raise + `TaskNotImportedError`, but calling `get_config_str` will not. + """ + return self._get_imported_data().config + + def get_config_str(self) -> str: + """Return the configuration for this task as a string of override + statements. + + Returns + ------- + config_str : `str` + String containing configuration-overload statements. + """ + try: + return self._config_str + except AttributeError: + pass + self._config_str = self.config.saveToString() + return self._config_str + + def iter_all_inputs(self) -> Iterator[ReadEdge]: + """Iterate over all inputs required for construction. + + This is the same as iteration over ``inputs.values()``, but it will be + updated to include any automatic init-input connections added in the + future, while `inputs` will continue to hold only task-defined init + inputs. + """ + return iter(self.inputs.values()) + + def iter_all_outputs(self) -> Iterator[WriteEdge]: + """Iterate over all outputs available after construction, including + special ones. + """ + yield from self.outputs.values() + yield self.config_output + + def diff_edges(self, other: TaskInitNode) -> list[str]: + """Compare the edges of this task initialization node to those from the + same task label in a different pipeline. + + Parameters + ---------- + other : `TaskInitNode` + Other node to compare to. Must have the same task label, but need + not have the same configuration or even the same task class. + + Returns + ------- + differences : `list` [ `str` ] + List of string messages describing differences between ``self`` and + ``other``. Will be empty if the two nodes have the same edges. + Messages will use 'A' to refer to ``self`` and 'B' to refer to + ``other``. + """ + result = [] + result += _diff_edge_mapping(self.inputs, self.inputs, self.label, "init input") + result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "init output") + result += self.config_output.diff(other.config_output, "config init output") + return result + + def _to_xgraph_state(self) -> dict[str, Any]: + """Convert this nodes's attributes into a dictionary suitable for use + in exported networkx graphs. + """ + result = {"task_class_name": self.task_class_name, "bipartite": NodeType.TASK_INIT.bipartite} + if hasattr(self, "_imported_data"): + result["task_class"] = self.task_class + result["config"] = self.config + return result + + def _get_imported_data(self) -> _TaskNodeImportedData: + """Return the imported data struct. + + Returns + ------- + imported_data : `_TaskNodeImportedData` + Internal structure holding state that requires the task class to + have been imported. + + Raises + ------ + TaskNotImportedError + Raised if `is_imported` is `False`. + """ + try: + return self._imported_data + except AttributeError: + raise TaskNotImportedError( + f"Task class {self.task_class_name!r} for label {self.label!r} has not been imported " + "(see PipelineGraph.import_and_configure)." + ) from None + + +@immutable +class TaskNode: + """A node in a pipeline graph that represents a labeled configuration of a + `PipelineTask`. + + Parameters + ---------- + key : `NodeKey` + Identifier for this node in networkx graphs. + init : `TaskInitNode` + Node representing the initialization of this task. + prerequisite_inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] + Graph edges that represent prerequisite inputs to this task, keyed by + connection name. + + Prerequisite inputs must already exist in the data repository when a + `QuantumGraph` is built, but have more flexibility in how they are + looked up than regular inputs. + inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] + Graph edges that represent regular runtime inputs to this task, keyed + by connection name. + outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ] + Graph edges that represent regular runtime outputs of this task, keyed + by connection name. + + This does not include the special `log_output` and `metadata_output` + edges; use `iter_all_outputs` to include that, too. + log_output : `WriteEdge` or `None` + The special runtime output that persists the task's logs. + metadata_output : `WriteEdge` + The special runtime output that persists the task's metadata. + dimensions : `lsst.daf.butler.DimensionGraph` or `frozenset` + Dimensions of the task. If a `frozenset`, the dimensions have not been + resolved by a `~lsst.daf.butler.DimensionUniverse` and cannot be safely + compared to other sets of dimensions. + + Notes + ----- + Task nodes are intentionally not equality comparable, since there are many + different (and useful) ways to compare these objects with no clear winner + as the most obvious behavior. + + When included in an exported `networkx` graph (e.g. + `PipelineGraph.make_xgraph`), task nodes set the following node attributes: + + - ``task_class_name`` + - ``bipartite`` (see `NodeType.bipartite`) + - ``task_class`` (only if `is_imported` is `True`) + - ``config`` (only if `is_importd` is `True`) + """ + + def __init__( + self, + key: NodeKey, + init: TaskInitNode, + *, + prerequisite_inputs: Mapping[str, ReadEdge], + inputs: Mapping[str, ReadEdge], + outputs: Mapping[str, WriteEdge], + log_output: WriteEdge | None, + metadata_output: WriteEdge, + dimensions: DimensionGraph | frozenset, + ): + self.key = key + self.init = init + self.prerequisite_inputs = prerequisite_inputs + self.inputs = inputs + self.outputs = outputs + self.log_output = log_output + self.metadata_output = metadata_output + self._dimensions = dimensions + + @staticmethod + def _from_imported_data( + key: NodeKey, + init_key: NodeKey, + data: _TaskNodeImportedData, + universe: DimensionUniverse | None, + ) -> TaskNode: + """Construct from a `PipelineTask` type and its configuration. + + Parameters + ---------- + key : `NodeKey` + Identifier for this node in networkx graphs. + init : `TaskInitNode` + Node representing the initialization of this task. + data : `_TaskNodeImportedData` + Internal struct that holds information that requires the task class + to have been be imported. + universe : `lsst.daf.butler.DimensionUniverse` or `None` + Definitions of all dimensions. + + Returns + ------- + node : `TaskNode` + New task node. + + Raises + ------ + ValueError + Raised if configuration validation failed when constructing + ``connections``. + """ + init_inputs = { + name: ReadEdge._from_connection_map(init_key, name, data.connection_map) + for name in data.connections.initInputs + } + prerequisite_inputs = { + name: ReadEdge._from_connection_map(key, name, data.connection_map, is_prerequisite=True) + for name in data.connections.prerequisiteInputs + } + inputs = { + name: ReadEdge._from_connection_map(key, name, data.connection_map) + for name in data.connections.inputs + } + init_outputs = { + name: WriteEdge._from_connection_map(init_key, name, data.connection_map) + for name in data.connections.initOutputs + } + outputs = { + name: WriteEdge._from_connection_map(key, name, data.connection_map) + for name in data.connections.outputs + } + init = TaskInitNode( + key=init_key, + inputs=init_inputs, + outputs=init_outputs, + config_output=WriteEdge._from_connection_map( + init_key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, data.connection_map + ), + imported_data=data, + ) + instance = TaskNode( + key=key, + init=init, + prerequisite_inputs=prerequisite_inputs, + inputs=inputs, + outputs=outputs, + log_output=( + WriteEdge._from_connection_map(key, acc.LOG_OUTPUT_CONNECTION_NAME, data.connection_map) + if data.config.saveLogOutput + else None + ), + metadata_output=WriteEdge._from_connection_map( + key, acc.METADATA_OUTPUT_CONNECTION_NAME, data.connection_map + ), + dimensions=( + frozenset(data.connections.dimensions) + if universe is None + else universe.extract(data.connections.dimensions) + ), + ) + return instance + + key: NodeKey + """Key that identifies this node in internal and exported networkx graphs. + """ + + prerequisite_inputs: Mapping[str, ReadEdge] + """Graph edges that represent prerequisite inputs to this task. + + Prerequisite inputs must already exist in the data repository when a + `QuantumGraph` is built, but have more flexibility in how they are looked + up than regular inputs. + """ + + inputs: Mapping[str, ReadEdge] + """Graph edges that represent regular runtime inputs to this task. + """ + + outputs: Mapping[str, WriteEdge] + """Graph edges that represent regular runtime outputs of this task. + + This does not include the special `log_output` and `metadata_output` edges; + use `iter_all_outputs` to include that, too. + """ + + log_output: WriteEdge | None + """The special runtime output that persists the task's logs. + """ + + metadata_output: WriteEdge + """The special runtime output that persists the task's metadata. + """ + + @property + def label(self) -> str: + """Label of this configuration of a task in the pipeline.""" + return self.key.name + + @property + def is_imported(self) -> bool: + """Whether this the task type for this node has been imported and + its configuration overrides applied. + + If this is `False`, the `task_class` and `config` attributes may not + be accessed. + """ + return self.init.is_imported + + @property + def task_class(self) -> type[PipelineTask]: + """Type object for the task. + + Accessing this attribute when `is_imported` is `False` will raise + `TaskNotImportedError`, but accessing `task_class_name` will not. + """ + return self.init.task_class + + @property + def task_class_name(self) -> str: + """The fully-qualified string name of the task class.""" + return self.init.task_class_name + + @property + def config(self) -> PipelineTaskConfig: + """Configuration for the task. + + This is always frozen. + + Accessing this attribute when `is_imported` is `False` will raise + `TaskNotImportedError`, but calling `get_config_str` will not. + """ + return self.init.config + + @property + def has_resolved_dimensions(self) -> bool: + """Whether the `dimensions` attribute may be accessed. + + If `False`, the `raw_dimensions` attribute may be used to obtain a + set of dimension names that has not been resolved by a + `~lsst.daf.butler.DimensionsUniverse`. + """ + return type(self._dimensions) is DimensionGraph + + @property + def dimensions(self) -> DimensionGraph: + """Standardized dimensions of the task.""" + if not self.has_resolved_dimensions: + raise UnresolvedGraphError(f"Dimensions for task {self.label!r} have not been resolved.") + return cast(DimensionGraph, self._dimensions) + + @property + def raw_dimensions(self) -> frozenset[str]: + """Raw dimensions of the task, with standardization by a + `~lsst.daf.butler.DimensionUniverse` not guaranteed. + """ + if self.has_resolved_dimensions: + return frozenset(cast(DimensionGraph, self._dimensions).names) + else: + return cast(frozenset[str], self._dimensions) + + def __repr__(self) -> str: + if self.has_resolved_dimensions: + return f"{self.label} ({self.task_class_name}, {self.dimensions})" + else: + return f"{self.label} ({self.task_class_name})" + + def get_config_str(self) -> str: + """Return the configuration for this task as a string of override + statements. + + Returns + ------- + config_str : `str` + String containing configuration-overload statements. + """ + return self.init.get_config_str() + + def iter_all_inputs(self) -> Iterator[ReadEdge]: + """Iterate over all runtime inputs, including both regular inputs and + prerequisites. + """ + yield from self.prerequisite_inputs.values() + yield from self.inputs.values() + + def iter_all_outputs(self) -> Iterator[WriteEdge]: + """Iterate over all runtime outputs, including special ones.""" + yield from self.outputs.values() + yield self.metadata_output + if self.log_output is not None: + yield self.log_output + + def diff_edges(self, other: TaskNode) -> list[str]: + """Compare the edges of this task node to those from the same task + label in a different pipeline. + + This also calls `TaskInitNode.diff_edges`. + + Parameters + ---------- + other : `TaskInitNode` + Other node to compare to. Must have the same task label, but need + not have the same configuration or even the same task class. + + Returns + ------- + differences : `list` [ `str` ] + List of string messages describing differences between ``self`` and + ``other``. Will be empty if the two nodes have the same edges. + Messages will use 'A' to refer to ``self`` and 'B' to refer to + ``other``. + """ + result = self.init.diff_edges(other.init) + result += _diff_edge_mapping( + self.prerequisite_inputs, other.prerequisite_inputs, self.label, "prerequisite input" + ) + result += _diff_edge_mapping(self.inputs, other.inputs, self.label, "input") + result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "output") + if self.log_output is not None: + if other.log_output is not None: + result += self.log_output.diff(other.log_output, "log output") + else: + result.append("Log output is present in A, but not in B.") + elif other.log_output is not None: + result.append("Log output is present in B, but not in A.") + result += self.metadata_output.diff(other.metadata_output, "metadata output") + return result + + def _imported_and_configured(self, rebuild: bool) -> TaskNode: + """Import the task class and use it to construct a new instance. + + Parameters + ---------- + rebuild : `bool` + If `True`, import the task class and configure its connections to + generate new edges that may differ from the current ones. If + `False`, import the task class but just update the `task_class` and + `config` attributes, and assume the edges have not changed. + + Returns + ------- + node : `TaskNode` + Task node instance for which `is_imported` is `True`. Will be + ``self`` if this is the case already. + """ + from ..pipelineTask import PipelineTask + + if self.is_imported: + return self + task_class = doImportType(self.task_class_name) + if not issubclass(task_class, PipelineTask): + raise TypeError(f"{self.task_class_name!r} is not a PipelineTask subclass.") + config = task_class.ConfigClass() + config.loadFromString(self.get_config_str()) + return self._reconfigured(config, rebuild=rebuild, task_class=task_class) + + def _reconfigured( + self, + config: PipelineTaskConfig, + rebuild: bool, + task_class: type[PipelineTask] | None = None, + ) -> TaskNode: + """Return a version of this node with new configuration. + + Parameters + ---------- + config : `.PipelineTaskConfig` + New configuration for the task. + rebuild : `bool` + If `True`, use the configured connections to generate new edges + that may differ from the current ones. If `False`, just update the + `task_class` and `config` attributes, and assume the edges have not + changed. + task_class : `type` [ `PipelineTask` ], optional + Subclass of `PipelineTask`. This defaults to ``self.task_class`, + but may be passed as an argument if that is not available because + the task class was not imported when ``self`` was constructed. + + Returns + ------- + node : `TaskNode` + Task node instance with the new config. + """ + if task_class is None: + task_class = self.task_class + imported_data = _TaskNodeImportedData.configure(self.key.name, task_class, config) + if rebuild: + return self._from_imported_data( + self.key, + self.init.key, + imported_data, + universe=self._dimensions.universe if type(self._dimensions) is DimensionGraph else None, + ) + else: + return TaskNode( + self.key, + TaskInitNode( + self.init.key, + inputs=self.init.inputs, + outputs=self.init.outputs, + config_output=self.init.config_output, + imported_data=imported_data, + ), + prerequisite_inputs=self.prerequisite_inputs, + inputs=self.inputs, + outputs=self.outputs, + log_output=self.log_output, + metadata_output=self.metadata_output, + dimensions=self._dimensions, + ) + + def _resolved(self, universe: DimensionUniverse | None) -> TaskNode: + """Return an otherwise-equivalent task node with resolved dimensions. + + Parameters + ---------- + universe : `lsst.daf.butler.DimensionUniverse` or `None` + Definitions for all dimensions. + + Returns + ------- + node : `TaskNode` + Task node instance with `dimensions` resolved by the given + universe. Will be ``self`` if this is the case already. + """ + if self.has_resolved_dimensions: + if cast(DimensionGraph, self._dimensions).universe is universe: + return self + elif universe is None: + return self + return TaskNode( + key=self.key, + init=self.init, + prerequisite_inputs=self.prerequisite_inputs, + inputs=self.inputs, + outputs=self.outputs, + log_output=self.log_output, + metadata_output=self.metadata_output, + dimensions=( + universe.extract(self.raw_dimensions) if universe is not None else self.raw_dimensions + ), + ) + + def _to_xgraph_state(self) -> dict[str, Any]: + """Convert this nodes's attributes into a dictionary suitable for use + in exported networkx graphs. + """ + result = self.init._to_xgraph_state() + if self.has_resolved_dimensions: + result["dimensions"] = self._dimensions + result["raw_dimensions"] = self.raw_dimensions + return result + + def _get_imported_data(self) -> _TaskNodeImportedData: + """Return the imported data struct. + + Returns + ------- + imported_data : `_TaskNodeImportedData` + Internal structure holding state that requires the task class to + have been imported. + + Raises + ------ + TaskNotImportedError + Raised if `is_imported` is `False`. + """ + return self.init._get_imported_data() + + +def _diff_edge_mapping( + a_mapping: Mapping[str, Edge], b_mapping: Mapping[str, Edge], task_label: str, connection_type: str +) -> list[str]: + """Compare a pair of mappings of edges. + + Parameters + ---------- + a_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ] + First mapping to compare. Expected to have connection names as keys. + b_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ] + First mapping to compare. If keys differ from those of ``a_mapping``, + this will be reported as a difference (in addition to element-wise + comparisons). + task_label : `str` + Task label associated with both mappings. + connection_type : `str` + Type of connection (e.g. "input" or "init output") associated with both + connections. This is a human-readable string to include in difference + messages. + + Returns + ------- + differences : `list` [ `str` ] + List of string messages describing differences between the two + mappings. Will be empty if the two mappings have the same edges. + Messages will include "A" and "B", and are expected to be a preceded + by a message describing what "A" and "B" are in the context in which + this method is called. + + Notes + ----- + This is expected to be used to compare one edge-holding mapping attribute + of a task or task init node to the same attribute on another task or task + init node (i.e. any of `TaskNode.inputs`, `TaskNode.outputs`, + `TaskNode.prerequisite_inputs`, `TaskInitNode.inputs`, + `TaskInitNode.outputs`). + """ + results = [] + b_to_do = set(b_mapping.keys()) + for connection_name, a_edge in a_mapping.items(): + if (b_edge := b_mapping.get(connection_name)) is None: + results.append( + f"{connection_type.capitalize()} {connection_name!r} of task " + f"{task_label!r} exists in A, but not in B (or it may have a different connection type)." + ) + else: + results.extend(a_edge.diff(b_edge, connection_type)) + b_to_do.discard(connection_name) + for connection_name in b_to_do: + results.append( + f"{connection_type.capitalize()} {connection_name!r} of task " + f"{task_label!r} exists in A, but not in B (or it may have a different connection type)." + ) + return results diff --git a/python/lsst/pipe/base/pipeline_graph/io.py b/python/lsst/pipe/base/pipeline_graph/io.py new file mode 100644 index 00000000..09e52df5 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/io.py @@ -0,0 +1,578 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ( + "expect_not_none", + "SerializedEdge", + "SerializedTaskInitNode", + "SerializedTaskNode", + "SerializedDatasetTypeNode", + "SerializedTaskSubset", + "SerializedPipelineGraph", +) + +from typing import Any, TypeVar + +import networkx +import pydantic +from lsst.daf.butler import DatasetType, DimensionConfig, DimensionGraph, DimensionUniverse + +from .. import automatic_connection_constants as acc +from ._dataset_types import DatasetTypeNode +from ._edges import Edge, ReadEdge, WriteEdge +from ._exceptions import PipelineGraphReadError +from ._nodes import NodeKey, NodeType +from ._pipeline_graph import PipelineGraph +from ._task_subsets import TaskSubset +from ._tasks import TaskInitNode, TaskNode + +_U = TypeVar("_U") + +_IO_VERSION_INFO = (0, 0, 1) +"""Version tuple embedded in saved PipelineGraphs. +""" + + +def expect_not_none(value: _U | None, msg: str) -> _U: + """Check that a value is not `None` and return it. + + Parameters + ---------- + value + Value to check + msg + Error message for the case where ``value is None``. + + Returns + ------- + value + Value, guaranteed not to be `None`. + + Raises + ------ + PipelineGraphReadError + Raised with ``msg`` if ``value is None``. + """ + if value is None: + raise PipelineGraphReadError(msg) + return value + + +class SerializedEdge(pydantic.BaseModel): + """Struct used to represent a serialized `Edge` in a `PipelineGraph`. + + All `ReadEdge` and `WriteEdge` state not included here is instead + effectively serialized by the context in which a `SerializedEdge` appears + (e.g. the keys of the nested dictionaries in which it serves as the value + type). + """ + + dataset_type_name: str + """Full dataset type name (including component).""" + + storage_class: str + """Name of the storage class.""" + + raw_dimensions: list[str] + """Raw dimensions of the dataset type from the task connections.""" + + is_calibration: bool = False + """Whether this dataset type can be included in + `~lsst.daf.butler.CollectionType.CALIBRATION` collections.""" + + defer_query_constraint: bool = False + """If `True`, by default do not include this dataset type's existence as a + constraint on the initial data ID query in QuantumGraph generation.""" + + @classmethod + def serialize(cls, target: Edge) -> SerializedEdge: + """Transform an `Edge` to a `SerializedEdge`.""" + return SerializedEdge.construct( + storage_class=target.storage_class_name, + dataset_type_name=target.dataset_type_name, + raw_dimensions=sorted(target.raw_dimensions), + is_calibration=target.is_calibration, + defer_query_constraint=getattr(target, "defer_query_constraint", False), + ) + + def deserialize_read_edge( + self, + task_key: NodeKey, + connection_name: str, + is_prerequisite: bool = False, + ) -> ReadEdge: + """Transform a `SerializedEdge` to a `ReadEdge`.""" + parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(self.dataset_type_name) + dataset_type_key = NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name) + return ReadEdge( + dataset_type_key, + task_key, + storage_class_name=self.storage_class, + is_prerequisite=is_prerequisite, + component=component, + connection_name=connection_name, + is_calibration=self.is_calibration, + defer_query_constraint=self.defer_query_constraint, + raw_dimensions=frozenset(self.raw_dimensions), + ) + + def deserialize_write_edge( + self, + task_key: NodeKey, + connection_name: str, + ) -> WriteEdge: + """Transform a `SerializedEdge` to a `WriteEdge`.""" + dataset_type_key = NodeKey(NodeType.DATASET_TYPE, self.dataset_type_name) + return WriteEdge( + task_key=task_key, + dataset_type_key=dataset_type_key, + storage_class_name=self.storage_class, + connection_name=connection_name, + is_calibration=self.is_calibration, + raw_dimensions=frozenset(self.raw_dimensions), + ) + + +class SerializedTaskInitNode(pydantic.BaseModel): + """Struct used to represent a serialized `TaskInitNode` in a + `PipelineGraph`. + + The task label is serialized by the context in which a + `SerializedTaskInitNode` appears (e.g. the keys of the nested dictionary + in which it serves as the value type), and the task class name and config + string are save with the corresponding `SerializedTaskNode`. + """ + + inputs: dict[str, SerializedEdge] + """Mapping of serialized init-input edges, keyed by connection name.""" + + outputs: dict[str, SerializedEdge] + """Mapping of serialized init-output edges, keyed by connection name.""" + + config_output: SerializedEdge + """The serialized config init-output edge.""" + + index: int | None = None + """The index of this node in the sorted sequence of `PipelineGraph`. + + This is `None` if the `PipelineGraph` was not sorted when it was + serialized. + """ + + @classmethod + def serialize(cls, target: TaskInitNode) -> SerializedTaskInitNode: + """Transform a `TaskInitNode` to a `SerializedTaskInitNode`.""" + return cls.construct( + inputs={ + connection_name: SerializedEdge.serialize(edge) + for connection_name, edge in sorted(target.inputs.items()) + }, + outputs={ + connection_name: SerializedEdge.serialize(edge) + for connection_name, edge in sorted(target.outputs.items()) + }, + config_output=SerializedEdge.serialize(target.config_output), + ) + + def deserialize( + self, + key: NodeKey, + task_class_name: str, + config_str: str, + ) -> TaskInitNode: + """Transform a `SerializedTaskInitNode` to a `TaskInitNode`.""" + return TaskInitNode( + key, + inputs={ + connection_name: serialized_edge.deserialize_read_edge(key, connection_name) + for connection_name, serialized_edge in self.inputs.items() + }, + outputs={ + connection_name: serialized_edge.deserialize_write_edge(key, connection_name) + for connection_name, serialized_edge in self.outputs.items() + }, + config_output=self.config_output.deserialize_write_edge( + key, + acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, + ), + task_class_name=task_class_name, + config_str=config_str, + ) + + +class SerializedTaskNode(pydantic.BaseModel): + """Struct used to represent a serialized `TaskNode` in a `PipelineGraph`. + + The task label is serialized by the context in which a + `SerializedTaskNode` appears (e.g. the keys of the nested dictionary in + which it serves as the value type). + """ + + task_class: str + """Fully-qualified name of the task class.""" + + init: SerializedTaskInitNode + """Serialized task initialization node.""" + + config_str: str + """Configuration for the task as a string of override statements.""" + + prerequisite_inputs: dict[str, SerializedEdge] + """Mapping of serialized prerequisiste input edges, keyed by connection + name. + """ + + inputs: dict[str, SerializedEdge] + """Mapping of serialized input edges, keyed by connection name.""" + + outputs: dict[str, SerializedEdge] + """Mapping of serialized output edges, keyed by connection name.""" + + metadata_output: SerializedEdge + """The serialized metadata output edge.""" + + dimensions: list[str] + """The task's dimensions, if they were resolved.""" + + log_output: SerializedEdge | None = None + """The serialized log output edge.""" + + index: int | None = None + """The index of this node in the sorted sequence of `PipelineGraph`. + + This is `None` if the `PipelineGraph` was not sorted when it was + serialized. + """ + + @classmethod + def serialize(cls, target: TaskNode) -> SerializedTaskNode: + """Transform a `TaskNode` to a `SerializedTaskNode`.""" + return cls.construct( + task_class=target.task_class_name, + init=SerializedTaskInitNode.serialize(target.init), + config_str=target.get_config_str(), + dimensions=list(target.raw_dimensions), + prerequisite_inputs={ + connection_name: SerializedEdge.serialize(edge) + for connection_name, edge in sorted(target.prerequisite_inputs.items()) + }, + inputs={ + connection_name: SerializedEdge.serialize(edge) + for connection_name, edge in sorted(target.inputs.items()) + }, + outputs={ + connection_name: SerializedEdge.serialize(edge) + for connection_name, edge in sorted(target.outputs.items()) + }, + metadata_output=SerializedEdge.serialize(target.metadata_output), + log_output=( + SerializedEdge.serialize(target.log_output) if target.log_output is not None else None + ), + ) + + def deserialize(self, key: NodeKey, init_key: NodeKey, universe: DimensionUniverse | None) -> TaskNode: + """Transform a `SerializedTaskNode` to a `TaskNode`.""" + init = self.init.deserialize( + init_key, + task_class_name=self.task_class, + config_str=expect_not_none( + self.config_str, f"No serialized config file for task with label {key.name!r}." + ), + ) + inputs = { + connection_name: serialized_edge.deserialize_read_edge(key, connection_name) + for connection_name, serialized_edge in self.inputs.items() + } + prerequisite_inputs = { + connection_name: serialized_edge.deserialize_read_edge(key, connection_name, is_prerequisite=True) + for connection_name, serialized_edge in self.prerequisite_inputs.items() + } + outputs = { + connection_name: serialized_edge.deserialize_write_edge(key, connection_name) + for connection_name, serialized_edge in self.outputs.items() + } + if (serialized_log_output := self.log_output) is not None: + log_output = serialized_log_output.deserialize_write_edge(key, acc.LOG_OUTPUT_CONNECTION_NAME) + else: + log_output = None + metadata_output = self.metadata_output.deserialize_write_edge( + key, acc.METADATA_OUTPUT_CONNECTION_NAME + ) + dimensions: frozenset[str] | DimensionGraph + if universe is not None: + dimensions = universe.extract(self.dimensions) + else: + dimensions = frozenset(self.dimensions) + return TaskNode( + key=key, + init=init, + inputs=inputs, + prerequisite_inputs=prerequisite_inputs, + outputs=outputs, + log_output=log_output, + metadata_output=metadata_output, + dimensions=dimensions, + ) + + +class SerializedDatasetTypeNode(pydantic.BaseModel): + """Struct used to represent a serialized `DatasetTypeNode` in a + `PipelineGraph`. + + Unresolved dataset types are serialized as instances with at most the + `index` attribute set, and are typically converted to JSON with pydantic's + ``exclude_defaults=True`` option to keep this compact. + + The dataset typename is serialized by the context in which a + `SerializedDatasetTypeNode` appears (e.g. the keys of the nested dictionary + in which it serves as the value type). + """ + + dimensions: list[str] | None = None + """Dimensions of the dataset type.""" + + storage_class: str | None = None + """Name of the storage class.""" + + is_calibration: bool = False + """Whether this dataset type is a calibration.""" + + is_initial_query_constraint: bool = False + """Whether this dataset type should be a query constraint during + `QuantumGraph` generation.""" + + is_prerequisite: bool = False + """Whether datasets of this dataset type must exist in the input collection + before `QuantumGraph` generation.""" + + index: int | None = None + """The index of this node in the sorted sequence of `PipelineGraph`. + + This is `None` if the `PipelineGraph` was not sorted when it was + serialized. + """ + + @classmethod + def serialize(cls, target: DatasetTypeNode | None) -> SerializedDatasetTypeNode: + """Transform a `DatasetTypeNode` to a `SerializedDatasetTypeNode`.""" + if target is None: + return cls.construct() + return cls.construct( + dimensions=list(target.dataset_type.dimensions.names), + storage_class=target.dataset_type.storageClass_name, + is_calibration=target.dataset_type.isCalibration(), + is_initial_query_constraint=target.is_initial_query_constraint, + is_prerequisite=target.is_prerequisite, + ) + + def deserialize(self, key: NodeKey, universe: DimensionUniverse | None) -> DatasetTypeNode | None: + """Transform a `SerializedDatasetTypeNode` to a `DatasetTypeNode`.""" + if self.dimensions is not None: + dataset_type = DatasetType( + key.name, + expect_not_none( + self.dimensions, + f"Serialized dataset type {key.name!r} has no dimensions.", + ), + storageClass=expect_not_none( + self.storage_class, + f"Serialized dataset type {key.name!r} has no storage class.", + ), + isCalibration=self.is_calibration, + universe=expect_not_none( + universe, + f"Serialized dataset type {key.name!r} has dimensions, " + "but no dimension universe was stored.", + ), + ) + return DatasetTypeNode( + dataset_type=dataset_type, + is_prerequisite=self.is_prerequisite, + is_initial_query_constraint=self.is_initial_query_constraint, + ) + return None + + +class SerializedTaskSubset(pydantic.BaseModel): + """Struct used to represent a serialized `TaskSubset` in a `PipelineGraph`. + + The subsetlabel is serialized by the context in which a + `SerializedDatasetTypeNode` appears (e.g. the keys of the nested dictionary + in which it serves as the value type). + """ + + description: str + """Description of the subset.""" + + tasks: list[str] + """Labels of tasks in the subset, sorted lexicographically for + determinism. + """ + + @classmethod + def serialize(cls, target: TaskSubset) -> SerializedTaskSubset: + """Transform a `TaskSubset` into a `SerializedTaskSubset`.""" + return cls.construct(description=target._description, tasks=list(sorted(target))) + + def deserialize_task_subset(self, label: str, xgraph: networkx.MultiDiGraph) -> TaskSubset: + """Transform a `SerializedTaskSubset` into a `TaskSubset`.""" + members = set(self.tasks) + return TaskSubset(xgraph, label, members, self.description) + + +class SerializedPipelineGraph(pydantic.BaseModel): + """Struct used to represent a serialized `PipelineGraph`.""" + + version: str = ".".join(str(v) for v in _IO_VERSION_INFO) + """Serialization version.""" + + description: str + """Human-readable description of the pipeline.""" + + tasks: dict[str, SerializedTaskNode] = pydantic.Field(default_factory=dict) + """Mapping of serialized tasks, keyed by label.""" + + dataset_types: dict[str, SerializedDatasetTypeNode] = pydantic.Field(default_factory=dict) + """Mapping of serialized dataset types, keyed by parent dataset type name. + """ + + task_subsets: dict[str, SerializedTaskSubset] = pydantic.Field(default_factory=dict) + """Mapping of task subsets, keyed by subset label.""" + + dimensions: dict[str, Any] | None = None + """Dimension universe configuration.""" + + data_id: dict[str, Any] = pydantic.Field(default_factory=dict) + """Data ID that constrains all quanta generated from this pipeline.""" + + @classmethod + def serialize(cls, target: PipelineGraph) -> SerializedPipelineGraph: + """Transform a `PipelineGraph` into a `SerializedPipelineGraph`.""" + result = SerializedPipelineGraph.construct( + description=target.description, + tasks={label: SerializedTaskNode.serialize(node) for label, node in target.tasks.items()}, + dataset_types={ + name: SerializedDatasetTypeNode().serialize(target.dataset_types.get_if_resolved(name)) + for name in target.dataset_types.keys() + }, + task_subsets={ + label: SerializedTaskSubset.serialize(subset) for label, subset in target.task_subsets.items() + }, + dimensions=target.universe.dimensionConfig.toDict() if target.universe is not None else None, + data_id=target._raw_data_id, + ) + if target._sorted_keys: + for index, node_key in enumerate(target._sorted_keys): + match node_key.node_type: + case NodeType.TASK: + result.tasks[node_key.name].index = index + case NodeType.DATASET_TYPE: + result.dataset_types[node_key.name].index = index + case NodeType.TASK_INIT: + result.tasks[node_key.name].init.index = index + return result + + def deserialize( + self, + import_and_configure: bool = True, + check_edges_unchanged: bool = False, + assume_edges_unchanged: bool = False, + ) -> PipelineGraph: + """Transform a `SerializedPipelineGraph` into a `PipelineGraph`.""" + universe: DimensionUniverse | None = None + if self.dimensions is not None: + universe = DimensionUniverse( + config=DimensionConfig( + expect_not_none( + self.dimensions, + "Serialized pipeline graph has not been resolved; " + "load it is a MutablePipelineGraph instead.", + ) + ) + ) + xgraph = networkx.MultiDiGraph() + sort_index_map: dict[int, NodeKey] = {} + for dataset_type_name, serialized_dataset_type in self.dataset_types.items(): + dataset_type_key = NodeKey(NodeType.DATASET_TYPE, dataset_type_name) + dataset_type_node = serialized_dataset_type.deserialize(dataset_type_key, universe) + xgraph.add_node( + dataset_type_key, instance=dataset_type_node, bipartite=NodeType.DATASET_TYPE.value + ) + if serialized_dataset_type.index is not None: + sort_index_map[serialized_dataset_type.index] = dataset_type_key + for task_label, serialized_task in self.tasks.items(): + task_key = NodeKey(NodeType.TASK, task_label) + task_init_key = NodeKey(NodeType.TASK_INIT, task_label) + task_node = serialized_task.deserialize(task_key, task_init_key, universe) + if serialized_task.index is not None: + sort_index_map[serialized_task.index] = task_key + if serialized_task.init.index is not None: + sort_index_map[serialized_task.init.index] = task_init_key + xgraph.add_node(task_key, instance=task_node, bipartite=NodeType.TASK.bipartite) + xgraph.add_node(task_init_key, instance=task_node.init, bipartite=NodeType.TASK_INIT.bipartite) + xgraph.add_edge(task_init_key, task_key, Edge.INIT_TO_TASK_NAME, instance=None) + for read_edge in task_node.init.iter_all_inputs(): + xgraph.add_edge( + read_edge.dataset_type_key, + read_edge.task_key, + read_edge.connection_name, + instance=read_edge, + ) + for write_edge in task_node.init.iter_all_outputs(): + xgraph.add_edge( + write_edge.task_key, + write_edge.dataset_type_key, + write_edge.connection_name, + instance=write_edge, + ) + for read_edge in task_node.iter_all_inputs(): + xgraph.add_edge( + read_edge.dataset_type_key, + read_edge.task_key, + read_edge.connection_name, + instance=read_edge, + ) + for write_edge in task_node.iter_all_outputs(): + xgraph.add_edge( + write_edge.task_key, + write_edge.dataset_type_key, + write_edge.connection_name, + instance=write_edge, + ) + result = PipelineGraph.__new__(PipelineGraph) + result._init_from_args( + xgraph, + sorted_keys=[sort_index_map[i] for i in range(len(xgraph))] if sort_index_map else None, + task_subsets={ + subset_label: serialized_subset.deserialize_task_subset(subset_label, xgraph) + for subset_label, serialized_subset in self.task_subsets.items() + }, + description=self.description, + universe=universe, + data_id=self.data_id, + ) + if import_and_configure: + result.import_and_configure( + check_edges_unchanged=check_edges_unchanged, + assume_edges_unchanged=assume_edges_unchanged, + ) + return result diff --git a/tests/test_pipeline_graph.py b/tests/test_pipeline_graph.py new file mode 100644 index 00000000..6c232bc3 --- /dev/null +++ b/tests/test_pipeline_graph.py @@ -0,0 +1,1258 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Tests of things related to the GraphBuilder class.""" + +import copy +import io +import logging +import unittest +from typing import Any + +import lsst.pipe.base.automatic_connection_constants as acc +import lsst.utils.tests +from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse +from lsst.daf.butler.registry import MissingDatasetTypeError +from lsst.pipe.base.pipeline_graph import ( + ConnectionTypeConsistencyError, + DuplicateOutputError, + Edge, + EdgesChangedError, + IncompatibleDatasetTypeError, + NodeKey, + NodeType, + PipelineGraph, + PipelineGraphError, + UnresolvedGraphError, +) +from lsst.pipe.base.tests.mocks import ( + DynamicConnectionConfig, + DynamicTestPipelineTask, + DynamicTestPipelineTaskConfig, + get_mock_name, +) + +_LOG = logging.getLogger(__name__) + + +class MockRegistry: + """A test-utility stand-in for lsst.daf.butler.Registry that just knows + how to get dataset types. + """ + + def __init__(self, dimensions: DimensionUniverse, dataset_types: dict[str, DatasetType]) -> None: + self.dimensions = dimensions + self._dataset_types = dataset_types + + def getDatasetType(self, name: str) -> DatasetType: + try: + return self._dataset_types[name] + except KeyError: + raise MissingDatasetTypeError(name) + + +class PipelineGraphTestCase(unittest.TestCase): + """Tests for the `PipelineGraph` class. + + Tests for `PipelineGraph.resolve` are mostly in + `PipelineGraphResolveTestCase` later in this file. + """ + + def setUp(self) -> None: + # Simple test pipeline has two tasks, 'a' and 'b', with dataset types + # 'input', 'intermediate', and 'output'. There are no dimensions on + # any of those. We add tasks in reverse order to better test sorting. + # There is one labeled task subset, 'only_b', with just 'b' in it. + # We copy the configs so the originals (the instance attributes) can + # be modified and reused after the ones passed in to the graph are + # frozen. + self.description = "A pipeline for PipelineGraph unit tests." + self.graph = PipelineGraph() + self.graph.description = self.description + self.b_config = DynamicTestPipelineTaskConfig() + self.b_config.init_inputs["in_schema"] = DynamicConnectionConfig(dataset_type_name="schema") + self.b_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1") + self.b_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="output_1") + self.graph.add_task("b", DynamicTestPipelineTask, copy.deepcopy(self.b_config)) + self.a_config = DynamicTestPipelineTaskConfig() + self.a_config.init_outputs["out_schema"] = DynamicConnectionConfig(dataset_type_name="schema") + self.a_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="input_1") + self.a_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1") + self.graph.add_task("a", DynamicTestPipelineTask, copy.deepcopy(self.a_config)) + self.graph.add_task_subset("only_b", ["b"]) + self.subset_description = "A subset with only task B in it." + self.graph.task_subsets["only_b"].description = self.subset_description + self.dimensions = DimensionUniverse() + self.maxDiff = None + + def test_unresolved_accessors(self) -> None: + """Test attribute accessors, iteration, and simple methods on a graph + that has not had `PipelineGraph.resolve` called on it. + """ + self.check_base_accessors(self.graph) + self.assertEqual( + repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask)" + ) + + def test_sorting(self) -> None: + """Test sort methods on PipelineGraph.""" + self.assertFalse(self.graph.has_been_sorted) + self.assertFalse(self.graph.is_sorted) + self.graph.sort() + self.check_sorted(self.graph) + + def test_unresolved_xgraph_export(self) -> None: + """Test exporting an unresolved PipelineGraph to networkx in various + ways. + """ + self.check_make_xgraph(self.graph, resolved=False) + self.check_make_bipartite_xgraph(self.graph, resolved=False) + self.check_make_task_xgraph(self.graph, resolved=False) + self.check_make_dataset_type_xgraph(self.graph, resolved=False) + + def test_unresolved_stream_io(self) -> None: + """Test round-tripping an unresolved PipelineGraph through in-memory + serialization. + """ + stream = io.BytesIO() + self.graph.write_stream(stream) + stream.seek(0) + roundtripped = PipelineGraph.read_stream(stream) + self.check_make_xgraph(roundtripped, resolved=False) + + def test_unresolved_file_io(self) -> None: + """Test round-tripping an unresolved PipelineGraph through file + serialization. + """ + with lsst.utils.tests.getTempFilePath(".json.gz") as filename: + self.graph.write_uri(filename) + roundtripped = PipelineGraph.read_uri(filename) + self.check_make_xgraph(roundtripped, resolved=False) + + def test_unresolved_deferred_import_io(self) -> None: + """Test round-tripping an unresolved PipelineGraph through + serialization, without immediately importing tasks on read. + """ + stream = io.BytesIO() + self.graph.write_stream(stream) + stream.seek(0) + roundtripped = PipelineGraph.read_stream(stream, import_and_configure=False) + self.check_make_xgraph(roundtripped, resolved=False, imported_and_configured=False) + # Check that we can still resolve the graph without importing tasks. + roundtripped.resolve(MockRegistry(self.dimensions, {})) + self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) + roundtripped.import_and_configure(assume_edges_unchanged=True) + self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) + + def test_resolved_accessors(self) -> None: + """Test attribute accessors, iteration, and simple methods on a graph + that has had `PipelineGraph.resolve` called on it. + + This includes the accessors available on unresolved graphs as well as + new ones, and we expect the resolved graph to be sorted as well. + """ + self.graph.resolve(MockRegistry(self.dimensions, {})) + self.check_base_accessors(self.graph) + self.check_sorted(self.graph) + self.assertEqual( + repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask, {})" + ) + self.assertEqual(self.graph.tasks["a"].dimensions, self.dimensions.empty) + self.assertEqual(repr(self.graph.dataset_types["input_1"]), "input_1 (_mock_StructuredDataDict, {})") + self.assertEqual(self.graph.dataset_types["input_1"].key, NodeKey(NodeType.DATASET_TYPE, "input_1")) + self.assertEqual(self.graph.dataset_types["input_1"].dimensions, self.dimensions.empty) + self.assertEqual(self.graph.dataset_types["input_1"].storage_class_name, "_mock_StructuredDataDict") + self.assertEqual(self.graph.dataset_types["input_1"].storage_class.name, "_mock_StructuredDataDict") + + def test_resolved_xgraph_export(self) -> None: + """Test exporting a resolved PipelineGraph to networkx in various + ways. + """ + self.graph.resolve(MockRegistry(self.dimensions, {})) + self.check_make_xgraph(self.graph, resolved=True) + self.check_make_bipartite_xgraph(self.graph, resolved=True) + self.check_make_task_xgraph(self.graph, resolved=True) + self.check_make_dataset_type_xgraph(self.graph, resolved=True) + + def test_resolved_stream_io(self) -> None: + """Test round-tripping a resolved PipelineGraph through in-memory + serialization. + """ + self.graph.resolve(MockRegistry(self.dimensions, {})) + stream = io.BytesIO() + self.graph.write_stream(stream) + stream.seek(0) + roundtripped = PipelineGraph.read_stream(stream) + self.check_make_xgraph(roundtripped, resolved=True) + + def test_resolved_file_io(self) -> None: + """Test round-tripping a resolved PipelineGraph through file + serialization. + """ + self.graph.resolve(MockRegistry(self.dimensions, {})) + with lsst.utils.tests.getTempFilePath(".json.gz") as filename: + self.graph.write_uri(filename) + roundtripped = PipelineGraph.read_uri(filename) + self.check_make_xgraph(roundtripped, resolved=True) + + def test_resolved_deferred_import_io(self) -> None: + """Test round-tripping a resolved PipelineGraph through serialization, + without immediately importing tasks on read. + """ + self.graph.resolve(MockRegistry(self.dimensions, {})) + stream = io.BytesIO() + self.graph.write_stream(stream) + stream.seek(0) + roundtripped = PipelineGraph.read_stream(stream, import_and_configure=False) + self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) + roundtripped.import_and_configure(check_edges_unchanged=True) + self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) + + def test_unresolved_copies(self) -> None: + """Test making copies of an unresolved PipelineGraph.""" + copy1 = self.graph.copy() + self.assertIsNot(copy1, self.graph) + self.check_make_xgraph(copy1, resolved=False) + copy2 = copy.copy(self.graph) + self.assertIsNot(copy2, self.graph) + self.check_make_xgraph(copy2, resolved=False) + copy3 = copy.deepcopy(self.graph) + self.assertIsNot(copy3, self.graph) + self.check_make_xgraph(copy3, resolved=False) + + def test_resolved_copies(self) -> None: + """Test making copies of a resolved PipelineGraph.""" + self.graph.resolve(MockRegistry(self.dimensions, {})) + copy1 = self.graph.copy() + self.assertIsNot(copy1, self.graph) + self.check_make_xgraph(copy1, resolved=True) + copy2 = copy.copy(self.graph) + self.assertIsNot(copy2, self.graph) + self.check_make_xgraph(copy2, resolved=True) + copy3 = copy.deepcopy(self.graph) + self.assertIsNot(copy3, self.graph) + self.check_make_xgraph(copy3, resolved=True) + + def check_base_accessors(self, graph: PipelineGraph) -> None: + """Run parameterized tests that check attribute access, iteration, and + simple methods. + + The given graph must be unchanged from the one defined in `setUp`, + other than sorting. + """ + self.assertEqual(graph.description, self.description) + self.assertEqual(graph.tasks.keys(), {"a", "b"}) + self.assertEqual( + graph.dataset_types.keys(), + { + "schema", + "input_1", + "intermediate_1", + "output_1", + "a_config", + "a_log", + "a_metadata", + "b_config", + "b_log", + "b_metadata", + }, + ) + self.assertEqual(graph.task_subsets.keys(), {"only_b"}) + self.assertEqual( + {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=False)}, + { + ( + NodeKey(NodeType.DATASET_TYPE, "input_1"), + NodeKey(NodeType.TASK, "a"), + "input_1 -> a (input1)", + ), + ( + NodeKey(NodeType.TASK, "a"), + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), + "a -> intermediate_1 (output1)", + ), + ( + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), + NodeKey(NodeType.TASK, "b"), + "intermediate_1 -> b (input1)", + ), + ( + NodeKey(NodeType.TASK, "b"), + NodeKey(NodeType.DATASET_TYPE, "output_1"), + "b -> output_1 (output1)", + ), + (NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.DATASET_TYPE, "a_log"), "a -> a_log (_log)"), + ( + NodeKey(NodeType.TASK, "a"), + NodeKey(NodeType.DATASET_TYPE, "a_metadata"), + "a -> a_metadata (_metadata)", + ), + (NodeKey(NodeType.TASK, "b"), NodeKey(NodeType.DATASET_TYPE, "b_log"), "b -> b_log (_log)"), + ( + NodeKey(NodeType.TASK, "b"), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"), + "b -> b_metadata (_metadata)", + ), + }, + ) + self.assertEqual( + {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=True)}, + { + ( + NodeKey(NodeType.TASK_INIT, "a"), + NodeKey(NodeType.DATASET_TYPE, "schema"), + "a -> schema (out_schema)", + ), + ( + NodeKey(NodeType.DATASET_TYPE, "schema"), + NodeKey(NodeType.TASK_INIT, "b"), + "schema -> b (in_schema)", + ), + ( + NodeKey(NodeType.TASK_INIT, "a"), + NodeKey(NodeType.DATASET_TYPE, "a_config"), + "a -> a_config (_config)", + ), + ( + NodeKey(NodeType.TASK_INIT, "b"), + NodeKey(NodeType.DATASET_TYPE, "b_config"), + "b -> b_config (_config)", + ), + }, + ) + self.assertEqual( + {(node_type, name) for node_type, name, _ in graph.iter_nodes()}, + { + NodeKey(NodeType.TASK, "a"), + NodeKey(NodeType.TASK, "b"), + NodeKey(NodeType.TASK_INIT, "a"), + NodeKey(NodeType.TASK_INIT, "b"), + NodeKey(NodeType.DATASET_TYPE, "schema"), + NodeKey(NodeType.DATASET_TYPE, "input_1"), + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), + NodeKey(NodeType.DATASET_TYPE, "output_1"), + NodeKey(NodeType.DATASET_TYPE, "a_config"), + NodeKey(NodeType.DATASET_TYPE, "a_log"), + NodeKey(NodeType.DATASET_TYPE, "a_metadata"), + NodeKey(NodeType.DATASET_TYPE, "b_config"), + NodeKey(NodeType.DATASET_TYPE, "b_log"), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"), + }, + ) + self.assertEqual({name for name, _ in graph.iter_overall_inputs()}, {"input_1"}) + self.assertEqual({edge.task_label for edge in graph.consumers_of("input_1")}, {"a"}) + self.assertEqual({edge.task_label for edge in graph.consumers_of("intermediate_1")}, {"b"}) + self.assertEqual({edge.task_label for edge in graph.consumers_of("output_1")}, set()) + self.assertIsNone(graph.producer_of("input_1")) + self.assertEqual(graph.producer_of("intermediate_1").task_label, "a") + self.assertEqual(graph.producer_of("output_1").task_label, "b") + self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks=")) + self.assertEqual( + repr(graph.task_subsets["only_b"]), f"only_b: {self.subset_description!r}, tasks={{b}}" + ) + + def check_sorted(self, graph: PipelineGraph) -> None: + """Run a battery of tests on a PipelineGraph that must be + deterministically sorted. + + The given graph must be unchanged from the one defined in `setUp`, + other than sorting. + """ + self.assertTrue(graph.has_been_sorted) + self.assertTrue(graph.is_sorted) + self.assertEqual( + [(node_type, name) for node_type, name, _ in graph.iter_nodes()], + [ + # We only advertise that the order is topological and + # deterministic, so this test is slightly over-specified; there + # are other orders that are consistent with our guarantees. + NodeKey(NodeType.DATASET_TYPE, "input_1"), + NodeKey(NodeType.TASK_INIT, "a"), + NodeKey(NodeType.DATASET_TYPE, "a_config"), + NodeKey(NodeType.DATASET_TYPE, "schema"), + NodeKey(NodeType.TASK_INIT, "b"), + NodeKey(NodeType.DATASET_TYPE, "b_config"), + NodeKey(NodeType.TASK, "a"), + NodeKey(NodeType.DATASET_TYPE, "a_log"), + NodeKey(NodeType.DATASET_TYPE, "a_metadata"), + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), + NodeKey(NodeType.TASK, "b"), + NodeKey(NodeType.DATASET_TYPE, "b_log"), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"), + NodeKey(NodeType.DATASET_TYPE, "output_1"), + ], + ) + # Most users should only care that the tasks and dataset types are + # topologically sorted. + self.assertEqual(list(graph.tasks), ["a", "b"]) + self.assertEqual( + list(graph.dataset_types), + [ + "input_1", + "a_config", + "schema", + "b_config", + "a_log", + "a_metadata", + "intermediate_1", + "b_log", + "b_metadata", + "output_1", + ], + ) + # __str__ and __repr__ of course work on unsorted mapping views, too, + # but the order of elements is then nondeterministic and hard to test. + self.assertEqual(repr(self.graph.tasks), "TaskMappingView({a, b})") + self.assertEqual( + repr(self.graph.dataset_types), + ( + "DatasetTypeMappingView({input_1, a_config, schema, b_config, a_log, a_metadata, " + "intermediate_1, b_log, b_metadata, output_1})" + ), + ) + + def check_make_xgraph( + self, graph: PipelineGraph, resolved: bool, imported_and_configured: bool = True + ) -> None: + """Check that the given graph exports as expected to networkx. + + The given graph must be unchanged from the one defined in `setUp`, + other than being resolved (if ``resolved=True``) or round-tripped + through serialization without tasks being imported (if + ``imported_and_configured=False``). + """ + xgraph = graph.make_xgraph() + expected_edges = ( + {edge.key for edge in graph.iter_edges()} + | {edge.key for edge in graph.iter_edges(init=True)} + | { + (NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK, "a"), Edge.INIT_TO_TASK_NAME), + (NodeKey(NodeType.TASK_INIT, "b"), NodeKey(NodeType.TASK, "b"), Edge.INIT_TO_TASK_NAME), + } + ) + test_edges = set(xgraph.edges) + self.assertEqual(test_edges, expected_edges) + expected_nodes = { + NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node( + "a", resolved, imported_and_configured=imported_and_configured + ), + NodeKey(NodeType.TASK, "a"): self.get_expected_task_node( + "a", resolved, imported_and_configured=imported_and_configured + ), + NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node( + "b", resolved, imported_and_configured=imported_and_configured + ), + NodeKey(NodeType.TASK, "b"): self.get_expected_task_node( + "b", resolved, imported_and_configured=imported_and_configured + ), + NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( + "schema", resolved, is_initial_query_constraint=False + ), + NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( + "input_1", resolved, is_initial_query_constraint=True + ), + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( + "intermediate_1", resolved, is_initial_query_constraint=False + ), + NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( + "output_1", resolved, is_initial_query_constraint=False + ), + } + test_nodes = dict(xgraph.nodes.items()) + self.assertEqual(set(test_nodes.keys()), set(expected_nodes.keys())) + for key, expected_node in expected_nodes.items(): + test_node = test_nodes[key] + self.assertEqual(expected_node, test_node, key) + + def check_make_bipartite_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: + """Check that the given graph's init-only or runtime subset exports as + expected to networkx. + + The given graph must be unchanged from the one defined in `setUp`, + other than being resolved (if ``resolved=True``). + """ + run_xgraph = graph.make_bipartite_xgraph() + self.assertEqual(set(run_xgraph.edges), {edge.key for edge in graph.iter_edges()}) + self.assertEqual( + dict(run_xgraph.nodes.items()), + { + NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved), + NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( + "input_1", resolved, is_initial_query_constraint=True + ), + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( + "intermediate_1", resolved, is_initial_query_constraint=False + ), + NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( + "output_1", resolved, is_initial_query_constraint=False + ), + }, + ) + init_xgraph = graph.make_bipartite_xgraph( + init=True, + ) + self.assertEqual(set(init_xgraph.edges), {edge.key for edge in graph.iter_edges(init=True)}) + self.assertEqual( + dict(init_xgraph.nodes.items()), + { + NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved), + NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( + "schema", resolved, is_initial_query_constraint=False + ), + NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), + }, + ) + + def check_make_task_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: + """Check that the given graph's task-only projection exports as + expected to networkx. + + The given graph must be unchanged from the one defined in `setUp`, + other than being resolved (if ``resolved=True``). + """ + run_xgraph = graph.make_task_xgraph() + self.assertEqual(set(run_xgraph.edges), {(NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.TASK, "b"))}) + self.assertEqual( + dict(run_xgraph.nodes.items()), + { + NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved), + NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved), + }, + ) + init_xgraph = graph.make_task_xgraph( + init=True, + ) + self.assertEqual( + set(init_xgraph.edges), + {(NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK_INIT, "b"))}, + ) + self.assertEqual( + dict(init_xgraph.nodes.items()), + { + NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved), + NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved), + }, + ) + + def check_make_dataset_type_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: + """Check that the given graph's dataset-type-only projection exports as + expected to networkx. + + The given graph must be unchanged from the one defined in `setUp`, + other than being resolved (if ``resolved=True``). + """ + run_xgraph = graph.make_dataset_type_xgraph() + self.assertEqual( + set(run_xgraph.edges), + { + (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "intermediate_1")), + (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_log")), + (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_metadata")), + ( + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), + NodeKey(NodeType.DATASET_TYPE, "output_1"), + ), + (NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), NodeKey(NodeType.DATASET_TYPE, "b_log")), + ( + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"), + ), + }, + ) + self.assertEqual( + dict(run_xgraph.nodes.items()), + { + NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( + "input_1", resolved, is_initial_query_constraint=True + ), + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( + "intermediate_1", resolved, is_initial_query_constraint=False + ), + NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( + "output_1", resolved, is_initial_query_constraint=False + ), + }, + ) + init_xgraph = graph.make_dataset_type_xgraph(init=True) + self.assertEqual( + set(init_xgraph.edges), + {(NodeKey(NodeType.DATASET_TYPE, "schema"), NodeKey(NodeType.DATASET_TYPE, "b_config"))}, + ) + self.assertEqual( + dict(init_xgraph.nodes.items()), + { + NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( + "schema", resolved, is_initial_query_constraint=False + ), + NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), + }, + ) + + def get_expected_task_node( + self, label: str, resolved: bool, imported_and_configured: bool = True + ) -> dict[str, Any]: + """Construct a networkx-export task node for comparison.""" + result = self.get_expected_task_init_node( + label, resolved, imported_and_configured=imported_and_configured + ) + if resolved: + result["dimensions"] = self.dimensions.empty + result["raw_dimensions"] = frozenset() + return result + + def get_expected_task_init_node( + self, label: str, resolved: bool, imported_and_configured: bool = True + ) -> dict[str, Any]: + """Construct a networkx-export task init for comparison.""" + result = { + "task_class_name": "lsst.pipe.base.tests.mocks.DynamicTestPipelineTask", + "bipartite": 1, + } + if imported_and_configured: + result["task_class"] = DynamicTestPipelineTask + result["config"] = getattr(self, f"{label}_config") + return result + + def get_expected_config_node(self, label: str, resolved: bool) -> dict[str, Any]: + """Construct a networkx-export init-output config dataset type node for + comparison. + """ + if not resolved: + return {"bipartite": 0} + else: + return { + "dataset_type": DatasetType( + acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label), + self.dimensions.empty, + acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, + ), + "is_initial_query_constraint": False, + "is_prerequisite": False, + "dimensions": self.dimensions.empty, + "storage_class_name": acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, + "bipartite": 0, + } + + def get_expected_log_node(self, label: str, resolved: bool) -> dict[str, Any]: + """Construct a networkx-export output log dataset type node for + comparison. + """ + if not resolved: + return {"bipartite": 0} + else: + return { + "dataset_type": DatasetType( + acc.LOG_OUTPUT_TEMPLATE.format(label=label), + self.dimensions.empty, + acc.LOG_OUTPUT_STORAGE_CLASS, + ), + "is_initial_query_constraint": False, + "is_prerequisite": False, + "dimensions": self.dimensions.empty, + "storage_class_name": acc.LOG_OUTPUT_STORAGE_CLASS, + "bipartite": 0, + } + + def get_expected_metadata_node(self, label: str, resolved: bool) -> dict[str, Any]: + """Construct a networkx-export output metadata dataset type node for + comparison. + """ + if not resolved: + return {"bipartite": 0} + else: + return { + "dataset_type": DatasetType( + acc.METADATA_OUTPUT_TEMPLATE.format(label=label), + self.dimensions.empty, + acc.METADATA_OUTPUT_STORAGE_CLASS, + ), + "is_initial_query_constraint": False, + "is_prerequisite": False, + "dimensions": self.dimensions.empty, + "storage_class_name": acc.METADATA_OUTPUT_STORAGE_CLASS, + "bipartite": 0, + } + + def get_expected_connection_node( + self, name: str, resolved: bool, *, is_initial_query_constraint: bool + ) -> dict[str, Any]: + """Construct a networkx-export dataset type node for comparison.""" + if not resolved: + return {"bipartite": 0} + else: + return { + "dataset_type": DatasetType( + name, + self.dimensions.empty, + get_mock_name("StructuredDataDict"), + ), + "is_initial_query_constraint": is_initial_query_constraint, + "is_prerequisite": False, + "dimensions": self.dimensions.empty, + "storage_class_name": get_mock_name("StructuredDataDict"), + "bipartite": 0, + } + + def test_construct_with_data_coordinate(self) -> None: + """Test constructing a graph with a DataCoordinate. + + Since this creates a graph with DimensionUniverse, all tasks added to + it should have resolved dimensions, but not (yet) resolved dataset + types. We use that to test a few other operations in that state. + """ + data_id = DataCoordinate.standardize(instrument="I", universe=self.dimensions) + graph = PipelineGraph(data_id=data_id) + self.assertEqual(graph.universe, self.dimensions) + self.assertEqual(graph.data_id, data_id) + graph.add_task("b1", DynamicTestPipelineTask, self.b_config) + self.assertEqual(graph.tasks["b1"].dimensions, self.dimensions.empty) + # Still can't group by dimensions, because the dataset types aren't + # resolved. + with self.assertRaises(UnresolvedGraphError): + graph.group_by_dimensions() + # Transferring a node from this graph to ``self.graph`` should + # unresolve the dimensions. + self.graph.add_task_nodes([graph.tasks["b1"]]) + self.assertIsNot(self.graph.tasks["b1"], graph.tasks["b1"]) + self.assertFalse(self.graph.tasks["b1"].has_resolved_dimensions) + # Do the opposite transfer, which should resolve dimensions. + graph.add_task_nodes([self.graph.tasks["a"]]) + self.assertIsNot(self.graph.tasks["a"], graph.tasks["a"]) + self.assertTrue(graph.tasks["a"].has_resolved_dimensions) + + def test_group_by_dimensions(self) -> None: + """Test PipelineGraph.group_by_dimensions.""" + with self.assertRaises(UnresolvedGraphError): + self.graph.group_by_dimensions() + self.a_config.dimensions = ["visit"] + self.a_config.outputs["output1"].dimensions = ["visit"] + self.a_config.prerequisite_inputs["prereq1"] = DynamicConnectionConfig( + dataset_type_name="prereq_1", + multiple=True, + dimensions=["htm7"], + is_calibration=True, + ) + self.b_config.dimensions = ["htm7"] + self.b_config.inputs["input1"].dimensions = ["visit"] + self.b_config.inputs["input1"].multiple = True + self.b_config.outputs["output1"].dimensions = ["htm7"] + self.graph.reconfigure_tasks(a=self.a_config, b=self.b_config) + self.graph.resolve(MockRegistry(self.dimensions, {})) + visit_dims = self.dimensions.extract(["visit"]) + htm7_dims = self.dimensions.extract(["htm7"]) + expected = { + self.dimensions.empty: ( + {}, + { + "schema": self.graph.dataset_types["schema"], + "input_1": self.graph.dataset_types["input_1"], + "a_config": self.graph.dataset_types["a_config"], + "b_config": self.graph.dataset_types["b_config"], + }, + ), + visit_dims: ( + {"a": self.graph.tasks["a"]}, + { + "a_log": self.graph.dataset_types["a_log"], + "a_metadata": self.graph.dataset_types["a_metadata"], + "intermediate_1": self.graph.dataset_types["intermediate_1"], + }, + ), + htm7_dims: ( + {"b": self.graph.tasks["b"]}, + { + "b_log": self.graph.dataset_types["b_log"], + "b_metadata": self.graph.dataset_types["b_metadata"], + "output_1": self.graph.dataset_types["output_1"], + }, + ), + } + self.assertEqual(self.graph.group_by_dimensions(), expected) + expected[htm7_dims][1]["prereq_1"] = self.graph.dataset_types["prereq_1"] + self.assertEqual(self.graph.group_by_dimensions(prerequisites=True), expected) + + def test_add_and_remove(self) -> None: + """Tests for adding and removing tasks and task subsets from a + PipelineGraph. + """ + # Can't remove a task while it's still in a subset. + with self.assertRaises(PipelineGraphError): + self.graph.remove_tasks(["b"], drop_from_subsets=False) + # ...unless you remove the subset. + self.graph.remove_task_subset("only_b") + self.assertFalse(self.graph.task_subsets) + ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=False) + self.assertFalse(referencing_subsets) + self.assertEqual(self.graph.tasks.keys(), {"a"}) + # Add that task back in. + self.graph.add_task_nodes([b]) + self.assertEqual(self.graph.tasks.keys(), {"a", "b"}) + # Add the subset back in. + self.graph.add_task_subset("only_b", {"b"}) + self.assertEqual(self.graph.task_subsets.keys(), {"only_b"}) + # Resolve the graph's dataset types and task dimensions. + self.graph.resolve(MockRegistry(self.dimensions, {})) + self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) + self.assertTrue(self.graph.dataset_types.is_resolved("output_1")) + self.assertTrue(self.graph.dataset_types.is_resolved("schema")) + self.assertTrue(self.graph.dataset_types.is_resolved("intermediate_1")) + # Remove the task while removing it from the subset automatically. This + # should also unresolve (only) the referenced dataset types and drop + # any datasets no longer attached to any task. + self.assertEqual(self.graph.tasks.keys(), {"a", "b"}) + ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=True) + self.assertEqual(referencing_subsets, {"only_b"}) + self.assertEqual(self.graph.tasks.keys(), {"a"}) + self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) + self.assertNotIn("output1", self.graph.dataset_types) + self.assertFalse(self.graph.dataset_types.is_resolved("schema")) + self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1")) + + def test_reconfigure(self) -> None: + """Tests for PipelineGraph.reconfigure.""" + self.graph.resolve(MockRegistry(self.dimensions, {})) + self.b_config.outputs["output1"].storage_class = "TaskMetadata" + with self.assertRaises(ValueError): + # Can't check and assume together. + self.graph.reconfigure_tasks( + b=self.b_config, assume_edges_unchanged=True, check_edges_unchanged=True + ) + # Check that graph is unchanged after error. + self.check_base_accessors(self.graph) + with self.assertRaises(EdgesChangedError): + self.graph.reconfigure_tasks(b=self.b_config, check_edges_unchanged=True) + self.check_base_accessors(self.graph) + # Make a change that does affect edges; this will unresolve most + # dataset types. + self.graph.reconfigure_tasks(b=self.b_config) + self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) + self.assertFalse(self.graph.dataset_types.is_resolved("output_1")) + self.assertFalse(self.graph.dataset_types.is_resolved("schema")) + self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1")) + # Resolving again will pick up the new storage class + self.graph.resolve(MockRegistry(self.dimensions, {})) + self.assertEqual( + self.graph.dataset_types["output_1"].storage_class_name, get_mock_name("TaskMetadata") + ) + + +class PipelineGraphResolveTestCase(unittest.TestCase): + """More extensive tests for PipelineGraph.resolve and its primate helper + methods. + + These are in a separate TestCase because they utilize a different `setUp` + from the rest of the `PipelineGraph` tests. + """ + + def setUp(self) -> None: + self.a_config = DynamicTestPipelineTaskConfig() + self.b_config = DynamicTestPipelineTaskConfig() + self.dimensions = DimensionUniverse() + self.maxDiff = None + + def make_graph(self) -> PipelineGraph: + graph = PipelineGraph() + graph.add_task("a", DynamicTestPipelineTask, self.a_config) + graph.add_task("b", DynamicTestPipelineTask, self.b_config) + return graph + + def test_prerequisite_inconsistency(self) -> None: + """Test that we raise an exception when one edge defines a dataset type + as a prerequisite and another does not. + + This test will hopefully someday go away (along with + `DatasetTypeNode.is_prerequisite`) when the QuantumGraph generation + algorithm becomes more flexible. + """ + self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") + self.b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d") + graph = self.make_graph() + with self.assertRaises(ConnectionTypeConsistencyError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_prerequisite_inconsistency_reversed(self) -> None: + """Same as `test_prerequisite_inconsistency`, with the order the edges + are added to the graph reversed. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d") + self.b_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") + graph = self.make_graph() + with self.assertRaises(ConnectionTypeConsistencyError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_prerequisite_output(self) -> None: + """Test that we raise an exception when one edge defines a dataset type + as a prerequisite but another defines it as an output. + """ + self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") + self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") + graph = self.make_graph() + with self.assertRaises(ConnectionTypeConsistencyError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_skypix_missing(self) -> None: + """Test that we raise an exception when one edge uses the "skypix" + dimension as a placeholder but the dataset type is not registered. + """ + self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", dimensions={"skypix"} + ) + graph = self.make_graph() + with self.assertRaises(MissingDatasetTypeError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_skypix_inconsistent(self) -> None: + """Test that we raise an exception when one edge uses the "skypix" + dimension as a placeholder but the rest of the dimensions are + inconsistent with the registered dataset type. + """ + self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", dimensions={"skypix", "visit"} + ) + graph = self.make_graph() + with self.assertRaises(IncompatibleDatasetTypeError): + graph.resolve( + MockRegistry( + self.dimensions, + { + "d": DatasetType( + "d", dimensions=self.dimensions.extract(["htm7"]), storageClass="ArrowTable" + ) + }, + ) + ) + with self.assertRaises(IncompatibleDatasetTypeError): + graph.resolve( + MockRegistry( + self.dimensions, + { + "d": DatasetType( + "d", + dimensions=self.dimensions.extract(["htm7", "visit", "skymap"]), + storageClass="ArrowTable", + ) + }, + ) + ) + + def test_duplicate_outputs(self) -> None: + """Test that we raise an exception when a dataset type node would have + two write edges. + """ + self.a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") + self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") + graph = self.make_graph() + with self.assertRaises(DuplicateOutputError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_component_of_unregistered_parent(self) -> None: + """Test that we raise an exception when a component dataset type's + parent is not registered. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c") + graph = self.make_graph() + with self.assertRaises(MissingDatasetTypeError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_undefined_component(self) -> None: + """Test that we raise an exception when a component dataset type's + parent is registered, but its storage class does not have that + component. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c") + graph = self.make_graph() + with self.assertRaises(IncompatibleDatasetTypeError): + graph.resolve( + MockRegistry( + self.dimensions, + {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))}, + ) + ) + + def test_bad_component_storage_class(self) -> None: + """Test that we raise an exception when a component dataset type's + parent is registered, but does not have that component. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d.schema", storage_class="StructuredDataDict" + ) + graph = self.make_graph() + with self.assertRaises(IncompatibleDatasetTypeError): + graph.resolve( + MockRegistry( + self.dimensions, + {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))}, + ) + ) + + def test_input_storage_class_incompatible_with_registry(self) -> None: + """Test that we raise an exception when an input connection's storage + class is incompatible with the registry definition. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="StructuredDataList" + ) + graph = self.make_graph() + with self.assertRaises(IncompatibleDatasetTypeError): + graph.resolve( + MockRegistry( + self.dimensions, + {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))}, + ) + ) + + def test_output_storage_class_incompatible_with_registry(self) -> None: + """Test that we raise an exception when an output connection's storage + class is incompatible with the registry definition. + """ + self.a_config.outputs["o"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="StructuredDataList" + ) + graph = self.make_graph() + with self.assertRaises(IncompatibleDatasetTypeError): + graph.resolve( + MockRegistry( + self.dimensions, + {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))}, + ) + ) + + def test_input_storage_class_incompatible_with_output(self) -> None: + """Test that we raise an exception when an input connection's storage + class is incompatible with the storage class of the output connection. + """ + self.a_config.outputs["o"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="ArrowTable" + ) + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="StructuredDataList" + ) + graph = self.make_graph() + with self.assertRaises(IncompatibleDatasetTypeError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_ambiguous_storage_class(self) -> None: + """Test that we raise an exception when two input connections define + the same dataset with different storage classes (even compatible ones) + and there is no output connection or registry definition to take + precedence. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable") + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="ArrowAstropy" + ) + graph = self.make_graph() + with self.assertRaises(MissingDatasetTypeError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_inputs_compatible_with_registry(self) -> None: + """Test successful resolution of a dataset type where input edges have + different but compatible storage classes and the dataset type is + already registered. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable") + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="ArrowAstropy" + ) + graph = self.make_graph() + dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame")) + graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type})) + self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type) + a_i = graph.tasks["a"].inputs["i"] + b_i = graph.tasks["b"].inputs["i"] + self.assertEqual( + a_i.adapt_dataset_type(dataset_type), + dataset_type.overrideStorageClass(get_mock_name("ArrowTable")), + ) + self.assertEqual( + b_i.adapt_dataset_type(dataset_type), + dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")), + ) + data_id = DataCoordinate.makeEmpty(self.dimensions) + ref = DatasetRef(dataset_type, data_id, run="r") + a_ref = a_i.adapt_dataset_ref(ref) + b_ref = b_i.adapt_dataset_ref(ref) + self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable"))) + self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy"))) + self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) + self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + + def test_output_compatible_with_registry(self) -> None: + """Test successful resolution of a dataset type where an output edge + has a different but compatible storage class from the dataset type + already registered. + """ + self.a_config.outputs["o"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="ArrowTable" + ) + graph = self.make_graph() + dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame")) + graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type})) + self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type) + a_o = graph.tasks["a"].outputs["o"] + self.assertEqual( + a_o.adapt_dataset_type(dataset_type), + dataset_type.overrideStorageClass(get_mock_name("ArrowTable")), + ) + data_id = DataCoordinate.makeEmpty(self.dimensions) + ref = DatasetRef(dataset_type, data_id, run="r") + a_ref = a_o.adapt_dataset_ref(ref) + self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable"))) + self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) + + def test_inputs_compatible_with_output(self) -> None: + """Test successful resolution of a dataset type where an input edge has + a different but compatible storage class from the output edge, and + the dataset type is not registered. + """ + self.a_config.outputs["o"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="ArrowTable" + ) + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="ArrowAstropy" + ) + graph = self.make_graph() + a_o = graph.tasks["a"].outputs["o"] + b_i = graph.tasks["b"].inputs["i"] + graph.resolve(MockRegistry(self.dimensions, {})) + self.assertEqual(graph.dataset_types["d"].storage_class_name, get_mock_name("ArrowTable")) + self.assertEqual( + a_o.adapt_dataset_type(graph.dataset_types["d"].dataset_type), + graph.dataset_types["d"].dataset_type, + ) + self.assertEqual( + b_i.adapt_dataset_type(graph.dataset_types["d"].dataset_type), + graph.dataset_types["d"].dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")), + ) + data_id = DataCoordinate.makeEmpty(self.dimensions) + ref = DatasetRef(graph.dataset_types["d"].dataset_type, data_id, run="r") + a_ref = a_o.adapt_dataset_ref(ref) + b_ref = b_i.adapt_dataset_ref(ref) + self.assertEqual(a_ref, ref) + self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy"))) + self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) + self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + + def test_component_resolved_by_input(self) -> None: + """Test successful resolution of a component dataset type due to + another input referencing the parent dataset type. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable") + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d.schema", storage_class="ArrowSchema" + ) + graph = self.make_graph() + parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) + graph.resolve(MockRegistry(self.dimensions, {})) + self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) + a_i = graph.tasks["a"].inputs["i"] + b_i = graph.tasks["b"].inputs["i"] + self.assertEqual(b_i.dataset_type_name, "d.schema") + self.assertEqual(a_i.adapt_dataset_type(parent_dataset_type), parent_dataset_type) + self.assertEqual( + b_i.adapt_dataset_type(parent_dataset_type), + parent_dataset_type.makeComponentDatasetType("schema"), + ) + data_id = DataCoordinate.makeEmpty(self.dimensions) + ref = DatasetRef(parent_dataset_type, data_id, run="r") + a_ref = a_i.adapt_dataset_ref(ref) + b_ref = b_i.adapt_dataset_ref(ref) + self.assertEqual(a_ref, ref) + self.assertEqual(b_ref, ref.makeComponentRef("schema")) + self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) + self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + + def test_component_resolved_by_output(self) -> None: + """Test successful resolution of a component dataset type due to + an output connection referencing the parent dataset type. + """ + self.a_config.outputs["o"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="ArrowTable" + ) + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d.schema", storage_class="ArrowSchema" + ) + graph = self.make_graph() + parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) + graph.resolve(MockRegistry(self.dimensions, {})) + self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) + a_o = graph.tasks["a"].outputs["o"] + b_i = graph.tasks["b"].inputs["i"] + self.assertEqual(b_i.dataset_type_name, "d.schema") + self.assertEqual(a_o.adapt_dataset_type(parent_dataset_type), parent_dataset_type) + self.assertEqual( + b_i.adapt_dataset_type(parent_dataset_type), + parent_dataset_type.makeComponentDatasetType("schema"), + ) + data_id = DataCoordinate.makeEmpty(self.dimensions) + ref = DatasetRef(parent_dataset_type, data_id, run="r") + a_ref = a_o.adapt_dataset_ref(ref) + b_ref = b_i.adapt_dataset_ref(ref) + self.assertEqual(a_ref, ref) + self.assertEqual(b_ref, ref.makeComponentRef("schema")) + self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) + self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + + def test_component_resolved_by_registry(self) -> None: + """Test successful resolution of a component dataset type due to + the parent dataset type already being registered. + """ + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d.schema", storage_class="ArrowSchema" + ) + graph = self.make_graph() + parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) + graph.resolve(MockRegistry(self.dimensions, {"d": parent_dataset_type})) + self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) + b_i = graph.tasks["b"].inputs["i"] + self.assertEqual(b_i.dataset_type_name, "d.schema") + self.assertEqual( + b_i.adapt_dataset_type(parent_dataset_type), + parent_dataset_type.makeComponentDatasetType("schema"), + ) + data_id = DataCoordinate.makeEmpty(self.dimensions) + ref = DatasetRef(parent_dataset_type, data_id, run="r") + b_ref = b_i.adapt_dataset_ref(ref) + self.assertEqual(b_ref, ref.makeComponentRef("schema")) + self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main() From 7f0d750d9acb345dfa037c5d9ddb9f72cab99790 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Wed, 1 Mar 2023 15:34:55 -0500 Subject: [PATCH 04/16] Integrate PipelineGraph with Pipeline and use it for sorting. Much of the code changed here is actually stuff I want to deprecate in the future, once PipelineGraph has been integrated with more things. In the meantime, this addresses much the duplication caused by adding PipelineGraph. --- python/lsst/pipe/base/pipeTools.py | 157 ++++++----------------------- python/lsst/pipe/base/pipeline.py | 100 +++++++++++++----- tests/test_pipeTools.py | 23 +---- 3 files changed, 105 insertions(+), 175 deletions(-) diff --git a/python/lsst/pipe/base/pipeTools.py b/python/lsst/pipe/base/pipeTools.py index 8793e27b..6e05a212 100644 --- a/python/lsst/pipe/base/pipeTools.py +++ b/python/lsst/pipe/base/pipeTools.py @@ -27,30 +27,17 @@ # No one should do import * from this module __all__ = ["isPipelineOrdered", "orderPipeline"] -# ------------------------------- -# Imports of standard modules -- -# ------------------------------- -import itertools -from collections.abc import Iterable, Sequence +from collections.abc import Iterable from typing import TYPE_CHECKING -# ----------------------------- -# Imports for other modules -- -# ----------------------------- -from .connections import iterConnections +from .pipeline import Pipeline, TaskDef + +# Exceptions re-exported here for backwards compatibility. +from .pipeline_graph import DuplicateOutputError, PipelineDataCycleError, PipelineGraph # noqa: F401 if TYPE_CHECKING: - from .pipeline import Pipeline, TaskDef from .taskFactory import TaskFactory -# ---------------------------------- -# Local non-exported definitions -- -# ---------------------------------- - -# ------------------------ -# Exported definitions -- -# ------------------------ - class MissingTaskFactoryError(Exception): """Exception raised when client fails to provide TaskFactory instance.""" @@ -58,20 +45,6 @@ class MissingTaskFactoryError(Exception): pass -class DuplicateOutputError(Exception): - """Exception raised when Pipeline has more than one task for the same - output. - """ - - pass - - -class PipelineDataCycleError(Exception): - """Exception raised when Pipeline has data dependency cycle.""" - - pass - - def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskFactory | None = None) -> bool: """Check whether tasks in pipeline are correctly ordered. @@ -80,9 +53,9 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF Parameters ---------- - pipeline : `pipe.base.Pipeline` or `collections.abc.Iterable` [ `TaskDef` ] + pipeline : `Pipeline` or `collections.abc.Iterable` [ `TaskDef` ] Pipeline description. - taskFactory: `pipe.base.TaskFactory`, optional + taskFactory: `TaskFactory`, optional Ignored; present only for backwards compatibility. Returns @@ -97,38 +70,30 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF DuplicateOutputError Raised when there is more than one producer for a dataset type. """ - # Build a map of DatasetType name to producer's index in a pipeline - producerIndex = {} - for idx, taskDef in enumerate(pipeline): - for attr in iterConnections(taskDef.connections, "outputs"): - if attr.name in producerIndex: - raise DuplicateOutputError( - "DatasetType `{}' appears more than once as output".format(attr.name) - ) - producerIndex[attr.name] = idx - - # check all inputs that are also someone's outputs - for idx, taskDef in enumerate(pipeline): - # get task input DatasetTypes, this can only be done via class method - inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs} - for dsTypeDescr in inputs.values(): - # all pre-existing datasets have effective index -1 - prodIdx = producerIndex.get(dsTypeDescr.name, -1) - if prodIdx >= idx: - # not good, producer is downstream - return False - + if isinstance(pipeline, Pipeline): + graph = pipeline.to_graph() + else: + graph = PipelineGraph() + for task_def in pipeline: + graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections) + # Can't use graph.is_sorted because that requires sorted dataset type names + # as well as sorted tasks. + tasks_xgraph = graph.make_task_xgraph() + seen: set[str] = set() + for task_label in tasks_xgraph: + successors = set(tasks_xgraph.successors(task_label)) + if not successors.isdisjoint(seen): + return False + seen.add(task_label) return True -def orderPipeline(pipeline: Sequence[TaskDef]) -> list[TaskDef]: +def orderPipeline(pipeline: Pipeline | Iterable[TaskDef]) -> list[TaskDef]: """Re-order tasks in pipeline to satisfy data dependencies. - When possible new ordering keeps original relative order of the tasks. - Parameters ---------- - pipeline : `pipe.base.Pipeline` or `collections.abc.Iterable` [ `TaskDef` ] + pipeline : `Pipeline` or `collections.abc.Iterable` [ `TaskDef` ] Pipeline description. Returns @@ -143,69 +108,11 @@ def orderPipeline(pipeline: Sequence[TaskDef]) -> list[TaskDef]: PipelineDataCycleError Raised when the pipeline has dependency cycles. """ - # This is a modified version of Kahn's algorithm that preserves order - - # build mapping of the tasks to their inputs and outputs - inputs = {} # maps task index to its input DatasetType names - outputs = {} # maps task index to its output DatasetType names - allInputs = set() # all inputs of all tasks - allOutputs = set() # all outputs of all tasks - dsTypeTaskLabels: dict[str, str] = {} # maps DatasetType name to the label of its parent task - for idx, taskDef in enumerate(pipeline): - # task outputs - dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs} - for dsTypeDescr in dsMap.values(): - if dsTypeDescr.name in allOutputs: - raise DuplicateOutputError( - f"DatasetType `{dsTypeDescr.name}' in task `{taskDef.label}' already appears as an " - f"output in task `{dsTypeTaskLabels[dsTypeDescr.name]}'." - ) - dsTypeTaskLabels[dsTypeDescr.name] = taskDef.label - outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values()) - allOutputs.update(outputs[idx]) - - # task inputs - connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs) - inputs[idx] = set(getattr(taskDef.connections, name).name for name in connectionInputs) - allInputs.update(inputs[idx]) - - # for simplicity add pseudo-node which is a producer for all pre-existing - # inputs, its index is -1 - preExisting = allInputs - allOutputs - outputs[-1] = preExisting - - # Set of nodes with no incoming edges, initially set to pseudo-node - queue = [-1] - result = [] - while queue: - # move to final list, drop -1 - idx = queue.pop(0) - if idx >= 0: - result.append(idx) - - # remove task outputs from other tasks inputs - thisTaskOutputs = outputs.get(idx, set()) - for taskInputs in inputs.values(): - taskInputs -= thisTaskOutputs - - # find all nodes with no incoming edges and move them to the queue - topNodes = [key for key, value in inputs.items() if not value] - queue += topNodes - for key in topNodes: - del inputs[key] - - # keep queue ordered - queue.sort() - - # if there is something left it means cycles - if inputs: - # format it in usable way - loops = [] - for idx, inputNames in inputs.items(): - taskName = pipeline[idx].label - outputNames = outputs[idx] - edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames) - loops.append(edge) - raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops)) - - return [pipeline[idx] for idx in result] + if isinstance(pipeline, Pipeline): + graph = pipeline.to_graph() + else: + graph = PipelineGraph() + for task_def in pipeline: + graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections) + graph.sort() + return list(graph._iter_task_defs()) diff --git a/python/lsst/pipe/base/pipeline.py b/python/lsst/pipe/base/pipeline.py index 65dfc4d0..3636850d 100644 --- a/python/lsst/pipe/base/pipeline.py +++ b/python/lsst/pipe/base/pipeline.py @@ -54,14 +54,12 @@ from lsst.utils.introspection import get_full_type_name from . import automatic_connection_constants as acc -from . import pipelineIR, pipeTools +from . import pipeline_graph, pipelineIR from ._instrument import Instrument as PipeBaseInstrument -from ._task_metadata import TaskMetadata from .config import PipelineTaskConfig from .connections import PipelineTaskConnections, iterConnections from .connectionTypes import Input from .pipelineTask import PipelineTask -from .task import _TASK_METADATA_TYPE if TYPE_CHECKING: # Imports needed only for type annotations; may be circular. from lsst.obs.base import Instrument @@ -750,6 +748,47 @@ def write_to_uri(self, uri: ResourcePathExpression) -> None: """ self._pipelineIR.write_to_uri(uri) + def to_graph(self) -> pipeline_graph.PipelineGraph: + """Construct a pipeline graph from this pipeline. + + Constructing a graph applies all configuration overrides, freezes all + configuration, checks all contracts, and checks for dataset type + consistency between tasks (as much as possible without access to a data + repository). It cannot be reversed. + + Returns + ------- + graph : `pipeline_graph.PipelineGraph` + Representation of the pipeline as a graph. + """ + instrument_class_name = self._pipelineIR.instrument + data_id = {} + if instrument_class_name is not None: + instrument_class: type[Instrument] = doImportType(instrument_class_name) + if instrument_class is not None: + data_id["instrument"] = instrument_class.getName() + graph = pipeline_graph.PipelineGraph(data_id=data_id) + graph.description = self._pipelineIR.description + for label in self._pipelineIR.tasks: + self._add_task_to_graph(label, graph) + if self._pipelineIR.contracts is not None: + label_to_config = {x.label: x.config for x in graph.tasks.values()} + for contract in self._pipelineIR.contracts: + # execute this in its own line so it can raise a good error + # message if there was problems with the eval + success = eval(contract.contract, None, label_to_config) + if not success: + extra_info = f": {contract.msg}" if contract.msg is not None else "" + raise pipelineIR.ContractError( + f"Contract(s) '{contract.contract}' were not satisfied{extra_info}" + ) + for label, subset in self._pipelineIR.labeled_subsets.items(): + graph.add_task_subset( + label, subset.subset, subset.description if subset.description is not None else "" + ) + graph.sort() + return graph + def toExpandedPipeline(self) -> Generator[TaskDef, None, None]: r"""Return a generator of `TaskDef`\s which can be used to create quantum graphs. @@ -766,31 +805,22 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]: If a dataId is supplied in a config block. This is in place for future use """ - taskDefs = [] - for label in self._pipelineIR.tasks: - taskDefs.append(self._buildTaskDef(label)) + yield from self.to_graph()._iter_task_defs() - # lets evaluate the contracts - if self._pipelineIR.contracts is not None: - label_to_config = {x.label: x.config for x in taskDefs} - for contract in self._pipelineIR.contracts: - # execute this in its own line so it can raise a good error - # message if there was problems with the eval - success = eval(contract.contract, None, label_to_config) - if not success: - extra_info = f": {contract.msg}" if contract.msg is not None else "" - raise pipelineIR.ContractError( - f"Contract(s) '{contract.contract}' were not satisfied{extra_info}" - ) - - taskDefs = sorted(taskDefs, key=lambda x: x.label) - yield from pipeTools.orderPipeline(taskDefs) + def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) -> None: + """Add a single task from this pipeline to a pipeline graph that is + under construction. - def _buildTaskDef(self, label: str) -> TaskDef: + Parameters + ---------- + label : `str` + Label for the task to be added. + graph : `pipeline_graph.PipelineGraph` + Graph to add the task to. + """ if (taskIR := self._pipelineIR.tasks.get(label)) is None: raise NameError(f"Label {label} does not appear in this pipeline") taskClass: type[PipelineTask] = doImportType(taskIR.klass) - taskName = get_full_type_name(taskClass) config = taskClass.ConfigClass() instrument: PipeBaseInstrument | None = None if (instrumentName := self._pipelineIR.instrument) is not None: @@ -803,13 +833,19 @@ def _buildTaskDef(self, label: str) -> TaskDef: self._pipelineIR.parameters, label, ) - return TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label) + graph.add_task(label, taskClass, config) def __iter__(self) -> Generator[TaskDef, None, None]: return self.toExpandedPipeline() def __getitem__(self, item: str) -> TaskDef: - return self._buildTaskDef(item) + # Making a whole graph and then making a TaskDef from that is pretty + # backwards, but I'm hoping to deprecate this method shortly in favor + # of making the graph explicitly and working with its node objects. + graph = pipeline_graph.PipelineGraph() + self._add_task_to_graph(item, graph) + (result,) = graph._iter_task_defs() + return result def __len__(self) -> int: return len(self._pipelineIR.tasks) @@ -1083,7 +1119,7 @@ def makeDatasetTypesSet( DatasetType( taskDef.configDatasetName, registry.dimensions.empty, - storageClass="Config", + storageClass=acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, ) ) initOutputs.freeze() @@ -1101,7 +1137,7 @@ def makeDatasetTypesSet( current = registry.getDatasetType(taskDef.metadataDatasetName) except KeyError: # No previous definition so use the default. - storageClass = "TaskMetadata" if _TASK_METADATA_TYPE is TaskMetadata else "PropertySet" + storageClass = acc.METADATA_OUTPUT_STORAGE_CLASS else: storageClass = current.storageClass.name outputs.update({DatasetType(taskDef.metadataDatasetName, dimensions, storageClass)}) @@ -1109,7 +1145,15 @@ def makeDatasetTypesSet( if taskDef.logOutputDatasetName is not None: # Log output dimensions correspond to a task quantum. dimensions = registry.dimensions.extract(taskDef.connections.dimensions) - outputs.update({DatasetType(taskDef.logOutputDatasetName, dimensions, "ButlerLogRecords")}) + outputs.update( + { + DatasetType( + taskDef.logOutputDatasetName, + dimensions, + acc.LOG_OUTPUT_STORAGE_CLASS, + ) + } + ) outputs.freeze() diff --git a/tests/test_pipeTools.py b/tests/test_pipeTools.py index 280879a2..ac7a889f 100644 --- a/tests/test_pipeTools.py +++ b/tests/test_pipeTools.py @@ -136,18 +136,6 @@ def testIsOrdered(self): ) self.assertTrue(pipeTools.isPipelineOrdered(pipeline)) - def testIsOrderedExceptions(self): - """Tests for pipeTools.isPipelineOrdered method exceptions""" - # two producers should throw ValueError - with self.assertRaises(pipeTools.DuplicateOutputError): - _makePipeline( - [ - ("A", "B", "task1"), - ("B", "C", "task2"), - ("A", "C", "task3"), - ] - ) - def testOrderPipeline(self): """Tests for pipeTools.orderPipeline method""" pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")]) @@ -203,16 +191,7 @@ def testOrderPipeline(self): self.assertEqual(pipeline[3].label, "task4") def testOrderPipelineExceptions(self): - """Tests for pipeTools.orderPipeline method exceptions""" - with self.assertRaises(pipeTools.DuplicateOutputError): - _makePipeline( - [ - ("A", "B", "task1"), - ("B", "C", "task2"), - ("A", "C", "task3"), - ] - ) - + """Tests for pipeTools.orderPipeline method exceptions.""" # cycle in a graph should throw ValueError with self.assertRaises(pipeTools.PipelineDataCycleError): _makePipeline([("A", ("A", "B"), "task1")]) From 22cbfa8ff7e3f6ff8c92405a9f34bd4b07c55190 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Thu, 2 Mar 2023 10:29:01 -0500 Subject: [PATCH 05/16] Use PipelineGraph instead of PipelineDatasetTypes in step tester. --- .../pipe/base/tests/pipelineStepTester.py | 40 +++++++------------ 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/python/lsst/pipe/base/tests/pipelineStepTester.py b/python/lsst/pipe/base/tests/pipelineStepTester.py index 22e08e7d..ddbc680f 100644 --- a/python/lsst/pipe/base/tests/pipelineStepTester.py +++ b/python/lsst/pipe/base/tests/pipelineStepTester.py @@ -28,7 +28,7 @@ import unittest from lsst.daf.butler import Butler, DatasetType -from lsst.pipe.base import Pipeline, PipelineDatasetTypes +from lsst.pipe.base import Pipeline @dataclasses.dataclass @@ -88,32 +88,22 @@ def run(self, butler: Butler, test_case: unittest.TestCase) -> None: pure_inputs: dict[str, str] = dict() for suffix in self.step_suffixes: - pipeline = Pipeline.from_uri(self.filename + suffix) - dataset_types = PipelineDatasetTypes.fromPipeline( - pipeline, - registry=butler.registry, - include_configs=False, - include_packages=False, - ) + step_graph = Pipeline.from_uri(self.filename + suffix).to_graph() + step_graph.resolve(butler.registry) - pure_inputs.update({k: suffix for k in dataset_types.prerequisites.names}) - parent_inputs = {t.nameAndComponent()[0] for t in dataset_types.inputs} - pure_inputs.update({k: suffix for k in parent_inputs - all_outputs.keys()}) - all_outputs.update(dataset_types.outputs.asMapping()) - all_outputs.update(dataset_types.intermediates.asMapping()) - - for name in dataset_types.inputs.names & all_outputs.keys(): - test_case.assertTrue( - all_outputs[name].is_compatible_with(dataset_types.inputs[name]), - msg=( - f"dataset type {name} is defined as {dataset_types.inputs[name]} as an " - f"input, but {all_outputs[name]} as an output, and these are not compatible." - ), - ) + pure_inputs.update( + {name: suffix for name, _ in step_graph.iter_overall_inputs() if name not in all_outputs} + ) + all_outputs.update( + { + name: node.dataset_type + for name, node in step_graph.dataset_types.items() + if step_graph.producer_of(name) is not None + } + ) - for dataset_type in dataset_types.outputs | dataset_types.intermediates: - if not dataset_type.isComponent(): - butler.registry.registerDatasetType(dataset_type) + for node in step_graph.dataset_types.values(): + butler.registry.registerDatasetType(node.dataset_type) if not pure_inputs.keys() <= self.expected_inputs: missing = [f"{k} ({pure_inputs[k]})" for k in pure_inputs.keys() - self.expected_inputs] From 6789be535f1287e4dc17e2b6fca256877d1f39e5 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Fri, 23 Jun 2023 14:24:55 -0400 Subject: [PATCH 06/16] Guard against some optional dependencies not being present in tests. --- tests/test_pipeline_graph.py | 60 +++++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/tests/test_pipeline_graph.py b/tests/test_pipeline_graph.py index 6c232bc3..bc11147e 100644 --- a/tests/test_pipeline_graph.py +++ b/tests/test_pipeline_graph.py @@ -29,7 +29,7 @@ import lsst.pipe.base.automatic_connection_constants as acc import lsst.utils.tests -from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse +from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, StorageClassFactory from lsst.daf.butler.registry import MissingDatasetTypeError from lsst.pipe.base.pipeline_graph import ( ConnectionTypeConsistencyError, @@ -872,6 +872,25 @@ def test_reconfigure(self) -> None: ) +def _have_example_storage_classes() -> bool: + """Check whether some storage classes work as expected. + + Given that these have registered converters, it shouldn't actually be + necessary to import be able to those types in order to determine that + they're convertible, but the storage class machinery is implemented such + that types that can't be imported can't be converted, and while that's + inconvenient here it's totally fine in non-testing scenarios where you only + care about a storage class if you can actually use it. + """ + getter = StorageClassFactory().getStorageClass + return ( + getter("ArrowTable").can_convert(getter("ArrowAstropy")) + and getter("ArrowAstropy").can_convert(getter("ArrowTable")) + and getter("ArrowTable").can_convert(getter("DataFrame")) + and getter("DataFrame").can_convert(getter("ArrowTable")) + ) + + class PipelineGraphResolveTestCase(unittest.TestCase): """More extensive tests for PipelineGraph.resolve and its primate helper methods. @@ -952,7 +971,9 @@ def test_skypix_inconsistent(self) -> None: self.dimensions, { "d": DatasetType( - "d", dimensions=self.dimensions.extract(["htm7"]), storageClass="ArrowTable" + "d", + dimensions=self.dimensions.extract(["htm7"]), + storageClass="StructuredDataDict", ) }, ) @@ -965,7 +986,7 @@ def test_skypix_inconsistent(self) -> None: "d": DatasetType( "d", dimensions=self.dimensions.extract(["htm7", "visit", "skymap"]), - storageClass="ArrowTable", + storageClass="StructuredDataDict", ) }, ) @@ -1005,6 +1026,9 @@ def test_undefined_component(self) -> None: ) ) + @unittest.skipUnless( + _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." + ) def test_bad_component_storage_class(self) -> None: """Test that we raise an exception when a component dataset type's parent is registered, but does not have that component. @@ -1033,7 +1057,7 @@ class is incompatible with the registry definition. graph.resolve( MockRegistry( self.dimensions, - {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))}, + {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))}, ) ) @@ -1049,7 +1073,7 @@ class is incompatible with the registry definition. graph.resolve( MockRegistry( self.dimensions, - {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))}, + {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))}, ) ) @@ -1058,7 +1082,7 @@ def test_input_storage_class_incompatible_with_output(self) -> None: class is incompatible with the storage class of the output connection. """ self.a_config.outputs["o"] = DynamicConnectionConfig( - dataset_type_name="d", storage_class="ArrowTable" + dataset_type_name="d", storage_class="StructuredDataDict" ) self.b_config.inputs["i"] = DynamicConnectionConfig( dataset_type_name="d", storage_class="StructuredDataList" @@ -1073,14 +1097,19 @@ def test_ambiguous_storage_class(self) -> None: and there is no output connection or registry definition to take precedence. """ - self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable") + self.a_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="StructuredDataDict" + ) self.b_config.inputs["i"] = DynamicConnectionConfig( - dataset_type_name="d", storage_class="ArrowAstropy" + dataset_type_name="d", storage_class="StructuredDataList" ) graph = self.make_graph() with self.assertRaises(MissingDatasetTypeError): graph.resolve(MockRegistry(self.dimensions, {})) + @unittest.skipUnless( + _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." + ) def test_inputs_compatible_with_registry(self) -> None: """Test successful resolution of a dataset type where input edges have different but compatible storage classes and the dataset type is @@ -1113,6 +1142,9 @@ def test_inputs_compatible_with_registry(self) -> None: self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + @unittest.skipUnless( + _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." + ) def test_output_compatible_with_registry(self) -> None: """Test successful resolution of a dataset type where an output edge has a different but compatible storage class from the dataset type @@ -1136,6 +1168,9 @@ def test_output_compatible_with_registry(self) -> None: self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable"))) self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) + @unittest.skipUnless( + _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." + ) def test_inputs_compatible_with_output(self) -> None: """Test successful resolution of a dataset type where an input edge has a different but compatible storage class from the output edge, and @@ -1169,6 +1204,9 @@ def test_inputs_compatible_with_output(self) -> None: self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + @unittest.skipUnless( + _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." + ) def test_component_resolved_by_input(self) -> None: """Test successful resolution of a component dataset type due to another input referencing the parent dataset type. @@ -1198,6 +1236,9 @@ def test_component_resolved_by_input(self) -> None: self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + @unittest.skipUnless( + _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." + ) def test_component_resolved_by_output(self) -> None: """Test successful resolution of a component dataset type due to an output connection referencing the parent dataset type. @@ -1229,6 +1270,9 @@ def test_component_resolved_by_output(self) -> None: self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + @unittest.skipUnless( + _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." + ) def test_component_resolved_by_registry(self) -> None: """Test successful resolution of a component dataset type due to the parent dataset type already being registered. From 6334faa97fcb3b823037b945804282b6b59dbdc1 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Fri, 23 Jun 2023 14:39:16 -0400 Subject: [PATCH 07/16] Add changelog entry. --- doc/changes/DM-33027.feature.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 doc/changes/DM-33027.feature.md diff --git a/doc/changes/DM-33027.feature.md b/doc/changes/DM-33027.feature.md new file mode 100644 index 00000000..947ecc1d --- /dev/null +++ b/doc/changes/DM-33027.feature.md @@ -0,0 +1 @@ +Add a PipelineGraph class that represents a Pipeline with all configuration overrides applied as a graph. From 69bf1e60cf2e8c6206a6b86c4c2ecf4908d9a9d6 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Thu, 20 Jul 2023 12:08:11 -0400 Subject: [PATCH 08/16] Add edges to DatasetTypeNode. This provides some symmetry with TaskNode and a bit of convenience. --- .../base/pipeline_graph/_dataset_types.py | 34 +++++--- python/lsst/pipe/base/pipeline_graph/io.py | 86 ++++++++++++++----- 2 files changed, 88 insertions(+), 32 deletions(-) diff --git a/python/lsst/pipe/base/pipeline_graph/_dataset_types.py b/python/lsst/pipe/base/pipeline_graph/_dataset_types.py index 4d949edf..b6e0b859 100644 --- a/python/lsst/pipe/base/pipeline_graph/_dataset_types.py +++ b/python/lsst/pipe/base/pipeline_graph/_dataset_types.py @@ -23,6 +23,7 @@ __all__ = ("DatasetTypeNode",) import dataclasses +from collections.abc import Collection from typing import TYPE_CHECKING, Any import networkx @@ -69,6 +70,12 @@ class DatasetTypeNode: the Registry before graph creation. """ + producing_edge: WriteEdge | None + """The edge to the task that produces this dataset type.""" + + consuming_edges: Collection[ReadEdge] + """The edges to tasks that consume this dataset type.""" + @classmethod def _from_edges( cls, key: NodeKey, xgraph: networkx.MultiDiGraph, registry: Registry, previous: DatasetTypeNode | None @@ -109,7 +116,7 @@ def _from_edges( is_initial_query_constraint = True is_prerequisite: bool | None = None producer: str | None = None - write_edge: WriteEdge + producing_edge: WriteEdge | None = None # Iterate over the incoming edges to this node, which represent the # output connections of tasks that write this dataset type; these take # precedence over the inputs in determining the graph-wide dataset type @@ -117,23 +124,26 @@ def _from_edges( # graph to register dataset types). There should only be one such # connection, but we won't necessarily have checked that rule until # here. As a result there can be at most one iteration of this loop. - for _, _, write_edge in xgraph.in_edges(key, data="instance"): + for _, _, producing_edge in xgraph.in_edges(key, data="instance"): + assert producing_edge is not None, "Should only be None if we never loop." if producer is not None: raise DuplicateOutputError( - f"Dataset type {key.name!r} is produced by both {write_edge.task_label!r} " + f"Dataset type {key.name!r} is produced by both {producing_edge.task_label!r} " f"and {producer!r}." ) - producer = write_edge.task_label - dataset_type = write_edge._resolve_dataset_type(dataset_type, universe=registry.dimensions) + producer = producing_edge.task_label + dataset_type = producing_edge._resolve_dataset_type(dataset_type, universe=registry.dimensions) is_prerequisite = False is_initial_query_constraint = False - read_edge: ReadEdge + consuming_edge: ReadEdge consumers: list[str] = [] - read_edges = list(read_edge for _, _, read_edge in xgraph.out_edges(key, data="instance")) + consuming_edges = list( + consuming_edge for _, _, consuming_edge in xgraph.out_edges(key, data="instance") + ) # Put edges that are not component datasets before any edges that are. - read_edges.sort(key=lambda read_edge: read_edge.component is not None) - for read_edge in read_edges: - dataset_type, is_initial_query_constraint, is_prerequisite = read_edge._resolve_dataset_type( + consuming_edges.sort(key=lambda consuming_edge: consuming_edge.component is not None) + for consuming_edge in consuming_edges: + dataset_type, is_initial_query_constraint, is_prerequisite = consuming_edge._resolve_dataset_type( current=dataset_type, universe=registry.dimensions, is_initial_query_constraint=is_initial_query_constraint, @@ -142,13 +152,15 @@ def _from_edges( producer=producer, consumers=consumers, ) - consumers.append(read_edge.task_label) + consumers.append(consuming_edge.task_label) assert dataset_type is not None, "Graph structure guarantees at least one edge." assert is_prerequisite is not None, "Having at least one edge guarantees is_prerequisite is known." return DatasetTypeNode( dataset_type=dataset_type, is_initial_query_constraint=is_initial_query_constraint, is_prerequisite=is_prerequisite, + producing_edge=producing_edge, + consuming_edges=tuple(consuming_edges), ) @property diff --git a/python/lsst/pipe/base/pipeline_graph/io.py b/python/lsst/pipe/base/pipeline_graph/io.py index 09e52df5..5b62c715 100644 --- a/python/lsst/pipe/base/pipeline_graph/io.py +++ b/python/lsst/pipe/base/pipeline_graph/io.py @@ -30,6 +30,7 @@ "SerializedPipelineGraph", ) +from collections.abc import Mapping from typing import Any, TypeVar import networkx @@ -118,14 +119,14 @@ def deserialize_read_edge( self, task_key: NodeKey, connection_name: str, + dataset_type_keys: Mapping[str, NodeKey], is_prerequisite: bool = False, ) -> ReadEdge: """Transform a `SerializedEdge` to a `ReadEdge`.""" parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(self.dataset_type_name) - dataset_type_key = NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name) return ReadEdge( - dataset_type_key, - task_key, + dataset_type_key=dataset_type_keys[parent_dataset_type_name], + task_key=task_key, storage_class_name=self.storage_class, is_prerequisite=is_prerequisite, component=component, @@ -139,12 +140,12 @@ def deserialize_write_edge( self, task_key: NodeKey, connection_name: str, + dataset_type_keys: Mapping[str, NodeKey], ) -> WriteEdge: """Transform a `SerializedEdge` to a `WriteEdge`.""" - dataset_type_key = NodeKey(NodeType.DATASET_TYPE, self.dataset_type_name) return WriteEdge( task_key=task_key, - dataset_type_key=dataset_type_key, + dataset_type_key=dataset_type_keys[self.dataset_type_name], storage_class_name=self.storage_class, connection_name=connection_name, is_calibration=self.is_calibration, @@ -198,21 +199,25 @@ def deserialize( key: NodeKey, task_class_name: str, config_str: str, + dataset_type_keys: Mapping[str, NodeKey], ) -> TaskInitNode: """Transform a `SerializedTaskInitNode` to a `TaskInitNode`.""" return TaskInitNode( key, inputs={ - connection_name: serialized_edge.deserialize_read_edge(key, connection_name) + connection_name: serialized_edge.deserialize_read_edge( + key, connection_name, dataset_type_keys + ) for connection_name, serialized_edge in self.inputs.items() }, outputs={ - connection_name: serialized_edge.deserialize_write_edge(key, connection_name) + connection_name: serialized_edge.deserialize_write_edge( + key, connection_name, dataset_type_keys + ) for connection_name, serialized_edge in self.outputs.items() }, config_output=self.config_output.deserialize_write_edge( - key, - acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, + key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, dataset_type_keys ), task_class_name=task_class_name, config_str=config_str, @@ -289,7 +294,13 @@ def serialize(cls, target: TaskNode) -> SerializedTaskNode: ), ) - def deserialize(self, key: NodeKey, init_key: NodeKey, universe: DimensionUniverse | None) -> TaskNode: + def deserialize( + self, + key: NodeKey, + init_key: NodeKey, + dataset_type_keys: Mapping[str, NodeKey], + universe: DimensionUniverse | None, + ) -> TaskNode: """Transform a `SerializedTaskNode` to a `TaskNode`.""" init = self.init.deserialize( init_key, @@ -297,25 +308,30 @@ def deserialize(self, key: NodeKey, init_key: NodeKey, universe: DimensionUniver config_str=expect_not_none( self.config_str, f"No serialized config file for task with label {key.name!r}." ), + dataset_type_keys=dataset_type_keys, ) inputs = { - connection_name: serialized_edge.deserialize_read_edge(key, connection_name) + connection_name: serialized_edge.deserialize_read_edge(key, connection_name, dataset_type_keys) for connection_name, serialized_edge in self.inputs.items() } prerequisite_inputs = { - connection_name: serialized_edge.deserialize_read_edge(key, connection_name, is_prerequisite=True) + connection_name: serialized_edge.deserialize_read_edge( + key, connection_name, dataset_type_keys, is_prerequisite=True + ) for connection_name, serialized_edge in self.prerequisite_inputs.items() } outputs = { - connection_name: serialized_edge.deserialize_write_edge(key, connection_name) + connection_name: serialized_edge.deserialize_write_edge(key, connection_name, dataset_type_keys) for connection_name, serialized_edge in self.outputs.items() } if (serialized_log_output := self.log_output) is not None: - log_output = serialized_log_output.deserialize_write_edge(key, acc.LOG_OUTPUT_CONNECTION_NAME) + log_output = serialized_log_output.deserialize_write_edge( + key, acc.LOG_OUTPUT_CONNECTION_NAME, dataset_type_keys + ) else: log_output = None metadata_output = self.metadata_output.deserialize_write_edge( - key, acc.METADATA_OUTPUT_CONNECTION_NAME + key, acc.METADATA_OUTPUT_CONNECTION_NAME, dataset_type_keys ) dimensions: frozenset[str] | DimensionGraph if universe is not None: @@ -384,7 +400,9 @@ def serialize(cls, target: DatasetTypeNode | None) -> SerializedDatasetTypeNode: is_prerequisite=target.is_prerequisite, ) - def deserialize(self, key: NodeKey, universe: DimensionUniverse | None) -> DatasetTypeNode | None: + def deserialize( + self, key: NodeKey, xgraph: networkx.MultiDiGraph, universe: DimensionUniverse | None + ) -> DatasetTypeNode | None: """Transform a `SerializedDatasetTypeNode` to a `DatasetTypeNode`.""" if self.dimensions is not None: dataset_type = DatasetType( @@ -404,10 +422,25 @@ def deserialize(self, key: NodeKey, universe: DimensionUniverse | None) -> Datas "but no dimension universe was stored.", ), ) + producer: str | None = None + producing_edge: WriteEdge | None = None + for _, _, producing_edge in xgraph.in_edges(key, data="instance"): + assert producing_edge is not None, "Should only be None if we never loop." + if producer is not None: + raise PipelineGraphReadError( + f"Serialized dataset type {key.name!r} is produced by both " + f"{producing_edge.task_label!r} and {producer!r} in resolved graph." + ) + producer = producing_edge.task_label + consuming_edges = tuple( + consuming_edge for _, _, consuming_edge in xgraph.in_edges(key, data="instance") + ) return DatasetTypeNode( dataset_type=dataset_type, is_prerequisite=self.is_prerequisite, is_initial_query_constraint=self.is_initial_query_constraint, + producing_edge=producing_edge, + consuming_edges=consuming_edges, ) return None @@ -511,18 +544,23 @@ def deserialize( ) xgraph = networkx.MultiDiGraph() sort_index_map: dict[int, NodeKey] = {} + # Save the dataset type keys after the first time we make them - these + # may be tiny objects, but it's still to have only one copy of each + # value floating around the graph. + dataset_type_keys: dict[str, NodeKey] = {} for dataset_type_name, serialized_dataset_type in self.dataset_types.items(): dataset_type_key = NodeKey(NodeType.DATASET_TYPE, dataset_type_name) - dataset_type_node = serialized_dataset_type.deserialize(dataset_type_key, universe) - xgraph.add_node( - dataset_type_key, instance=dataset_type_node, bipartite=NodeType.DATASET_TYPE.value - ) + # We intentionally don't attach a DatasetTypeNode instance here + # yet, since we need edges to do that and those are saved with + # the tasks. + xgraph.add_node(dataset_type_key, bipartite=NodeType.DATASET_TYPE.value) if serialized_dataset_type.index is not None: sort_index_map[serialized_dataset_type.index] = dataset_type_key + dataset_type_keys[dataset_type_name] = dataset_type_key for task_label, serialized_task in self.tasks.items(): task_key = NodeKey(NodeType.TASK, task_label) task_init_key = NodeKey(NodeType.TASK_INIT, task_label) - task_node = serialized_task.deserialize(task_key, task_init_key, universe) + task_node = serialized_task.deserialize(task_key, task_init_key, dataset_type_keys, universe) if serialized_task.index is not None: sort_index_map[serialized_task.index] = task_key if serialized_task.init.index is not None: @@ -558,6 +596,12 @@ def deserialize( write_edge.connection_name, instance=write_edge, ) + # Iterate over dataset types again to add instances. + for dataset_type_name, serialized_dataset_type in self.dataset_types.items(): + dataset_type_key = dataset_type_keys[dataset_type_name] + xgraph.nodes[dataset_type_key]["instance"] = serialized_dataset_type.deserialize( + dataset_type_key, xgraph, universe + ) result = PipelineGraph.__new__(PipelineGraph) result._init_from_args( xgraph, From 02bbc62325320c45b013a02d29a8fcd5f6e42d35 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Thu, 20 Jul 2023 12:45:15 -0400 Subject: [PATCH 09/16] Add, improve simple graph-traversal methods on PipelineGraph. --- .../base/pipeline_graph/_pipeline_graph.py | 175 +++++++++++++++++- .../pipe/base/tests/pipelineStepTester.py | 2 +- tests/test_pipeline_graph.py | 35 +++- 3 files changed, 201 insertions(+), 11 deletions(-) diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py index 5e8ea5c1..644869ff 100644 --- a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py +++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py @@ -37,6 +37,7 @@ from ._dataset_types import DatasetTypeNode from ._edges import Edge, ReadEdge, WriteEdge from ._exceptions import ( + DuplicateOutputError, EdgesChangedError, PipelineDataCycleError, PipelineGraphError, @@ -577,7 +578,7 @@ def sort(self) -> None: ) from err self._reorder(sorted_keys) - def producer_of(self, dataset_type_name: str) -> WriteEdge | None: + def producing_edge_of(self, dataset_type_name: str) -> WriteEdge | None: """Return the `WriteEdge` that links the producing task to the named dataset type. @@ -590,14 +591,36 @@ def producer_of(self, dataset_type_name: str) -> WriteEdge | None: ------- edge : `WriteEdge` or `None` Producing edge or `None` if there isn't one in this graph. + + Raises + ------ + DuplicateOutputError + Raised if there are multiple tasks defined to produce this dataset + type. This is only possible if the graph's dataset types are not + resolved. + + Notes + ----- + On resolved graphs, it may be slightly more efficient to use:: + + graph.dataset_types[dataset_type_name].producing_edge + + but this method works on graphs with unresolved dataset types as well. """ - for _, _, edge in self._xgraph.in_edges( + producer: str | None = None + producing_edge: WriteEdge | None = None + for _, _, producing_edge in self._xgraph.in_edges( NodeKey(NodeType.DATASET_TYPE, dataset_type_name), data="instance" ): - return edge - return None + assert producing_edge is not None, "Should only be None if we never loop." + if producer is not None: + raise DuplicateOutputError( + f"Dataset type {dataset_type_name!r} is produced by both {producing_edge.task_label!r} " + f"and {producer!r}." + ) + return producing_edge - def consumers_of(self, dataset_type_name: str) -> list[ReadEdge]: + def consuming_edges_of(self, dataset_type_name: str) -> list[ReadEdge]: """Return the `ReadEdge` objects that link the named dataset type to the tasks that consume it. @@ -610,6 +633,14 @@ def consumers_of(self, dataset_type_name: str) -> list[ReadEdge]: ------- edges : `list` [ `ReadEdge` ] Edges that connect this dataset type to the tasks that consume it. + + Notes + ----- + On resolved graphs, it may be slightly more efficient to use:: + + graph.dataset_types[dataset_type_name].producing_edges + + but this method works on graphs with unresolved dataset types as well. """ return [ edge @@ -618,6 +649,140 @@ def consumers_of(self, dataset_type_name: str) -> list[ReadEdge]: ) ] + def producer_of(self, dataset_type_name: str) -> TaskNode | TaskInitNode | None: + """Return the `TaskNode` or `TaskInitNode` that writes the given + dataset type. + + Parameters + ---------- + dataset_type_name : `str` + Dataset type name. Must not be a component. + + Returns + ------- + edge : `TaskNode`, `TaskInitNode`, or `None` + Producing node or `None` if there isn't one in this graph. + + Raises + ------ + DuplicateOutputError + Raised if there are multiple tasks defined to produce this dataset + type. This is only possible if the graph's dataset types are not + resolved. + """ + if (producing_edge := self.producing_edge_of(dataset_type_name)) is not None: + return self._xgraph.nodes[producing_edge.task_key]["instance"] + return None + + def consumers_of(self, dataset_type_name: str) -> list[TaskNode | TaskInitNode]: + """Return the `TaskNode` and/or `TaskInitNode` objects that read + the given dataset type. + + Parameters + ---------- + dataset_type_name : `str` + Dataset type name. Must not be a component. + + Returns + ------- + edges : `list` [ `ReadEdge` ] + Edges that connect this dataset type to the tasks that consume it. + + Notes + ----- + On resolved graphs, it may be slightly more efficient to use:: + + graph.dataset_types[dataset_type_name].producing_edges + + but this method works on graphs with unresolved dataset types as well. + """ + return [ + self._xgraph.nodes[consuming_edge.task_key]["instance"] + for consuming_edge in self.consuming_edges_of(dataset_type_name) + ] + + def inputs_of(self, task_label: str, init: bool = False) -> dict[str, DatasetTypeNode | None]: + """Return the dataset types that are inputs to a task. + + Parameters + ---------- + task_label : `str` + Label for the task in the pipeline. + init : `bool`, optional + If `True`, return init-input dataset types instead of runtime + (including prerequisite) inputs. + + Returns + ------- + inputs : `dict` [ `str`, `DatasetTypeNode` or `None` ] + Dictionary parent dataset type name keys and either + `DatasetTypeNode` values (if the dataset type has been resolved) + or `None` values. + + Notes + ----- + To get the input edges of a task or task init node (which provide + information about storage class overrides nd components) use:: + + graph.tasks[task_label].iter_all_inputs() + + or + + graph.tasks[task_label].init.iter_all_inputs() + + or the various mapping attributes of the `TaskNode` and `TaskInitNode` + class. + """ + node: TaskNode | TaskInitNode = self.tasks[task_label] if not init else self.tasks[task_label].init + return { + edge.parent_dataset_type_name: self._xgraph.nodes[edge.dataset_type_key]["instance"] + for edge in node.iter_all_inputs() + } + + def outputs_of( + self, task_label: str, init: bool = False, include_automatic_connections: bool = True + ) -> dict[str, DatasetTypeNode | None]: + """Return the dataset types that are outputs of a task. + + Parameters + ---------- + task_label : `str` + Label for the task in the pipeline. + init : `bool`, optional + If `True`, return init-output dataset types instead of runtime + outputs. + include_automatic_connections : `bool`, optional + Whether to include automatic connections such as configs, metadata, + and logs. + + Returns + ------- + outputs : `dict` [ `str`, `DatasetTypeNode` or `None` ] + Dictionary parent dataset type name keys and either + `DatasetTypeNode` values (if the dataset type has been resolved) + or `None` values. + + Notes + ----- + To get the input edges of a task or task init node (which provide + information about storage class overrides nd components) use:: + + graph.tasks[task_label].iter_all_outputs() + + or + + graph.tasks[task_label].init.iter_all_outputs() + + or the various mapping attributes of the `TaskNode` and `TaskInitNode` + class. + """ + node: TaskNode | TaskInitNode = self.tasks[task_label] if not init else self.tasks[task_label].init + iterable = node.iter_all_outputs() if include_automatic_connections else node.outputs.values() + return { + edge.parent_dataset_type_name: self._xgraph.nodes[edge.dataset_type_key]["instance"] + for edge in iterable + } + def add_task( self, label: str, diff --git a/python/lsst/pipe/base/tests/pipelineStepTester.py b/python/lsst/pipe/base/tests/pipelineStepTester.py index ddbc680f..3411065f 100644 --- a/python/lsst/pipe/base/tests/pipelineStepTester.py +++ b/python/lsst/pipe/base/tests/pipelineStepTester.py @@ -98,7 +98,7 @@ def run(self, butler: Butler, test_case: unittest.TestCase) -> None: { name: node.dataset_type for name, node in step_graph.dataset_types.items() - if step_graph.producer_of(name) is not None + if step_graph.producing_edge_of(name) is not None } ) diff --git a/tests/test_pipeline_graph.py b/tests/test_pipeline_graph.py index bc11147e..dd8718e1 100644 --- a/tests/test_pipeline_graph.py +++ b/tests/test_pipeline_graph.py @@ -358,12 +358,37 @@ def check_base_accessors(self, graph: PipelineGraph) -> None: }, ) self.assertEqual({name for name, _ in graph.iter_overall_inputs()}, {"input_1"}) - self.assertEqual({edge.task_label for edge in graph.consumers_of("input_1")}, {"a"}) - self.assertEqual({edge.task_label for edge in graph.consumers_of("intermediate_1")}, {"b"}) - self.assertEqual({edge.task_label for edge in graph.consumers_of("output_1")}, set()) + self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("input_1")}, {"a"}) + self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("intermediate_1")}, {"b"}) + self.assertEqual({edge.task_label for edge in graph.consuming_edges_of("output_1")}, set()) + self.assertEqual({node.label for node in graph.consumers_of("input_1")}, {"a"}) + self.assertEqual({node.label for node in graph.consumers_of("intermediate_1")}, {"b"}) + self.assertEqual({node.label for node in graph.consumers_of("output_1")}, set()) + + self.assertIsNone(graph.producing_edge_of("input_1")) + self.assertEqual(graph.producing_edge_of("intermediate_1").task_label, "a") + self.assertEqual(graph.producing_edge_of("output_1").task_label, "b") self.assertIsNone(graph.producer_of("input_1")) - self.assertEqual(graph.producer_of("intermediate_1").task_label, "a") - self.assertEqual(graph.producer_of("output_1").task_label, "b") + self.assertEqual(graph.producer_of("intermediate_1").label, "a") + self.assertEqual(graph.producer_of("output_1").label, "b") + + self.assertEqual(graph.inputs_of("a").keys(), {"input_1"}) + self.assertEqual(graph.inputs_of("b").keys(), {"intermediate_1"}) + self.assertEqual(graph.inputs_of("a", init=True).keys(), set()) + self.assertEqual(graph.inputs_of("b", init=True).keys(), {"schema"}) + self.assertEqual(graph.outputs_of("a").keys(), {"intermediate_1", "a_log", "a_metadata"}) + self.assertEqual(graph.outputs_of("b").keys(), {"output_1", "b_log", "b_metadata"}) + self.assertEqual( + graph.outputs_of("a", include_automatic_connections=False).keys(), {"intermediate_1"} + ) + self.assertEqual(graph.outputs_of("b", include_automatic_connections=False).keys(), {"output_1"}) + self.assertEqual(graph.outputs_of("a", init=True).keys(), {"schema", "a_config"}) + self.assertEqual( + graph.outputs_of("a", init=True, include_automatic_connections=False).keys(), {"schema"} + ) + self.assertEqual(graph.outputs_of("b", init=True).keys(), {"b_config"}) + self.assertEqual(graph.outputs_of("b", init=True, include_automatic_connections=False).keys(), set()) + self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks=")) self.assertEqual( repr(graph.task_subsets["only_b"]), f"only_b: {self.subset_description!r}, tasks={{b}}" From 83328320871eb265adbabd6264969bb7889ac2de Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Thu, 20 Jul 2023 17:26:59 -0400 Subject: [PATCH 10/16] Add pipeline_graph module to API reference docs. We can't tell Sphinx about having lifted the PipelineGraph symbol to lsst.pipe.base, unfortunately, as it doesn't like duplicates and can't do aliases. --- doc/lsst.pipe.base/index.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/lsst.pipe.base/index.rst b/doc/lsst.pipe.base/index.rst index 401a035e..aac96aa3 100644 --- a/doc/lsst.pipe.base/index.rst +++ b/doc/lsst.pipe.base/index.rst @@ -77,6 +77,10 @@ Python API reference :no-main-docstr: :skip: BuildId :skip: DatasetTypeName + :skip: PipelineGraph + +.. automodapi:: lsst.pipe.base.pipeline_graph + :no-main-docstr: .. automodapi:: lsst.pipe.base.testUtils :no-main-docstr: From 53cb0d59203961de4b5e7505b6b4844269819ed7 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Thu, 20 Jul 2023 17:27:54 -0400 Subject: [PATCH 11/16] Add option to resolve PipelineGraph when making it from a Pipeline. This is what most uses will want to do anyway, unless they don't have a data repository. --- python/lsst/pipe/base/pipeline.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/lsst/pipe/base/pipeline.py b/python/lsst/pipe/base/pipeline.py index 3636850d..81dc7f29 100644 --- a/python/lsst/pipe/base/pipeline.py +++ b/python/lsst/pipe/base/pipeline.py @@ -748,7 +748,7 @@ def write_to_uri(self, uri: ResourcePathExpression) -> None: """ self._pipelineIR.write_to_uri(uri) - def to_graph(self) -> pipeline_graph.PipelineGraph: + def to_graph(self, registry: Registry | None = None) -> pipeline_graph.PipelineGraph: """Construct a pipeline graph from this pipeline. Constructing a graph applies all configuration overrides, freezes all @@ -756,6 +756,12 @@ def to_graph(self) -> pipeline_graph.PipelineGraph: consistency between tasks (as much as possible without access to a data repository). It cannot be reversed. + Parameters + ---------- + registry : `lsst.daf.butler.Registry`, optional + Data repository client. If provided, the graph's dataset types + and dimensions will be resolved (see `PipelineGraph.resolve`). + Returns ------- graph : `pipeline_graph.PipelineGraph` @@ -787,6 +793,8 @@ def to_graph(self) -> pipeline_graph.PipelineGraph: label, subset.subset, subset.description if subset.description is not None else "" ) graph.sort() + if registry is not None: + graph.resolve(registry) return graph def toExpandedPipeline(self) -> Generator[TaskDef, None, None]: From 0ffd43b2cfecabb604139a6fa363303536cb72a7 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Thu, 20 Jul 2023 17:28:37 -0400 Subject: [PATCH 12/16] Provide user guide for PipelineGraph and reorder its API accordingly. --- doc/lsst.pipe.base/index.rst | 1 + .../working-with-pipeline-graphs.rst | 88 ++ .../base/pipeline_graph/_pipeline_graph.py | 1351 +++++++++-------- 3 files changed, 800 insertions(+), 640 deletions(-) create mode 100644 doc/lsst.pipe.base/working-with-pipeline-graphs.rst diff --git a/doc/lsst.pipe.base/index.rst b/doc/lsst.pipe.base/index.rst index aac96aa3..318ab78c 100644 --- a/doc/lsst.pipe.base/index.rst +++ b/doc/lsst.pipe.base/index.rst @@ -59,6 +59,7 @@ Developing Pipelines creating-a-pipeline.rst testing-pipelines-with-mocks.rst + working-with-pipeline-graphs.rst .. _lsst.pipe.base-contributing: diff --git a/doc/lsst.pipe.base/working-with-pipeline-graphs.rst b/doc/lsst.pipe.base/working-with-pipeline-graphs.rst new file mode 100644 index 00000000..92add070 --- /dev/null +++ b/doc/lsst.pipe.base/working-with-pipeline-graphs.rst @@ -0,0 +1,88 @@ +.. _pipe_base_pipeline_graphs: + +.. py:currentmodule:: lsst.pipe.base.pipeline_graph + +############################ +Working with Pipeline Graphs +############################ + +Pipeline objects are written as YAML documents, but once they are fully configured, they are conceptually directed acyclic graphs (DAGs). +In code, this graph version of a pipeline is represented by the `PipelineGraph` class. +`PipelineGraph` objects are usually constructed by calling the `.Pipeline.to_graph` method:: + + from lsst.daf.butler import Butler + from lsst.pipe.base import Pipeline + + butler = Butler("/some/repo") + pipeline = Pipeline.from_uri("my_pipeline.yaml") + graph = pipeline.to_graph(registry=butler.registry) + +The ``registry`` argument here is optional, but without it the graph will be incomplete ("unresolved") and the pipeline will not be checked for correctness until the `~PipelineGraph.resolve` method is called. +Resolving a graph compares all of the task connections (which are edges in the graph) that reference each dataset type to each other and to the registry's definition of that dataset to determine a common graph-wide definition. +A definition in the registry always takes precedence, followed by the output connection that produces the dataset type (if there is one). +When a pipeline graph is used to register dataset types in a data repository, it is this common definition in the dataset type node that is registered. +Edge dataset type descriptions represent storage class overrides for a task, or specify that the task only wants a component. + +Simple Graph Inspection +----------------------- + +The basic structure of the graph can be explored via the `~PipelineGraph.tasks` and `~PipelineGraph.dataset_types` mapping attributes. +These are keyed by task label and *parent* (never component) dataset type name, and they have `TaskNode` and `DatasetTypeNode` objects as values, respectively. +A resolved pipeline graph is always sorted, which means iterations over these mappings will be in topological order. +`TaskNode` objects have an `~TaskNode.init` attribute that holds a `TaskInitNode` instance - these resemble `TaskNode` objects and have edges to dataset types as well, but these edges represent the "init input" and "init output" connections of those tasks. + +Task and dataset type node objects have attributes holding all of their edges, but to get neighboring nodes, you have to go back to the graph object:: + + task_node = graph.tasks["task_a"] + for edge in task.inputs.values(): + dataset_type_node = graph.dataset_types[edge.parent_dataset_type_name] + print(f"{task_node.label} takes {dataset_type_node.name} as an input.") + +There are also convenience methods on `PipelineGraph` to get the edges or neighbors of a node: + + - `~PipelineGraph.producing_edge_of`: an alternative to `DatasetTypeNode.producing_edge` + - `~PipelineGraph.consuming_edges_of`: an alternative to `DatasetTypeNode.consuming_edges` + - `~PipelineGraph.producer_of`: a shortcut for getting the task that write a dataset type + - `~PipelineGraph.consumers_of`: a shortcut for getting the tasks that read a dataset type + - `~PipelineGraph.inputs_of`: a shortcut for getting the dataset types that a task reads + - `~PipelineGraph.outputs_of`: a shortcut for getting the dataset types that a task writes + +Pipeline graphs also hold the `~PipelineGraph.description` and `~PipelineGraph.data_id` (usually just an instrument value) of the pipeline used to construct them, as well as the same mapping of labeled task subsets (`~PipelineGraph.task_subsets`). + +Modifying PipelineGraphs +------------------------ + +Usually the tasks in a pipeline are set before a `PipelineGraph` is ever constructed. +In some cases it may be more convenient to add tasks to an existing `PipelineGraph`, either because a related graph is being created from an existing one, or because a (rare) task needs to be configured in a way that depends on the content or structure of the rest of the graph. +`PipelineGraph` provides a number of mutation methods: + +- `~PipelineGraph.add_task` adds a brand new task from a `.PipelineTask` type object and its configuration; +- `~PipelineGraph.add_task_nodes` adds one or more tasks from a different `PipelineGraph` instance; +- `~PipelineGraph.reconfigure_tasks` replaces the configuration of an existing task with new configuration (note that this is typically less convenient than adding config *overrides* to a `Pipeline` object, because all configuration in a `PipelineGraph` must be validated and frozen); +- `~PipelineGraph.remove_task_nodes` removes existing tasks; +- `~PipelineGraph.add_task_subset` and `~PipelineGraph.remove_task_subset` modify the mapping of labeled task subsets (which can also be modified in-place). + +**The most important thing to remember when modifying `PipelineGraph` objects is that modifications typically reset some or all of the graph to an unresolved state.** + +The reference documentation for these methods describes exactly what guarantees they make about existing resolutions in detail, and what operations are still supported on unresolved or partially-resolved graphs, but it is easiest to just ensure `resolve` is called after any modifications are complete. + +`PipelineGraph` mutator methods provide strong exception safety (the graph is left unchanged when an exception is raised and caught by calling code) unless the exception type raised is `PipelineGraphExceptionSafetyError`. + +Exporting to NetworkX +--------------------- + +NetworkX is a powerful Python library for graph manipulation, and in addition to being used in the implementation, `PipelineGraph` provides methods to create various native NetworkX graph objects. +The node attributes of these graphs provide much of the same information as the `TaskNode` and `DatasetTypeNode` objects (see the documentation for those objects for details). + +The export methods include: + + - `~PipelineGraph.make_xgraph` exports all nodes, including task nodes, dataset type nodes, and task init nodes, and the edges between them. + This is a `networkx.MultiDiGraph` because there can be (albeit) rarely multiple edges (representing different connections) between a dataset type and a task. + The edges of this graph have attributes as well as the nodes. + - `~PipelineGraph.make_bipartite_graph` exports just task nodes and dataset type nodes and the edges between them (or, if ``init=True``, just task init nodes and the dataset type nodes and edges between them). + A "bipartite" graph is one in which there are two kinds of nodes and edges only connect one type to the other. + This is also a `networkx.MultiDiGraph`, and its edges also have attributes. + - `~PipelineGraph.make_task_graph` exports just task (or task init) nodes; it is one "bipartite projection" of the full graph. + This is a `networkx.DiGraph`, because all dataset types that connect a pair of tasks are rolled into one edge, and edges have no state. + - `~PipelineGraph.make_dataset_type_graph` exports just dataset type nodes; it is one "bipartite projection" of the full graph. + This is a `networkx.DiGraph`, because all tasks that connect a pair of dataset types are rolled into one edge, and edges have no state. diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py index 644869ff..6e29363d 100644 --- a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py +++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py @@ -63,8 +63,8 @@ class PipelineGraph: """A graph representation of fully-configured pipeline. `PipelineGraph` instances are typically constructed by calling - `.Pipeline.to_graph`, but in rare cases constructing and then populating - an empty one may be preferable. + `.Pipeline.to_graph`, but in rare cases constructing and then populating an + empty one may be preferable. Parameters ---------- @@ -79,6 +79,20 @@ class PipelineGraph: in the pipeline definition, if there was one. """ + ########################################################################### + # + # Simple Pipeline Graph Inspection Interface: + # + # - for inspecting graph structure, not modifying it (except to sort and] + # resolve); + # + # - no NodeKey objects, just string dataset type name and task label keys; + # + # - graph structure is represented as a pair of mappings, with methods to + # find neighbors and edges of nodes. + # + ########################################################################### + def __init__( self, *, @@ -95,70 +109,6 @@ def __init__( data_id=data_id, ) - def _init_from_args( - self, - xgraph: networkx.MultiDiGraph | None, - sorted_keys: Sequence[NodeKey] | None, - task_subsets: dict[str, TaskSubset] | None, - description: str, - universe: DimensionUniverse | None, - data_id: DataId | None, - ) -> None: - """Initialize the graph with possibly-nontrivial arguments. - - Parameters - ---------- - xgraph : `networkx.MultiDiGraph` or `None` - The backing networkx graph, or `None` to create an empty one. - This graph has `NodeKey` instances for nodes and the same structure - as the graph exported by `make_xgraph`, but its nodes and edges - have a single ``instance`` attribute that holds a `TaskNode`, - `TaskInitNode`, `DatasetTypeNode` (or `None`), `ReadEdge`, or - `WriteEdge` instance. - sorted_keys : `Sequence` [ `NodeKey` ] or `None` - Topologically sorted sequence of node keys, or `None` if the graph - is not sorted. - task_subsets : `dict` [ `str`, `TaskSubset` ] - Labeled subsets of tasks. Values must be constructed with - ``xgraph`` as their parent graph. - description : `str` - String description for this pipeline. - universe : `lsst.daf.butler.DimensionUniverse` or `None` - Definitions of all dimensions. - data_id : `lsst.daf.butler.DataCoordinate` or other data ID mapping. - Data ID that represents a constraint on all quanta generated from - this pipeline. - - Notes - ----- - Only empty `PipelineGraph` instances should be constructed directly by - users, which sets the signature of ``__init__`` itself, but methods on - `PipelineGraph` and its helper classes need to be able to create them - with state. Those methods can call this after calling ``__new__`` - manually, skipping ``__init__``. - - `PipelineGraph` mutator methods provide strong exception safety (the - graph is left unchanged when an exception is raised and caught) unless - the exception raised is `PipelineGraphExceptionSafetyError`. - """ - self._xgraph = xgraph if xgraph is not None else networkx.MultiDiGraph() - self._sorted_keys: Sequence[NodeKey] | None = None - self._task_subsets = task_subsets if task_subsets is not None else {} - self._description = description - self._tasks = TaskMappingView(self._xgraph) - self._dataset_types = DatasetTypeMappingView(self._xgraph) - self._raw_data_id: dict[str, Any] - if isinstance(data_id, DataCoordinate): - universe = data_id.universe - self._raw_data_id = data_id.byName() - elif data_id is None: - self._raw_data_id = {} - else: - self._raw_data_id = dict(data_id) - self._universe = universe - if sorted_keys is not None: - self._reorder(sorted_keys) - def __repr__(self) -> str: return f"{type(self).__name__}({self.description!r}, tasks={self.tasks!s})" @@ -219,393 +169,110 @@ def task_subsets(self) -> Mapping[str, TaskSubset]: """ return self._task_subsets - def iter_edges(self, init: bool = False) -> Iterator[Edge]: - """Iterate over edges in the graph. - - Parameters - ---------- - init : `bool`, optional - If `True` (`False` is default) iterate over the edges between task - initialization node and init input/output dataset types, instead of - the runtime task nodes and regular input/output/prerequisite - dataset types. - - Returns - ------- - edges : `~collections.abc.Iterator` [ `Edge` ] - A lazy iterator over `Edge` (`WriteEdge` or `ReadEdge`) instances. - - Notes - ----- - This method always returns _either_ init edges or runtime edges, never - both. The full (internal) graph that contains both also includes a - special edge that connects each task init node to its runtime node; - that is also never returned by this method, since it is never a part of - the init-only or runtime-only subgraphs. - """ - edge: Edge - for _, _, edge in self._xgraph.edges(data="instance"): - if edge is not None and edge.is_init == init: - yield edge - - def iter_nodes( - self, - ) -> Iterator[ - tuple[Literal[NodeType.TASK_INIT], str, TaskInitNode] - | tuple[Literal[NodeType.TASK], str, TaskInitNode] - | tuple[Literal[NodeType.DATASET_TYPE], str, DatasetTypeNode | None] - ]: - """Iterate over nodes in the graph. - - Returns - ------- - nodes : `~collections.abc.Iterator` [ `tuple` ] - A lazy iterator over all of the nodes in the graph. Each yielded - element is a tuple of: + @property + def is_sorted(self) -> bool: + """Whether this graph's tasks and dataset types are topologically + sorted with the exact same deterministic tiebreakers that `sort` would + apply. - - the node type enum value (`NodeType`); - - the string name for the node (task label or parent dataset type - name); - - the node value (`TaskNode`, `TaskInitNode`, `DatasetTypeNode`, - or `None` for dataset type nodes that have not been resolved). + This may perform (and then discard) a full sort if `has_been_sorted` is + `False`. If the goal is to obtain a sorted graph, it is better to just + call `sort` without guarding that with an ``if not graph.is_sorted`` + check. """ - key: NodeKey if self._sorted_keys is not None: - for key in self._sorted_keys: - yield key.node_type, key.name, self._xgraph.nodes[key]["instance"] # type: ignore - else: - for key, node in self._xgraph.nodes(data="instance"): - yield key.node_type, key.name, node # type: ignore - - def iter_overall_inputs(self) -> Iterator[tuple[str, DatasetTypeNode | None]]: - """Iterate over all of the dataset types that are consumed but not - produced by the graph. + return True + return all( + sorted == unsorted + for sorted, unsorted in zip(networkx.lexicographical_topological_sort(self._xgraph), self._xgraph) + ) - Returns - ------- - dataset_types : `~collections.abc.Iterator` [ `tuple` ] - A lazy iterator over the overall-input dataset types (including - overall init inputs and prerequisites). Each yielded element is a - tuple of: + @property + def has_been_sorted(self) -> bool: + """Whether this graph's tasks and dataset types have been + topologically sorted (with unspecified but deterministic tiebreakers) + since the last modification to the graph. - - the parent dataset type name; - - the resolved `DatasetTypeNode`, or `None` if the dataset type has - - not been resolved. + This may return `False` if the graph *happens* to be sorted but `sort` + was never called, but it is potentially much faster than `is_sorted`, + which may attempt (and then discard) a full sort if `has_been_sorted` + is `False`. """ - for generation in networkx.algorithms.dag.topological_generations(self._xgraph): - key: NodeKey - for key in generation: - # While we expect all tasks to have at least one input and - # hence never appear in the first topological generation, that - # is not true of task init nodes. - if key.node_type is NodeType.DATASET_TYPE: - yield key.name, self._xgraph.nodes[key]["instance"] - return + return self._sorted_keys is not None - def make_xgraph(self) -> networkx.MultiDiGraph: - """Export a networkx representation of the full pipeline graph, - including both init and runtime edges. + def sort(self) -> None: + """Sort this graph's nodes topologically with deterministic (but + unspecified) tiebreakers. - Returns - ------- - xgraph : `networkx.MultiDiGraph` - Directed acyclic graph with parallel edges. + This does nothing if the graph is already known to be sorted. + """ + if self._sorted_keys is None: + try: + sorted_keys: Sequence[NodeKey] = list(networkx.lexicographical_topological_sort(self._xgraph)) + except networkx.NetworkXUnfeasible as err: # pragma: no cover + # Should't be possible to get here, because we check for cycles + # when adding tasks, but we guard against it anyway. + cycle = networkx.find_cycle(self._xgraph) + raise PipelineDataCycleError( + f"Cycle detected while attempting to sort graph: {cycle}." + ) from err + self._reorder(sorted_keys) - Notes - ----- - The returned graph uses `NodeKey` instances for nodes. Parallel edges - represent the same dataset type appearing in multiple connections for - the same task, and are hence rare. The connection name is used as the - edge key to disambiguate those parallel edges. + def copy(self) -> PipelineGraph: + """Return a copy of this graph that copies all mutable state.""" + xgraph = self._xgraph.copy() + result = PipelineGraph.__new__(PipelineGraph) + result._init_from_args( + xgraph, + self._sorted_keys, + task_subsets={ + k: TaskSubset(xgraph, v.label, set(v._members), v.description) + for k, v in self._task_subsets.items() + }, + description=self._description, + universe=self.universe, + data_id=self._raw_data_id, + ) + return result - Almost all edges connect dataset type nodes to task or task init nodes - or vice versa, but there is also a special edge that connects each task - init node to its runtime node. The existence of these nodes makes the - graph not quite bipartite, unless its init-only and runtime-only - subgraphs. + def __copy__(self) -> PipelineGraph: + # Fully shallow copies are dangerous; we don't want shared mutable + # state to lead to broken class invariants. + return self.copy() - See `TaskNode`, `TaskInitNode`, `DatasetTypeNode`, `ReadEdge`, and - `WriteEdge` for the descriptive node and edge attributes added. - """ - return self._transform_xgraph_state(self._xgraph.copy(), skip_edges=False) + def __deepcopy__(self, memo: dict) -> PipelineGraph: + # Genuine deep copies are unnecessary, since we should only ever care + # that mutable state is copied. + return self.copy() - def make_bipartite_xgraph(self, init: bool = False) -> networkx.MultiDiGraph: - """Return a bipartite networkx representation of just the runtime or - init-time pipeline graph. + def producing_edge_of(self, dataset_type_name: str) -> WriteEdge | None: + """Return the `WriteEdge` that links the producing task to the named + dataset type. Parameters ---------- - init : `bool`, optional - If `True` (`False` is default) return the graph of task - initialization nodes and init input/output dataset types, instead - of the graph of runtime task nodes and regular - input/output/prerequisite dataset types. + dataset_type_name : `str` + Dataset type name. Must not be a component. Returns ------- - xgraph : `networkx.MultiDiGraph` - Directed acyclic graph with parallel edges. + edge : `WriteEdge` or `None` + Producing edge or `None` if there isn't one in this graph. + + Raises + ------ + DuplicateOutputError + Raised if there are multiple tasks defined to produce this dataset + type. This is only possible if the graph's dataset types are not + resolved. Notes ----- - The returned graph uses `NodeKey` instances for nodes. Parallel edges - represent the same dataset type appearing in multiple connections for - the same task, and are hence rare. The connection name is used as the - edge key to disambiguate those parallel edges. + On resolved graphs, it may be slightly more efficient to use:: - This graph is bipartite because each dataset type node only has edges - that connect it to a task [init] node, and vice versa. + graph.dataset_types[dataset_type_name].producing_edge - See `TaskNode`, `TaskInitNode`, `DatasetTypeNode`, `ReadEdge`, and - `WriteEdge` for the descriptive node and edge attributes added. - """ - return self._transform_xgraph_state( - self._make_bipartite_xgraph_internal(init).copy(), skip_edges=False - ) - - def make_task_xgraph(self, init: bool = False) -> networkx.DiGraph: - """Return a networkx representation of just the tasks in the pipeline. - - Parameters - ---------- - init : `bool`, optional - If `True` (`False` is default) return the graph of task - initialization nodes, instead of the graph of runtime task nodes. - - Returns - ------- - xgraph : `networkx.DiGraph` - Directed acyclic graph with no parallel edges. - - Notes - ----- - The returned graph uses `NodeKey` instances for nodes. The dataset - types that link these tasks are not represented at all; edges have no - attributes, and there are no parallel edges. - - See `TaskNode` and `TaskInitNode` for the descriptive node and - attributes added. - """ - bipartite_xgraph = self._make_bipartite_xgraph_internal(init) - task_keys = [ - key - for key, bipartite in bipartite_xgraph.nodes(data="bipartite") - if bipartite == NodeType.TASK.bipartite - ] - return self._transform_xgraph_state( - networkx.algorithms.bipartite.projected_graph(networkx.DiGraph(bipartite_xgraph), task_keys), - skip_edges=True, - ) - - def make_dataset_type_xgraph(self, init: bool = False) -> networkx.DiGraph: - """Return a networkx representation of just the dataset types in the - pipeline. - - Parameters - ---------- - init : `bool`, optional - If `True` (`False` is default) return the graph of init input and - output dataset types, instead of the graph of runtime (input, - output, prerequisite input) dataset types. - - Returns - ------- - xgraph : `networkx.DiGraph` - Directed acyclic graph with no parallel edges. - - Notes - ----- - The returned graph uses `NodeKey` instances for nodes. The tasks that - link these tasks are not represented at all; edges have no attributes, - and there are no parallel edges. - - See `DatasetTypeNode` for the descriptive node and attributes added. - """ - bipartite_xgraph = self._make_bipartite_xgraph_internal(init) - dataset_type_keys = [ - key - for key, bipartite in bipartite_xgraph.nodes(data="bipartite") - if bipartite == NodeType.DATASET_TYPE.bipartite - ] - return self._transform_xgraph_state( - networkx.algorithms.bipartite.projected_graph( - networkx.DiGraph(bipartite_xgraph), dataset_type_keys - ), - skip_edges=True, - ) - - def _make_bipartite_xgraph_internal(self, init: bool) -> networkx.MultiDiGraph: - """Make a bipartite init-only or runtime-only internal subgraph. - - See `make_bipartite_xgraph` for parameters and return values. - - Notes - ----- - This method returns a view of the `PipelineGraph` object's internal - backing graph, and hence should only be called in methods that copy the - result either explicitly or by running a copying algorithm before - returning it to the user. - """ - return self._xgraph.edge_subgraph([edge.key for edge in self.iter_edges(init)]) - - def _transform_xgraph_state(self, xgraph: _G, skip_edges: bool) -> _G: - """Transform networkx graph attributes in-place from the internal - "instance" attributes to the documented exported attributes. - - Parameters - ---------- - xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph` - Graph whose state should be transformed. - skip_edges : `bool` - If `True`, do not transform edge state. - - Returns - ------- - xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph` - The same object passed in, after modification. - - Notes - ----- - This should be called after making a copy of the internal graph but - before any projection down to just task or dataset type nodes, since - it assumes stateful edges. - """ - state: dict[str, Any] - for state in xgraph.nodes.values(): - node_value: TaskInitNode | TaskNode | DatasetTypeNode | None = state.pop("instance") - if node_value is not None: - state.update(node_value._to_xgraph_state()) - if not skip_edges: - for _, _, state in xgraph.edges(data=True): - edge: Edge | None = state.pop("instance", None) - if edge is not None: - state.update(edge._to_xgraph_state()) - return xgraph - - def group_by_dimensions( - self, prerequisites: bool = False - ) -> dict[DimensionGraph, tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]]]: - """Group this graph's tasks and dataset types by their dimensions. - - Parameters - ---------- - prerequisites : `bool`, optional - If `True`, include prerequisite dataset types as well as regular - input and output datasets (including intermediates). - - Returns - ------- - groups : `dict` [ `DimensionGraph`, `tuple` ] - A dictionary of groups keyed by `DimensionGraph`, in which each - value is a tuple of: - - - a `dict` of `TaskNode` instances, keyed by task label - - a `dict` of `DatasetTypeNode` instances, keyed by - dataset type name. - - that have those dimensions. - - Notes - ----- - Init inputs and outputs are always included, but always have empty - dimensions and are hence are all grouped together. - """ - result: dict[DimensionGraph, tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]]] = {} - next_new_value: tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]] = ({}, {}) - for task_label, task_node in self.tasks.items(): - if task_node.dimensions is None: - raise UnresolvedGraphError(f"Task with label {task_label!r} has not been resolved.") - if (group := result.setdefault(task_node.dimensions, next_new_value)) is next_new_value: - next_new_value = ({}, {}) # make new lists for next time - group[0][task_node.label] = task_node - for dataset_type_name, dataset_type_node in self.dataset_types.items(): - if dataset_type_node is None: - raise UnresolvedGraphError(f"Dataset type {dataset_type_name!r} has not been resolved.") - if not dataset_type_node.is_prerequisite or prerequisites: - if ( - group := result.setdefault(dataset_type_node.dataset_type.dimensions, next_new_value) - ) is next_new_value: - next_new_value = ({}, {}) # make new lists for next time - group[1][dataset_type_node.name] = dataset_type_node - return result - - @property - def is_sorted(self) -> bool: - """Whether this graph's tasks and dataset types are topologically - sorted with the exact same deterministic tiebreakers that `sort` would - apply. - - This may perform (and then discard) a full sort if `has_been_sorted` is - `False`. If the goal is to obtain a sorted graph, it is better to just - call `sort` without guarding that with an ``if not graph.is_sorted`` - check. - """ - if self._sorted_keys is not None: - return True - return all( - sorted == unsorted - for sorted, unsorted in zip(networkx.lexicographical_topological_sort(self._xgraph), self._xgraph) - ) - - @property - def has_been_sorted(self) -> bool: - """Whether this graph's tasks and dataset types have been - topologically sorted (with unspecified but deterministic tiebreakers) - since the last modification to the graph. - - This may return `False` if the graph *happens* to be sorted but `sort` - was never called, but it is potentially much faster than `is_sorted`, - which may attempt (and then discard) a full sort if `has_been_sorted` - is `False`. - """ - return self._sorted_keys is not None - - def sort(self) -> None: - """Sort this graph's nodes topologically with deterministic (but - unspecified) tiebreakers. - - This does nothing if the graph is already known to be sorted. - """ - if self._sorted_keys is None: - try: - sorted_keys: Sequence[NodeKey] = list(networkx.lexicographical_topological_sort(self._xgraph)) - except networkx.NetworkXUnfeasible as err: # pragma: no cover - # Should't be possible to get here, because we check for cycles - # when adding tasks, but we guard against it anyway. - cycle = networkx.find_cycle(self._xgraph) - raise PipelineDataCycleError( - f"Cycle detected while attempting to sort graph: {cycle}." - ) from err - self._reorder(sorted_keys) - - def producing_edge_of(self, dataset_type_name: str) -> WriteEdge | None: - """Return the `WriteEdge` that links the producing task to the named - dataset type. - - Parameters - ---------- - dataset_type_name : `str` - Dataset type name. Must not be a component. - - Returns - ------- - edge : `WriteEdge` or `None` - Producing edge or `None` if there isn't one in this graph. - - Raises - ------ - DuplicateOutputError - Raised if there are multiple tasks defined to produce this dataset - type. This is only possible if the graph's dataset types are not - resolved. - - Notes - ----- - On resolved graphs, it may be slightly more efficient to use:: - - graph.dataset_types[dataset_type_name].producing_edge - - but this method works on graphs with unresolved dataset types as well. + but this method works on graphs with unresolved dataset types as well. """ producer: str | None = None producing_edge: WriteEdge | None = None @@ -783,14 +450,96 @@ def outputs_of( for edge in iterable } - def add_task( - self, - label: str, - task_class: type[PipelineTask], - config: PipelineTaskConfig, - connections: PipelineTaskConnections | None = None, - ) -> TaskNode: - """Add a new task to the graph. + def resolve(self, registry: Registry) -> None: + """Resolve all dimensions and dataset types and check them for + consistency. + + Resolving a graph also causes it to be sorted. + + Parameters + ---------- + registry : `lsst.daf.butler.Registry` + Client for the data repository to resolve against. + + Notes + ----- + The `universe` attribute are set to ``registry.dimensions`` and used to + set all `TaskNode.dimensions` attributes. Dataset type nodes are + resolved by first looking for a registry definition, then using the + producing task's definition, then looking for consistency between all + consuming task definitions. + + Raises + ------ + ConnectionTypeConsistencyError + Raised if a prerequisite input for one task appears as a different + kind of connection in any other task. + DuplicateOutputError + Raised if multiple tasks have the same dataset type as an output. + IncompatibleDatasetTypeError + Raised if different tasks have different definitions of a dataset + type. Different but compatible storage classes are permitted. + MissingDatasetTypeError + Raised if a dataset type definition is required to exist in the + data repository but none was found. This should only occur for + dataset types that are not produced by a task in the pipeline and + are consumed with different storage classes or as components by + tasks in the pipeline. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change after import and reconfiguration. + """ + node_key: NodeKey + updates: dict[NodeKey, TaskNode | DatasetTypeNode] = {} + for node_key, node_state in self._xgraph.nodes.items(): + match node_key.node_type: + case NodeType.TASK: + task_node: TaskNode = node_state["instance"] + new_task_node = task_node._resolved(registry.dimensions) + if new_task_node is not task_node: + updates[node_key] = new_task_node + case NodeType.DATASET_TYPE: + dataset_type_node: DatasetTypeNode | None = node_state["instance"] + new_dataset_type_node = DatasetTypeNode._from_edges( + node_key, self._xgraph, registry, previous=dataset_type_node + ) + if new_dataset_type_node is not dataset_type_node: + updates[node_key] = new_dataset_type_node + try: + for node_key, node_value in updates.items(): + self._xgraph.nodes[node_key]["instance"] = node_value + except Exception as err: # pragma: no cover + # There's no known way to get here, but we want to make it + # clear it's a big problem if we do. + raise PipelineGraphExceptionSafetyError( + "Error during dataset type resolution has left the graph in an inconsistent state." + ) from err + self.sort() + self._universe = registry.dimensions + + ########################################################################### + # + # Graph Modification Interface: + # + # - methods to add, remove, and replace tasks; + # + # - methods to add and remove task subsets. + # + # These are all things that are usually done in a Pipeline before making a + # graph at all, but there may be cases where we want to modify the graph + # instead. (These are also the methods used to make a graph from a + # Pipeline, or make a graph from another graph.) + # + ########################################################################### + + def add_task( + self, + label: str, + task_class: type[PipelineTask], + config: PipelineTaskConfig, + connections: PipelineTaskConnections | None = None, + ) -> TaskNode: + """Add a new task to the graph. Parameters ---------- @@ -1093,32 +842,300 @@ def remove_task_subset(self, subset_label: str) -> None: """Remove a labeled set of tasks.""" del self._task_subsets[subset_label] - def copy(self) -> PipelineGraph: - """Return a copy of this graph that copies all mutable state.""" - xgraph = self._xgraph.copy() - result = PipelineGraph.__new__(PipelineGraph) - result._init_from_args( - xgraph, - self._sorted_keys, - task_subsets={ - k: TaskSubset(xgraph, v.label, set(v._members), v.description) - for k, v in self._task_subsets.items() - }, - description=self._description, - universe=self.universe, - data_id=self._raw_data_id, + ########################################################################### + # + # NetworkX Export Interface: + # + # - methods to export the PipelineGraph's content (or various subsets + # thereof) as NetworkX objects. + # + # These are particularly useful when writing tools to visualize the graph, + # while providing options for which aspects of the graph (tasks, dataset + # types, or both) to include, since all exported graphs have similar + # attributes regardless of their structure. + # + ########################################################################### + + def make_xgraph(self) -> networkx.MultiDiGraph: + """Export a networkx representation of the full pipeline graph, + including both init and runtime edges. + + Returns + ------- + xgraph : `networkx.MultiDiGraph` + Directed acyclic graph with parallel edges. + + Notes + ----- + The returned graph uses `NodeKey` instances for nodes. Parallel edges + represent the same dataset type appearing in multiple connections for + the same task, and are hence rare. The connection name is used as the + edge key to disambiguate those parallel edges. + + Almost all edges connect dataset type nodes to task or task init nodes + or vice versa, but there is also a special edge that connects each task + init node to its runtime node. The existence of these edges makes the + graph not quite bipartite, though its init-only and runtime-only + subgraphs are bipartite. + + See `TaskNode`, `TaskInitNode`, `DatasetTypeNode`, `ReadEdge`, and + `WriteEdge` for the descriptive node and edge attributes added. + """ + return self._transform_xgraph_state(self._xgraph.copy(), skip_edges=False) + + def make_bipartite_xgraph(self, init: bool = False) -> networkx.MultiDiGraph: + """Return a bipartite networkx representation of just the runtime or + init-time pipeline graph. + + Parameters + ---------- + init : `bool`, optional + If `True` (`False` is default) return the graph of task + initialization nodes and init input/output dataset types, instead + of the graph of runtime task nodes and regular + input/output/prerequisite dataset types. + + Returns + ------- + xgraph : `networkx.MultiDiGraph` + Directed acyclic graph with parallel edges. + + Notes + ----- + The returned graph uses `NodeKey` instances for nodes. Parallel edges + represent the same dataset type appearing in multiple connections for + the same task, and are hence rare. The connection name is used as the + edge key to disambiguate those parallel edges. + + This graph is bipartite because each dataset type node only has edges + that connect it to a task [init] node, and vice versa. + + See `TaskNode`, `TaskInitNode`, `DatasetTypeNode`, `ReadEdge`, and + `WriteEdge` for the descriptive node and edge attributes added. + """ + return self._transform_xgraph_state( + self._make_bipartite_xgraph_internal(init).copy(), skip_edges=False ) - return result - def __copy__(self) -> PipelineGraph: - # Fully shallow copies are dangerous; we don't want shared mutable - # state to lead to broken class invariants. - return self.copy() + def make_task_xgraph(self, init: bool = False) -> networkx.DiGraph: + """Return a networkx representation of just the tasks in the pipeline. - def __deepcopy__(self, memo: dict) -> PipelineGraph: - # Genuine deep copies are unnecessary, since we should only ever care - # that mutable state is copied. - return self.copy() + Parameters + ---------- + init : `bool`, optional + If `True` (`False` is default) return the graph of task + initialization nodes, instead of the graph of runtime task nodes. + + Returns + ------- + xgraph : `networkx.DiGraph` + Directed acyclic graph with no parallel edges. + + Notes + ----- + The returned graph uses `NodeKey` instances for nodes. The dataset + types that link these tasks are not represented at all; edges have no + attributes, and there are no parallel edges. + + See `TaskNode` and `TaskInitNode` for the descriptive node and + attributes added. + """ + bipartite_xgraph = self._make_bipartite_xgraph_internal(init) + task_keys = [ + key + for key, bipartite in bipartite_xgraph.nodes(data="bipartite") + if bipartite == NodeType.TASK.bipartite + ] + return self._transform_xgraph_state( + networkx.algorithms.bipartite.projected_graph(networkx.DiGraph(bipartite_xgraph), task_keys), + skip_edges=True, + ) + + def make_dataset_type_xgraph(self, init: bool = False) -> networkx.DiGraph: + """Return a networkx representation of just the dataset types in the + pipeline. + + Parameters + ---------- + init : `bool`, optional + If `True` (`False` is default) return the graph of init input and + output dataset types, instead of the graph of runtime (input, + output, prerequisite input) dataset types. + + Returns + ------- + xgraph : `networkx.DiGraph` + Directed acyclic graph with no parallel edges. + + Notes + ----- + The returned graph uses `NodeKey` instances for nodes. The tasks that + link these tasks are not represented at all; edges have no attributes, + and there are no parallel edges. + + See `DatasetTypeNode` for the descriptive node and attributes added. + """ + bipartite_xgraph = self._make_bipartite_xgraph_internal(init) + dataset_type_keys = [ + key + for key, bipartite in bipartite_xgraph.nodes(data="bipartite") + if bipartite == NodeType.DATASET_TYPE.bipartite + ] + return self._transform_xgraph_state( + networkx.algorithms.bipartite.projected_graph( + networkx.DiGraph(bipartite_xgraph), dataset_type_keys + ), + skip_edges=True, + ) + + ########################################################################### + # + # Serialization Interface. + # + ########################################################################### + + @classmethod + def read_stream( + cls, + stream: BinaryIO, + import_and_configure: bool = True, + check_edges_unchanged: bool = False, + assume_edges_unchanged: bool = False, + ) -> PipelineGraph: + """Read a serialized `PipelineGraph` from a file-like object. + + Parameters + ---------- + stream : `BinaryIO` + File-like object opened for binary reading, containing + gzip-compressed JSON. + import_and_configure : `bool`, optional + If `True`, import and configure all tasks immediately (see the + `import_and_configure` method). If `False`, some `TaskNode` and + `TaskInitNode` attributes will not be available, but reading may be + much faster. + check_edges_unchanged : `bool`, optional + Forwarded to `import_and_configure` after reading. + assume_edges_unchanged : `bool`, optional + Forwarded to `import_and_configure` after reading. + + Returns + ------- + graph : `PipelineGraph` + Deserialized pipeline graph. + + Raises + ------ + PipelineGraphReadError + Raised if the serialized `PipelineGraph` is not self-consistent. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change after import and reconfiguration. + """ + from .io import SerializedPipelineGraph + + with gzip.open(stream, "rb") as uncompressed_stream: + data = json.load(uncompressed_stream) + serialized_graph = SerializedPipelineGraph.parse_obj(data) + return serialized_graph.deserialize( + import_and_configure=import_and_configure, + check_edges_unchanged=check_edges_unchanged, + assume_edges_unchanged=assume_edges_unchanged, + ) + + @classmethod + def read_uri( + cls, + uri: ResourcePathExpression, + import_and_configure: bool = True, + check_edges_unchanged: bool = False, + assume_edges_unchanged: bool = False, + ) -> PipelineGraph: + """Read a serialized `PipelineGraph` from a file at a URI. + + Parameters + ---------- + uri : convertible to `lsst.resources.ResourcePath` + URI to a gzip-compressed JSON file containing a serialized pipeline + graph. + import_and_configure : `bool`, optional + If `True`, import and configure all tasks immediately (see + the `import_and_configure` method). If `False`, some `TaskNode` + and `TaskInitNode` attributes will not be available, but reading + may be much faster. + check_edges_unchanged : `bool`, optional + Forwarded to `import_and_configure` after reading. + assume_edges_unchanged : `bool`, optional + Forwarded to `import_and_configure` after reading. + + Returns + ------- + graph : `PipelineGraph` + Deserialized pipeline graph. + + Raises + ------ + PipelineGraphReadError + Raised if the serialized `PipelineGraph` is not self-consistent. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change after import and reconfiguration. + """ + uri = ResourcePath(uri) + with uri.open("rb") as stream: + return cls.read_stream( + cast(BinaryIO, stream), + import_and_configure=import_and_configure, + check_edges_unchanged=check_edges_unchanged, + assume_edges_unchanged=assume_edges_unchanged, + ) + + def write_stream(self, stream: BinaryIO) -> None: + """Write the pipeline to a file-like object. + + Parameters + ---------- + stream + File-like object opened for binary writing. + + Notes + ----- + The file format is gzipped JSON, and is intended to be human-readable, + but it should not be considered a stable public interface for outside + code, which should always use `PipelineGraph` methods (or at least the + `io.SerializedPipelineGraph` class) to read these files. + """ + from .io import SerializedPipelineGraph + + with gzip.open(stream, mode="wb") as compressed_stream: + compressed_stream.write( + SerializedPipelineGraph.serialize(self).json(exclude_defaults=True, indent=2).encode("utf-8") + ) + + def write_uri(self, uri: ResourcePathExpression) -> None: + """Write the pipeline to a file given a URI. + + Parameters + ---------- + uri : convertible to `lsst.resources.ResourcePath` + URI to write to . May have ``.json.gz`` or no extension (which + will cause a ``.json.gz`` extension to be added). + + Notes + ----- + The file format is gzipped JSON, and is intended to be human-readable, + but it should not be considered a stable public interface for outside + code, which should always use `PipelineGraph` methods (or at least the + `io.SerializedPipelineGraph` class) to read these files. + """ + uri = ResourcePath(uri) + extension = uri.getExtension() + if not extension: + uri = uri.updatedExtension(".json.gz") + elif extension != ".json.gz": + raise ValueError("Expanded pipeline files should always have a .json.gz extension.") + with uri.open(mode="wb") as stream: + self.write_stream(cast(BinaryIO, stream)) def import_and_configure( self, check_edges_unchanged: bool = False, assume_edges_unchanged: bool = False @@ -1184,215 +1201,152 @@ def import_and_configure( ), ) - def resolve(self, registry: Registry) -> None: - """Resolve all dimensions and dataset types and check them for - consistency. + ########################################################################### + # + # Advanced PipelineGraph Inspection Interface: + # + # - methods to iterate over all nodes and edges, utilizing NodeKeys; + # + # - methods to find overall inputs and group nodes by their dimensions, + # which are important operations for QuantumGraph generation. + # + ########################################################################### - Resolving a graph also causes it to be sorted. + def iter_edges(self, init: bool = False) -> Iterator[Edge]: + """Iterate over edges in the graph. Parameters ---------- - registry : `lsst.daf.butler.Registry` - Client for the data repository to resolve against. + init : `bool`, optional + If `True` (`False` is default) iterate over the edges between task + initialization node and init input/output dataset types, instead of + the runtime task nodes and regular input/output/prerequisite + dataset types. + + Returns + ------- + edges : `~collections.abc.Iterator` [ `Edge` ] + A lazy iterator over `Edge` (`WriteEdge` or `ReadEdge`) instances. Notes ----- - The `universe` attribute are set to ``registry.dimensions`` and used to - set all `TaskNode.dimensions` attributes. Dataset type nodes are - resolved by first looking for a registry definition, then using the - producing task's definition, then looking for consistency between all - consuming task definitions. - - Raises - ------ - ConnectionTypeConsistencyError - Raised if a prerequisite input for one task appears as a different - kind of connection in any other task. - DuplicateOutputError - Raised if multiple tasks have the same dataset type as an output. - IncompatibleDatasetTypeError - Raised if different tasks have different definitions of a dataset - type. Different but compatible storage classes are permitted. - MissingDatasetTypeError - Raised if a dataset type definition is required to exist in the - data repository but none was found. This should only occur for - dataset types that are not produced by a task in the pipeline and - are consumed with different storage classes or as components by - tasks in the pipeline. - EdgesChangedError - Raised if ``check_edges_unchanged=True`` and the edges of a task do - change after import and reconfiguration. + This method always returns _either_ init edges or runtime edges, never + both. The full (internal) graph that contains both also includes a + special edge that connects each task init node to its runtime node; + that is also never returned by this method, since it is never a part of + the init-only or runtime-only subgraphs. """ - node_key: NodeKey - updates: dict[NodeKey, TaskNode | DatasetTypeNode] = {} - for node_key, node_state in self._xgraph.nodes.items(): - match node_key.node_type: - case NodeType.TASK: - task_node: TaskNode = node_state["instance"] - new_task_node = task_node._resolved(registry.dimensions) - if new_task_node is not task_node: - updates[node_key] = new_task_node - case NodeType.DATASET_TYPE: - dataset_type_node: DatasetTypeNode | None = node_state["instance"] - new_dataset_type_node = DatasetTypeNode._from_edges( - node_key, self._xgraph, registry, previous=dataset_type_node - ) - if new_dataset_type_node is not dataset_type_node: - updates[node_key] = new_dataset_type_node - try: - for node_key, node_value in updates.items(): - self._xgraph.nodes[node_key]["instance"] = node_value - except Exception as err: # pragma: no cover - # There's no known way to get here, but we want to make it - # clear it's a big problem if we do. - raise PipelineGraphExceptionSafetyError( - "Error during dataset type resolution has left the graph in an inconsistent state." - ) from err - self.sort() - self._universe = registry.dimensions - - @classmethod - def read_stream( - cls, - stream: BinaryIO, - import_and_configure: bool = True, - check_edges_unchanged: bool = False, - assume_edges_unchanged: bool = False, - ) -> PipelineGraph: - """Read a serialized `PipelineGraph` from a file-like object. + edge: Edge + for _, _, edge in self._xgraph.edges(data="instance"): + if edge is not None and edge.is_init == init: + yield edge - Parameters - ---------- - stream : `BinaryIO` - File-like object opened for binary reading, containing - gzip-compressed JSON. - import_and_configure : `bool`, optional - If `True`, import and configure all tasks immediately (see the - `import_and_configure` method). If `False`, some `TaskNode` and - `TaskInitNode` attributes will not be available, but reading may be - much faster. - check_edges_unchanged : `bool`, optional - Forwarded to `import_and_configure` after reading. - assume_edges_unchanged : `bool`, optional - Forwarded to `import_and_configure` after reading. + def iter_nodes( + self, + ) -> Iterator[ + tuple[Literal[NodeType.TASK_INIT], str, TaskInitNode] + | tuple[Literal[NodeType.TASK], str, TaskInitNode] + | tuple[Literal[NodeType.DATASET_TYPE], str, DatasetTypeNode | None] + ]: + """Iterate over nodes in the graph. Returns ------- - graph : `PipelineGraph` - Deserialized pipeline graph. - - Raises - ------ - PipelineGraphReadError - Raised if the serialized `PipelineGraph` is not self-consistent. - EdgesChangedError - Raised if ``check_edges_unchanged=True`` and the edges of a task do - change after import and reconfiguration. - """ - from .io import SerializedPipelineGraph - - with gzip.open(stream, "rb") as uncompressed_stream: - data = json.load(uncompressed_stream) - serialized_graph = SerializedPipelineGraph.parse_obj(data) - return serialized_graph.deserialize( - import_and_configure=import_and_configure, - check_edges_unchanged=check_edges_unchanged, - assume_edges_unchanged=assume_edges_unchanged, - ) + nodes : `~collections.abc.Iterator` [ `tuple` ] + A lazy iterator over all of the nodes in the graph. Each yielded + element is a tuple of: - @classmethod - def read_uri( - cls, - uri: ResourcePathExpression, - import_and_configure: bool = True, - check_edges_unchanged: bool = False, - assume_edges_unchanged: bool = False, - ) -> PipelineGraph: - """Read a serialized `PipelineGraph` from a file at a URI. + - the node type enum value (`NodeType`); + - the string name for the node (task label or parent dataset type + name); + - the node value (`TaskNode`, `TaskInitNode`, `DatasetTypeNode`, + or `None` for dataset type nodes that have not been resolved). + """ + key: NodeKey + if self._sorted_keys is not None: + for key in self._sorted_keys: + yield key.node_type, key.name, self._xgraph.nodes[key]["instance"] # type: ignore + else: + for key, node in self._xgraph.nodes(data="instance"): + yield key.node_type, key.name, node # type: ignore - Parameters - ---------- - uri : convertible to `lsst.resources.ResourcePath` - URI to a gzip-compressed JSON file containing a serialized pipeline - graph. - import_and_configure : `bool`, optional - If `True`, import and configure all tasks immediately (see - the `import_and_configure` method). If `False`, some `TaskNode` - and `TaskInitNode` attributes will not be available, but reading - may be much faster. - check_edges_unchanged : `bool`, optional - Forwarded to `import_and_configure` after reading. - assume_edges_unchanged : `bool`, optional - Forwarded to `import_and_configure` after reading. + def iter_overall_inputs(self) -> Iterator[tuple[str, DatasetTypeNode | None]]: + """Iterate over all of the dataset types that are consumed but not + produced by the graph. Returns ------- - graph : `PipelineGraph` - Deserialized pipeline graph. + dataset_types : `~collections.abc.Iterator` [ `tuple` ] + A lazy iterator over the overall-input dataset types (including + overall init inputs and prerequisites). Each yielded element is a + tuple of: - Raises - ------ - PipelineGraphReadError - Raised if the serialized `PipelineGraph` is not self-consistent. - EdgesChangedError - Raised if ``check_edges_unchanged=True`` and the edges of a task do - change after import and reconfiguration. + - the parent dataset type name; + - the resolved `DatasetTypeNode`, or `None` if the dataset type has + - not been resolved. """ - uri = ResourcePath(uri) - with uri.open("rb") as stream: - return cls.read_stream( - cast(BinaryIO, stream), - import_and_configure=import_and_configure, - check_edges_unchanged=check_edges_unchanged, - assume_edges_unchanged=assume_edges_unchanged, - ) + for generation in networkx.algorithms.dag.topological_generations(self._xgraph): + key: NodeKey + for key in generation: + # While we expect all tasks to have at least one input and + # hence never appear in the first topological generation, that + # is not true of task init nodes. + if key.node_type is NodeType.DATASET_TYPE: + yield key.name, self._xgraph.nodes[key]["instance"] + return - def write_stream(self, stream: BinaryIO) -> None: - """Write the pipeline to a file-like object. + def group_by_dimensions( + self, prerequisites: bool = False + ) -> dict[DimensionGraph, tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]]]: + """Group this graph's tasks and dataset types by their dimensions. Parameters ---------- - stream - File-like object opened for binary writing. - - Notes - ----- - The file format is gzipped JSON, and is intended to be human-readable, - but it should not be considered a stable public interface for outside - code, which should always use `PipelineGraph` methods (or at least the - `io.SerializedPipelineGraph` class) to read these files. - """ - from .io import SerializedPipelineGraph + prerequisites : `bool`, optional + If `True`, include prerequisite dataset types as well as regular + input and output datasets (including intermediates). - with gzip.open(stream, mode="wb") as compressed_stream: - compressed_stream.write( - SerializedPipelineGraph.serialize(self).json(exclude_defaults=True, indent=2).encode("utf-8") - ) + Returns + ------- + groups : `dict` [ `DimensionGraph`, `tuple` ] + A dictionary of groups keyed by `DimensionGraph`, in which each + value is a tuple of: - def write_uri(self, uri: ResourcePathExpression) -> None: - """Write the pipeline to a file given a URI. + - a `dict` of `TaskNode` instances, keyed by task label + - a `dict` of `DatasetTypeNode` instances, keyed by + dataset type name. - Parameters - ---------- - uri : convertible to `lsst.resources.ResourcePath` - URI to write to . May have ``.json.gz`` or no extension (which - will cause a ``.json.gz`` extension to be added). + that have those dimensions. Notes ----- - The file format is gzipped JSON, and is intended to be human-readable, - but it should not be considered a stable public interface for outside - code, which should always use `PipelineGraph` methods (or at least the - `io.SerializedPipelineGraph` class) to read these files. + Init inputs and outputs are always included, but always have empty + dimensions and are hence are all grouped together. """ - uri = ResourcePath(uri) - extension = uri.getExtension() - if not extension: - uri = uri.updatedExtension(".json.gz") - elif extension != ".json.gz": - raise ValueError("Expanded pipeline files should always have a .json.gz extension.") - with uri.open(mode="wb") as stream: - self.write_stream(cast(BinaryIO, stream)) + result: dict[DimensionGraph, tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]]] = {} + next_new_value: tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]] = ({}, {}) + for task_label, task_node in self.tasks.items(): + if task_node.dimensions is None: + raise UnresolvedGraphError(f"Task with label {task_label!r} has not been resolved.") + if (group := result.setdefault(task_node.dimensions, next_new_value)) is next_new_value: + next_new_value = ({}, {}) # make new lists for next time + group[0][task_node.label] = task_node + for dataset_type_name, dataset_type_node in self.dataset_types.items(): + if dataset_type_node is None: + raise UnresolvedGraphError(f"Dataset type {dataset_type_name!r} has not been resolved.") + if not dataset_type_node.is_prerequisite or prerequisites: + if ( + group := result.setdefault(dataset_type_node.dataset_type.dimensions, next_new_value) + ) is next_new_value: + next_new_value = ({}, {}) # make new lists for next time + group[1][dataset_type_node.name] = dataset_type_node + return result + + ########################################################################### + # + # Class- and Package-Private Methods. + # + ########################################################################### def _iter_task_defs(self) -> Iterator[TaskDef]: """Iterate over this pipeline as a sequence of `TaskDef` instances. @@ -1420,6 +1374,114 @@ def _iter_task_defs(self) -> Iterator[TaskDef]: connections=node._get_imported_data().connections, ) + def _init_from_args( + self, + xgraph: networkx.MultiDiGraph | None, + sorted_keys: Sequence[NodeKey] | None, + task_subsets: dict[str, TaskSubset] | None, + description: str, + universe: DimensionUniverse | None, + data_id: DataId | None, + ) -> None: + """Initialize the graph with possibly-nontrivial arguments. + + Parameters + ---------- + xgraph : `networkx.MultiDiGraph` or `None` + The backing networkx graph, or `None` to create an empty one. + This graph has `NodeKey` instances for nodes and the same structure + as the graph exported by `make_xgraph`, but its nodes and edges + have a single ``instance`` attribute that holds a `TaskNode`, + `TaskInitNode`, `DatasetTypeNode` (or `None`), `ReadEdge`, or + `WriteEdge` instance. + sorted_keys : `Sequence` [ `NodeKey` ] or `None` + Topologically sorted sequence of node keys, or `None` if the graph + is not sorted. + task_subsets : `dict` [ `str`, `TaskSubset` ] + Labeled subsets of tasks. Values must be constructed with + ``xgraph`` as their parent graph. + description : `str` + String description for this pipeline. + universe : `lsst.daf.butler.DimensionUniverse` or `None` + Definitions of all dimensions. + data_id : `lsst.daf.butler.DataCoordinate` or other data ID mapping. + Data ID that represents a constraint on all quanta generated from + this pipeline. + + Notes + ----- + Only empty `PipelineGraph` instances should be constructed directly by + users, which sets the signature of ``__init__`` itself, but methods on + `PipelineGraph` and its helper classes need to be able to create them + with state. Those methods can call this after calling ``__new__`` + manually, skipping ``__init__``. + """ + self._xgraph = xgraph if xgraph is not None else networkx.MultiDiGraph() + self._sorted_keys: Sequence[NodeKey] | None = None + self._task_subsets = task_subsets if task_subsets is not None else {} + self._description = description + self._tasks = TaskMappingView(self._xgraph) + self._dataset_types = DatasetTypeMappingView(self._xgraph) + self._raw_data_id: dict[str, Any] + if isinstance(data_id, DataCoordinate): + universe = data_id.universe + self._raw_data_id = data_id.byName() + elif data_id is None: + self._raw_data_id = {} + else: + self._raw_data_id = dict(data_id) + self._universe = universe + if sorted_keys is not None: + self._reorder(sorted_keys) + + def _make_bipartite_xgraph_internal(self, init: bool) -> networkx.MultiDiGraph: + """Make a bipartite init-only or runtime-only internal subgraph. + + See `make_bipartite_xgraph` for parameters and return values. + + Notes + ----- + This method returns a view of the `PipelineGraph` object's internal + backing graph, and hence should only be called in methods that copy the + result either explicitly or by running a copying algorithm before + returning it to the user. + """ + return self._xgraph.edge_subgraph([edge.key for edge in self.iter_edges(init)]) + + def _transform_xgraph_state(self, xgraph: _G, skip_edges: bool) -> _G: + """Transform networkx graph attributes in-place from the internal + "instance" attributes to the documented exported attributes. + + Parameters + ---------- + xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph` + Graph whose state should be transformed. + skip_edges : `bool` + If `True`, do not transform edge state. + + Returns + ------- + xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph` + The same object passed in, after modification. + + Notes + ----- + This should be called after making a copy of the internal graph but + before any projection down to just task or dataset type nodes, since + it assumes stateful edges. + """ + state: dict[str, Any] + for state in xgraph.nodes.values(): + node_value: TaskInitNode | TaskNode | DatasetTypeNode | None = state.pop("instance") + if node_value is not None: + state.update(node_value._to_xgraph_state()) + if not skip_edges: + for _, _, state in xgraph.edges(data=True): + edge: Edge | None = state.pop("instance", None) + if edge is not None: + state.update(edge._to_xgraph_state()) + return xgraph + def _replace_task_nodes( self, updates: Mapping[str, TaskNode], @@ -1552,3 +1614,12 @@ def _reset(self) -> None: self._sorted_keys = None self._tasks._reset() self._dataset_types._reset() + + _xgraph: networkx.MultiDiGraph + _sorted_keys: Sequence[NodeKey] | None + _task_subsets: dict[str, TaskSubset] + _description: str + _tasks: TaskMappingView + _dataset_types: DatasetTypeMappingView + _raw_data_id: dict[str, Any] + _universe: DimensionUniverse | None From 497ac247fb1aa9f8a57bd7d906b72206dcfdf500 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Thu, 20 Jul 2023 21:20:41 -0400 Subject: [PATCH 13/16] Mark PipelineGraph serialization as experimental. --- .../base/pipeline_graph/_pipeline_graph.py | 39 +++++++++++++++---- python/lsst/pipe/base/pipeline_graph/io.py | 2 +- tests/test_pipeline_graph.py | 28 ++++++------- 3 files changed, 47 insertions(+), 22 deletions(-) diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py index 6e29363d..3b1e1671 100644 --- a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py +++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py @@ -992,10 +992,15 @@ def make_dataset_type_xgraph(self, init: bool = False) -> networkx.DiGraph: # # Serialization Interface. # + # Serialization of PipelineGraphs is currently experimental and may not be + # retained in the future. All serialization methods are + # underscore-prefixed to ensure nobody mistakes them for a stable interface + # (let a lone a stable file format). + # ########################################################################### @classmethod - def read_stream( + def _read_stream( cls, stream: BinaryIO, import_and_configure: bool = True, @@ -1031,6 +1036,12 @@ def read_stream( EdgesChangedError Raised if ``check_edges_unchanged=True`` and the edges of a task do change after import and reconfiguration. + + Notes + ----- + `PipelineGraph` serialization is currently experimental and may be + removed or significantly changed in the future, with no deprecation + period. """ from .io import SerializedPipelineGraph @@ -1044,7 +1055,7 @@ def read_stream( ) @classmethod - def read_uri( + def _read_uri( cls, uri: ResourcePathExpression, import_and_configure: bool = True, @@ -1080,17 +1091,23 @@ def read_uri( EdgesChangedError Raised if ``check_edges_unchanged=True`` and the edges of a task do change after import and reconfiguration. + + Notes + ----- + `PipelineGraph` serialization is currently experimental and may be + removed or significantly changed in the future, with no deprecation + period. """ uri = ResourcePath(uri) with uri.open("rb") as stream: - return cls.read_stream( + return cls._read_stream( cast(BinaryIO, stream), import_and_configure=import_and_configure, check_edges_unchanged=check_edges_unchanged, assume_edges_unchanged=assume_edges_unchanged, ) - def write_stream(self, stream: BinaryIO) -> None: + def _write_stream(self, stream: BinaryIO) -> None: """Write the pipeline to a file-like object. Parameters @@ -1100,6 +1117,10 @@ def write_stream(self, stream: BinaryIO) -> None: Notes ----- + `PipelineGraph` serialization is currently experimental and may be + removed or significantly changed in the future, with no deprecation + period. + The file format is gzipped JSON, and is intended to be human-readable, but it should not be considered a stable public interface for outside code, which should always use `PipelineGraph` methods (or at least the @@ -1112,7 +1133,7 @@ def write_stream(self, stream: BinaryIO) -> None: SerializedPipelineGraph.serialize(self).json(exclude_defaults=True, indent=2).encode("utf-8") ) - def write_uri(self, uri: ResourcePathExpression) -> None: + def _write_uri(self, uri: ResourcePathExpression) -> None: """Write the pipeline to a file given a URI. Parameters @@ -1123,6 +1144,10 @@ def write_uri(self, uri: ResourcePathExpression) -> None: Notes ----- + `PipelineGraph` serialization is currently experimental and may be + removed or significantly changed in the future, with no deprecation + period. + The file format is gzipped JSON, and is intended to be human-readable, but it should not be considered a stable public interface for outside code, which should always use `PipelineGraph` methods (or at least the @@ -1135,9 +1160,9 @@ def write_uri(self, uri: ResourcePathExpression) -> None: elif extension != ".json.gz": raise ValueError("Expanded pipeline files should always have a .json.gz extension.") with uri.open(mode="wb") as stream: - self.write_stream(cast(BinaryIO, stream)) + self._write_stream(cast(BinaryIO, stream)) - def import_and_configure( + def _import_and_configure( self, check_edges_unchanged: bool = False, assume_edges_unchanged: bool = False ) -> None: """Import the `PipelineTask` classes referenced by all task nodes and diff --git a/python/lsst/pipe/base/pipeline_graph/io.py b/python/lsst/pipe/base/pipeline_graph/io.py index 5b62c715..33c9ef85 100644 --- a/python/lsst/pipe/base/pipeline_graph/io.py +++ b/python/lsst/pipe/base/pipeline_graph/io.py @@ -615,7 +615,7 @@ def deserialize( data_id=self.data_id, ) if import_and_configure: - result.import_and_configure( + result._import_and_configure( check_edges_unchanged=check_edges_unchanged, assume_edges_unchanged=assume_edges_unchanged, ) diff --git a/tests/test_pipeline_graph.py b/tests/test_pipeline_graph.py index dd8718e1..479e9202 100644 --- a/tests/test_pipeline_graph.py +++ b/tests/test_pipeline_graph.py @@ -133,9 +133,9 @@ def test_unresolved_stream_io(self) -> None: serialization. """ stream = io.BytesIO() - self.graph.write_stream(stream) + self.graph._write_stream(stream) stream.seek(0) - roundtripped = PipelineGraph.read_stream(stream) + roundtripped = PipelineGraph._read_stream(stream) self.check_make_xgraph(roundtripped, resolved=False) def test_unresolved_file_io(self) -> None: @@ -143,8 +143,8 @@ def test_unresolved_file_io(self) -> None: serialization. """ with lsst.utils.tests.getTempFilePath(".json.gz") as filename: - self.graph.write_uri(filename) - roundtripped = PipelineGraph.read_uri(filename) + self.graph._write_uri(filename) + roundtripped = PipelineGraph._read_uri(filename) self.check_make_xgraph(roundtripped, resolved=False) def test_unresolved_deferred_import_io(self) -> None: @@ -152,14 +152,14 @@ def test_unresolved_deferred_import_io(self) -> None: serialization, without immediately importing tasks on read. """ stream = io.BytesIO() - self.graph.write_stream(stream) + self.graph._write_stream(stream) stream.seek(0) - roundtripped = PipelineGraph.read_stream(stream, import_and_configure=False) + roundtripped = PipelineGraph._read_stream(stream, import_and_configure=False) self.check_make_xgraph(roundtripped, resolved=False, imported_and_configured=False) # Check that we can still resolve the graph without importing tasks. roundtripped.resolve(MockRegistry(self.dimensions, {})) self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) - roundtripped.import_and_configure(assume_edges_unchanged=True) + roundtripped._import_and_configure(assume_edges_unchanged=True) self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) def test_resolved_accessors(self) -> None: @@ -198,9 +198,9 @@ def test_resolved_stream_io(self) -> None: """ self.graph.resolve(MockRegistry(self.dimensions, {})) stream = io.BytesIO() - self.graph.write_stream(stream) + self.graph._write_stream(stream) stream.seek(0) - roundtripped = PipelineGraph.read_stream(stream) + roundtripped = PipelineGraph._read_stream(stream) self.check_make_xgraph(roundtripped, resolved=True) def test_resolved_file_io(self) -> None: @@ -209,8 +209,8 @@ def test_resolved_file_io(self) -> None: """ self.graph.resolve(MockRegistry(self.dimensions, {})) with lsst.utils.tests.getTempFilePath(".json.gz") as filename: - self.graph.write_uri(filename) - roundtripped = PipelineGraph.read_uri(filename) + self.graph._write_uri(filename) + roundtripped = PipelineGraph._read_uri(filename) self.check_make_xgraph(roundtripped, resolved=True) def test_resolved_deferred_import_io(self) -> None: @@ -219,11 +219,11 @@ def test_resolved_deferred_import_io(self) -> None: """ self.graph.resolve(MockRegistry(self.dimensions, {})) stream = io.BytesIO() - self.graph.write_stream(stream) + self.graph._write_stream(stream) stream.seek(0) - roundtripped = PipelineGraph.read_stream(stream, import_and_configure=False) + roundtripped = PipelineGraph._read_stream(stream, import_and_configure=False) self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) - roundtripped.import_and_configure(check_edges_unchanged=True) + roundtripped._import_and_configure(check_edges_unchanged=True) self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) def test_unresolved_copies(self) -> None: From 8f0d3fa055bbaa3633846432912c4cdc8e864323 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Mon, 24 Jul 2023 13:45:13 -0400 Subject: [PATCH 14/16] Make SerializedPipelineGraph compatible with pydantic v2. --- .../base/pipeline_graph/_pipeline_graph.py | 2 +- python/lsst/pipe/base/pipeline_graph/io.py | 27 ++++++++++--------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py index 3b1e1671..86f8d3dc 100644 --- a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py +++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py @@ -1130,7 +1130,7 @@ def _write_stream(self, stream: BinaryIO) -> None: with gzip.open(stream, mode="wb") as compressed_stream: compressed_stream.write( - SerializedPipelineGraph.serialize(self).json(exclude_defaults=True, indent=2).encode("utf-8") + SerializedPipelineGraph.serialize(self).json(exclude_defaults=True).encode("utf-8") ) def _write_uri(self, uri: ResourcePathExpression) -> None: diff --git a/python/lsst/pipe/base/pipeline_graph/io.py b/python/lsst/pipe/base/pipeline_graph/io.py index 33c9ef85..53013280 100644 --- a/python/lsst/pipe/base/pipeline_graph/io.py +++ b/python/lsst/pipe/base/pipeline_graph/io.py @@ -36,6 +36,7 @@ import networkx import pydantic from lsst.daf.butler import DatasetType, DimensionConfig, DimensionGraph, DimensionUniverse +from lsst.daf.butler._compat import _BaseModelCompat from .. import automatic_connection_constants as acc from ._dataset_types import DatasetTypeNode @@ -78,7 +79,7 @@ def expect_not_none(value: _U | None, msg: str) -> _U: return value -class SerializedEdge(pydantic.BaseModel): +class SerializedEdge(_BaseModelCompat): """Struct used to represent a serialized `Edge` in a `PipelineGraph`. All `ReadEdge` and `WriteEdge` state not included here is instead @@ -107,7 +108,7 @@ class SerializedEdge(pydantic.BaseModel): @classmethod def serialize(cls, target: Edge) -> SerializedEdge: """Transform an `Edge` to a `SerializedEdge`.""" - return SerializedEdge.construct( + return SerializedEdge.model_construct( storage_class=target.storage_class_name, dataset_type_name=target.dataset_type_name, raw_dimensions=sorted(target.raw_dimensions), @@ -153,7 +154,7 @@ def deserialize_write_edge( ) -class SerializedTaskInitNode(pydantic.BaseModel): +class SerializedTaskInitNode(_BaseModelCompat): """Struct used to represent a serialized `TaskInitNode` in a `PipelineGraph`. @@ -182,7 +183,7 @@ class SerializedTaskInitNode(pydantic.BaseModel): @classmethod def serialize(cls, target: TaskInitNode) -> SerializedTaskInitNode: """Transform a `TaskInitNode` to a `SerializedTaskInitNode`.""" - return cls.construct( + return cls.model_construct( inputs={ connection_name: SerializedEdge.serialize(edge) for connection_name, edge in sorted(target.inputs.items()) @@ -224,7 +225,7 @@ def deserialize( ) -class SerializedTaskNode(pydantic.BaseModel): +class SerializedTaskNode(_BaseModelCompat): """Struct used to represent a serialized `TaskNode` in a `PipelineGraph`. The task label is serialized by the context in which a @@ -271,7 +272,7 @@ class SerializedTaskNode(pydantic.BaseModel): @classmethod def serialize(cls, target: TaskNode) -> SerializedTaskNode: """Transform a `TaskNode` to a `SerializedTaskNode`.""" - return cls.construct( + return cls.model_construct( task_class=target.task_class_name, init=SerializedTaskInitNode.serialize(target.init), config_str=target.get_config_str(), @@ -350,7 +351,7 @@ def deserialize( ) -class SerializedDatasetTypeNode(pydantic.BaseModel): +class SerializedDatasetTypeNode(_BaseModelCompat): """Struct used to represent a serialized `DatasetTypeNode` in a `PipelineGraph`. @@ -391,8 +392,8 @@ class SerializedDatasetTypeNode(pydantic.BaseModel): def serialize(cls, target: DatasetTypeNode | None) -> SerializedDatasetTypeNode: """Transform a `DatasetTypeNode` to a `SerializedDatasetTypeNode`.""" if target is None: - return cls.construct() - return cls.construct( + return cls.model_construct() + return cls.model_construct( dimensions=list(target.dataset_type.dimensions.names), storage_class=target.dataset_type.storageClass_name, is_calibration=target.dataset_type.isCalibration(), @@ -445,7 +446,7 @@ def deserialize( return None -class SerializedTaskSubset(pydantic.BaseModel): +class SerializedTaskSubset(_BaseModelCompat): """Struct used to represent a serialized `TaskSubset` in a `PipelineGraph`. The subsetlabel is serialized by the context in which a @@ -464,7 +465,7 @@ class SerializedTaskSubset(pydantic.BaseModel): @classmethod def serialize(cls, target: TaskSubset) -> SerializedTaskSubset: """Transform a `TaskSubset` into a `SerializedTaskSubset`.""" - return cls.construct(description=target._description, tasks=list(sorted(target))) + return cls.model_construct(description=target._description, tasks=list(sorted(target))) def deserialize_task_subset(self, label: str, xgraph: networkx.MultiDiGraph) -> TaskSubset: """Transform a `SerializedTaskSubset` into a `TaskSubset`.""" @@ -472,7 +473,7 @@ def deserialize_task_subset(self, label: str, xgraph: networkx.MultiDiGraph) -> return TaskSubset(xgraph, label, members, self.description) -class SerializedPipelineGraph(pydantic.BaseModel): +class SerializedPipelineGraph(_BaseModelCompat): """Struct used to represent a serialized `PipelineGraph`.""" version: str = ".".join(str(v) for v in _IO_VERSION_INFO) @@ -500,7 +501,7 @@ class SerializedPipelineGraph(pydantic.BaseModel): @classmethod def serialize(cls, target: PipelineGraph) -> SerializedPipelineGraph: """Transform a `PipelineGraph` into a `SerializedPipelineGraph`.""" - result = SerializedPipelineGraph.construct( + result = SerializedPipelineGraph.model_construct( description=target.description, tasks={label: SerializedTaskNode.serialize(node) for label, node in target.tasks.items()}, dataset_types={ From 14334a420527513bbbe7ab6f207b0b8f922414b9 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Thu, 3 Aug 2023 12:46:56 -0400 Subject: [PATCH 15/16] Minor review-inspired cleanups for PipelineGraph. --- .../base/pipeline_graph/_pipeline_graph.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py index 86f8d3dc..e2ab4335 100644 --- a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py +++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py @@ -463,7 +463,7 @@ def resolve(self, registry: Registry) -> None: Notes ----- - The `universe` attribute are set to ``registry.dimensions`` and used to + The `universe` attribute is set to ``registry.dimensions`` and used to set all `TaskNode.dimensions` attributes. Dataset type nodes are resolved by first looking for a registry definition, then using the producing task's definition, then looking for consistency between all @@ -503,6 +503,9 @@ def resolve(self, registry: Registry) -> None: new_dataset_type_node = DatasetTypeNode._from_edges( node_key, self._xgraph, registry, previous=dataset_type_node ) + # Usage of `is`` here is intentional; `_from_edges` returns + # `previous=dataset_type_node` if it can determine that it + # doesn't need to change. if new_dataset_type_node is not dataset_type_node: updates[node_key] = new_dataset_type_node try: @@ -582,12 +585,10 @@ def add_task( it references and marks the graph as unsorted. It is most effiecient to add all tasks up front and only then resolve and/or sort the graph. """ - key = NodeKey(NodeType.TASK, label) - init_key = NodeKey(NodeType.TASK_INIT, label) task_node = TaskNode._from_imported_data( - key, - init_key, - _TaskNodeImportedData.configure(label, task_class, config, connections), + key=NodeKey(NodeType.TASK, label), + init_key=NodeKey(NodeType.TASK_INIT, label), + data=_TaskNodeImportedData.configure(label, task_class, config, connections), universe=self.universe, ) self.add_task_nodes([task_node]) @@ -1449,7 +1450,10 @@ def _init_from_args( self._dataset_types = DatasetTypeMappingView(self._xgraph) self._raw_data_id: dict[str, Any] if isinstance(data_id, DataCoordinate): - universe = data_id.universe + if universe is None: + universe = data_id.universe + else: + assert universe is data_id.universe, "data_id.universe and given universe differ" self._raw_data_id = data_id.byName() elif data_id is None: self._raw_data_id = {} From 8091ec1839ffc0334f3bd9ae53096a25a880484e Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Thu, 3 Aug 2023 16:55:05 -0400 Subject: [PATCH 16/16] Use enum to make PipelineGraph load options clearer. --- .../base/pipeline_graph/_pipeline_graph.py | 102 +++++++----------- .../lsst/pipe/base/pipeline_graph/_tasks.py | 35 +++++- python/lsst/pipe/base/pipeline_graph/io.py | 12 +-- tests/test_pipeline_graph.py | 9 +- 4 files changed, 83 insertions(+), 75 deletions(-) diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py index e2ab4335..8bfd6357 100644 --- a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py +++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py @@ -47,7 +47,7 @@ from ._mapping_views import DatasetTypeMappingView, TaskMappingView from ._nodes import NodeKey, NodeType from ._task_subsets import TaskSubset -from ._tasks import TaskInitNode, TaskNode, _TaskNodeImportedData +from ._tasks import TaskImportMode, TaskInitNode, TaskNode, _TaskNodeImportedData if TYPE_CHECKING: from ..config import PipelineTaskConfig @@ -1002,11 +1002,7 @@ def make_dataset_type_xgraph(self, init: bool = False) -> networkx.DiGraph: @classmethod def _read_stream( - cls, - stream: BinaryIO, - import_and_configure: bool = True, - check_edges_unchanged: bool = False, - assume_edges_unchanged: bool = False, + cls, stream: BinaryIO, import_mode: TaskImportMode = TaskImportMode.REQUIRE_CONSISTENT_EDGES ) -> PipelineGraph: """Read a serialized `PipelineGraph` from a file-like object. @@ -1015,15 +1011,11 @@ def _read_stream( stream : `BinaryIO` File-like object opened for binary reading, containing gzip-compressed JSON. - import_and_configure : `bool`, optional - If `True`, import and configure all tasks immediately (see the - `import_and_configure` method). If `False`, some `TaskNode` and - `TaskInitNode` attributes will not be available, but reading may be - much faster. - check_edges_unchanged : `bool`, optional - Forwarded to `import_and_configure` after reading. - assume_edges_unchanged : `bool`, optional - Forwarded to `import_and_configure` after reading. + import_mode : `TaskImportMode`, optional + Whether to import tasks, and how to reconcile any differences + between the imported task's connections and the those that were + persisted with the graph. Default is to check that they are the + same. Returns ------- @@ -1035,8 +1027,9 @@ def _read_stream( PipelineGraphReadError Raised if the serialized `PipelineGraph` is not self-consistent. EdgesChangedError - Raised if ``check_edges_unchanged=True`` and the edges of a task do - change after import and reconfiguration. + Raised if ``import_mode`` is + `TaskImportMode.REQUIRED_CONSISTENT_EDGES` and the edges of a task + did change after import and reconfiguration. Notes ----- @@ -1049,19 +1042,13 @@ def _read_stream( with gzip.open(stream, "rb") as uncompressed_stream: data = json.load(uncompressed_stream) serialized_graph = SerializedPipelineGraph.parse_obj(data) - return serialized_graph.deserialize( - import_and_configure=import_and_configure, - check_edges_unchanged=check_edges_unchanged, - assume_edges_unchanged=assume_edges_unchanged, - ) + return serialized_graph.deserialize(import_mode) @classmethod def _read_uri( cls, uri: ResourcePathExpression, - import_and_configure: bool = True, - check_edges_unchanged: bool = False, - assume_edges_unchanged: bool = False, + import_mode: TaskImportMode = TaskImportMode.REQUIRE_CONSISTENT_EDGES, ) -> PipelineGraph: """Read a serialized `PipelineGraph` from a file at a URI. @@ -1070,15 +1057,11 @@ def _read_uri( uri : convertible to `lsst.resources.ResourcePath` URI to a gzip-compressed JSON file containing a serialized pipeline graph. - import_and_configure : `bool`, optional - If `True`, import and configure all tasks immediately (see - the `import_and_configure` method). If `False`, some `TaskNode` - and `TaskInitNode` attributes will not be available, but reading - may be much faster. - check_edges_unchanged : `bool`, optional - Forwarded to `import_and_configure` after reading. - assume_edges_unchanged : `bool`, optional - Forwarded to `import_and_configure` after reading. + import_mode : `TaskImportMode`, optional + Whether to import tasks, and how to reconcile any differences + between the imported task's connections and the those that were + persisted with the graph. Default is to check that they are the + same. Returns ------- @@ -1090,8 +1073,9 @@ def _read_uri( PipelineGraphReadError Raised if the serialized `PipelineGraph` is not self-consistent. EdgesChangedError - Raised if ``check_edges_unchanged=True`` and the edges of a task do - change after import and reconfiguration. + Raised if ``import_mode`` is + `TaskImportMode.REQUIRED_CONSISTENT_EDGES` and the edges of a task + did change after import and reconfiguration. Notes ----- @@ -1101,12 +1085,7 @@ def _read_uri( """ uri = ResourcePath(uri) with uri.open("rb") as stream: - return cls._read_stream( - cast(BinaryIO, stream), - import_and_configure=import_and_configure, - check_edges_unchanged=check_edges_unchanged, - assume_edges_unchanged=assume_edges_unchanged, - ) + return cls._read_stream(cast(BinaryIO, stream), import_mode=import_mode) def _write_stream(self, stream: BinaryIO) -> None: """Write the pipeline to a file-like object. @@ -1164,31 +1143,26 @@ def _write_uri(self, uri: ResourcePathExpression) -> None: self._write_stream(cast(BinaryIO, stream)) def _import_and_configure( - self, check_edges_unchanged: bool = False, assume_edges_unchanged: bool = False + self, import_mode: TaskImportMode = TaskImportMode.REQUIRE_CONSISTENT_EDGES ) -> None: """Import the `PipelineTask` classes referenced by all task nodes and update those nodes accordingly. Parameters ---------- - check_edges_unchanged : `bool`, optional - If `True`, require the edges (connections) of the modified tasks to - remain unchanged after importing and configuring each task, and - verify that this is the case. - assume_edges_unchanged : `bool`, optional - If `True`, the caller declares that the edges (connections) of the - modified tasks will remain unchanged importing and configuring each - task, and that it is unnecessary to check this. + import_mode : `TaskImportMode`, optional + Whether to import tasks, and how to reconcile any differences + between the imported task's connections and the those that were + persisted with the graph. Default is to check that they are the + same. This method does nothing if this is + `TaskImportMode.DO_NOT_IMPORT`. Raises ------ - ValueError - Raised if ``assume_edges_unchanged`` and ``check_edges_unchanged`` - are both `True`, or if a full config is provided for a task after - another full config or an override has already been provided. EdgesChangedError - Raised if ``check_edges_unchanged=True`` and the edges of a task do - change. + Raised if ``import_mode`` is + `TaskImportMode.REQUIRED_CONSISTENT_EDGES` and the edges of a task + did change after import and reconfiguration. Notes ----- @@ -1202,13 +1176,19 @@ def _import_and_configure( usually because the software used to read a serialized graph is newer than the software used to write it (e.g. a new config option has been added, or the task was moved to a new module with a forwarding alias - left behind). These changes are allowed by ``check=True``. + left behind). These changes are allowed by + `TaskImportMode.REQUIRE_CONSISTENT_EDGES`. If importing and configuring a task causes its edges to change, any dataset type nodes linked to those edges will be reset to the unresolved state. """ - rebuild = check_edges_unchanged or not assume_edges_unchanged + if import_mode is TaskImportMode.DO_NOT_IMPORT: + return + rebuild = ( + import_mode is TaskImportMode.REQUIRE_CONSISTENT_EDGES + or import_mode is TaskImportMode.OVERRIDE_EDGES + ) updates: dict[str, TaskNode] = {} node_key: NodeKey for node_key, node_state in self._xgraph.nodes.items(): @@ -1219,8 +1199,8 @@ def _import_and_configure( updates[task_node.label] = new_task_node self._replace_task_nodes( updates, - check_edges_unchanged=check_edges_unchanged, - assume_edges_unchanged=assume_edges_unchanged, + check_edges_unchanged=(import_mode is TaskImportMode.REQUIRE_CONSISTENT_EDGES), + assume_edges_unchanged=(import_mode is TaskImportMode.ASSUME_CONSISTENT_EDGES), message_header=( "In task with label {task_label!r}, persisted edges (A)" "differ from imported and configured edges (B):" diff --git a/python/lsst/pipe/base/pipeline_graph/_tasks.py b/python/lsst/pipe/base/pipeline_graph/_tasks.py index 8c3b4e2b..d1d7b236 100644 --- a/python/lsst/pipe/base/pipeline_graph/_tasks.py +++ b/python/lsst/pipe/base/pipeline_graph/_tasks.py @@ -20,9 +20,10 @@ # along with this program. If not, see . from __future__ import annotations -__all__ = ("TaskNode", "TaskInitNode") +__all__ = ("TaskNode", "TaskInitNode", "TaskImportMode") import dataclasses +import enum from collections.abc import Iterator, Mapping from typing import TYPE_CHECKING, Any, cast @@ -43,6 +44,38 @@ from ..pipelineTask import PipelineTask +class TaskImportMode(enum.Enum): + """Enumeration of the ways to handle importing tasks when reading a + serialized PipelineGraph. + """ + + DO_NOT_IMPORT = enum.auto() + """Do not import tasks or instantiate their configs and connections.""" + + REQUIRE_CONSISTENT_EDGES = enum.auto() + """Import tasks and instantiate their config and connection objects, and + check that the connections still define the same edges. + """ + + ASSUME_CONSISTENT_EDGES = enum.auto() + """Import tasks and instantiate their config and connection objects, but do + not check that the connections still define the same edges. + + This is safe only when the caller knows the task definition has not changed + since the pipeline graph was persisted, such as when it was saved and + loaded with the same pipeline version. + """ + + OVERRIDE_EDGES = enum.auto() + """Import tasks and instantiate their config and connection objects, and + allow the edges defined in those connections to override those in the + persisted graph. + + This may cause dataset type nodes to be unresolved, since resolutions + consistent with the original edges may be invalidated. + """ + + @dataclasses.dataclass(frozen=True) class _TaskNodeImportedData: """An internal struct that holds `TaskNode` and `TaskInitNode` state that diff --git a/python/lsst/pipe/base/pipeline_graph/io.py b/python/lsst/pipe/base/pipeline_graph/io.py index 53013280..02506ac3 100644 --- a/python/lsst/pipe/base/pipeline_graph/io.py +++ b/python/lsst/pipe/base/pipeline_graph/io.py @@ -45,7 +45,7 @@ from ._nodes import NodeKey, NodeType from ._pipeline_graph import PipelineGraph from ._task_subsets import TaskSubset -from ._tasks import TaskInitNode, TaskNode +from ._tasks import TaskImportMode, TaskInitNode, TaskNode _U = TypeVar("_U") @@ -527,9 +527,7 @@ def serialize(cls, target: PipelineGraph) -> SerializedPipelineGraph: def deserialize( self, - import_and_configure: bool = True, - check_edges_unchanged: bool = False, - assume_edges_unchanged: bool = False, + import_mode: TaskImportMode, ) -> PipelineGraph: """Transform a `SerializedPipelineGraph` into a `PipelineGraph`.""" universe: DimensionUniverse | None = None @@ -615,9 +613,5 @@ def deserialize( universe=universe, data_id=self.data_id, ) - if import_and_configure: - result._import_and_configure( - check_edges_unchanged=check_edges_unchanged, - assume_edges_unchanged=assume_edges_unchanged, - ) + result._import_and_configure(import_mode) return result diff --git a/tests/test_pipeline_graph.py b/tests/test_pipeline_graph.py index 479e9202..44e568b0 100644 --- a/tests/test_pipeline_graph.py +++ b/tests/test_pipeline_graph.py @@ -41,6 +41,7 @@ NodeType, PipelineGraph, PipelineGraphError, + TaskImportMode, UnresolvedGraphError, ) from lsst.pipe.base.tests.mocks import ( @@ -154,12 +155,12 @@ def test_unresolved_deferred_import_io(self) -> None: stream = io.BytesIO() self.graph._write_stream(stream) stream.seek(0) - roundtripped = PipelineGraph._read_stream(stream, import_and_configure=False) + roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT) self.check_make_xgraph(roundtripped, resolved=False, imported_and_configured=False) # Check that we can still resolve the graph without importing tasks. roundtripped.resolve(MockRegistry(self.dimensions, {})) self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) - roundtripped._import_and_configure(assume_edges_unchanged=True) + roundtripped._import_and_configure(TaskImportMode.ASSUME_CONSISTENT_EDGES) self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) def test_resolved_accessors(self) -> None: @@ -221,9 +222,9 @@ def test_resolved_deferred_import_io(self) -> None: stream = io.BytesIO() self.graph._write_stream(stream) stream.seek(0) - roundtripped = PipelineGraph._read_stream(stream, import_and_configure=False) + roundtripped = PipelineGraph._read_stream(stream, import_mode=TaskImportMode.DO_NOT_IMPORT) self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) - roundtripped._import_and_configure(check_edges_unchanged=True) + roundtripped._import_and_configure(TaskImportMode.REQUIRE_CONSISTENT_EDGES) self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) def test_unresolved_copies(self) -> None: