From 4f92f7a5e0108f73e375ca7710b8e3f0d9a136d2 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Mon, 29 May 2023 11:51:23 -0400 Subject: [PATCH] Add PipelineGraph package. --- python/lsst/pipe/base/__init__.py | 6 +- python/lsst/pipe/base/pipeline.py | 26 +- .../lsst/pipe/base/pipeline_graph/__init__.py | 29 + .../base/pipeline_graph/_dataset_types.py | 214 +++ .../lsst/pipe/base/pipeline_graph/_edges.py | 714 +++++++++ .../pipe/base/pipeline_graph/_exceptions.py | 95 ++ .../base/pipeline_graph/_mapping_views.py | 197 +++ .../lsst/pipe/base/pipeline_graph/_nodes.py | 85 + .../base/pipeline_graph/_pipeline_graph.py | 1389 +++++++++++++++++ .../pipe/base/pipeline_graph/_task_subsets.py | 122 ++ .../lsst/pipe/base/pipeline_graph/_tasks.py | 855 ++++++++++ python/lsst/pipe/base/pipeline_graph/io.py | 578 +++++++ .../pipe/base/tests/mocks/_pipeline_task.py | 1 - tests/test_pipeline_graph.py | 1255 +++++++++++++++ 14 files changed, 5556 insertions(+), 10 deletions(-) create mode 100644 python/lsst/pipe/base/pipeline_graph/__init__.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_dataset_types.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_edges.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_exceptions.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_mapping_views.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_nodes.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_task_subsets.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_tasks.py create mode 100644 python/lsst/pipe/base/pipeline_graph/io.py create mode 100644 tests/test_pipeline_graph.py diff --git a/python/lsst/pipe/base/__init__.py b/python/lsst/pipe/base/__init__.py index 74339da9..51652a71 100644 --- a/python/lsst/pipe/base/__init__.py +++ b/python/lsst/pipe/base/__init__.py @@ -1,4 +1,4 @@ -from . import automatic_connection_constants, connectionTypes, pipelineIR +from . import automatic_connection_constants, connectionTypes, pipeline_graph, pipelineIR from ._dataset_handle import * from ._instrument import * from ._observation_dimension_packer import * @@ -11,6 +11,10 @@ from .graph import * from .graphBuilder import * from .pipeline import * + +# We import the main PipelineGraph type and the module (above), but we don't +# lift all symbols to package scope. +from .pipeline_graph import PipelineGraph from .pipelineTask import * from .struct import * from .task import * diff --git a/python/lsst/pipe/base/pipeline.py b/python/lsst/pipe/base/pipeline.py index 781defcc..922bd112 100644 --- a/python/lsst/pipe/base/pipeline.py +++ b/python/lsst/pipe/base/pipeline.py @@ -57,7 +57,7 @@ from ._instrument import Instrument as PipeBaseInstrument from ._task_metadata import TaskMetadata from .config import PipelineTaskConfig -from .connections import iterConnections +from .connections import PipelineTaskConnections, iterConnections from .connectionTypes import Input from .pipelineTask import PipelineTask from .task import _TASK_METADATA_TYPE @@ -126,6 +126,11 @@ class TaskDef: Task label, usually a short string unique in a pipeline. If not provided, ``taskClass`` must be, and ``taskClass._DefaultName`` will be used. + connections : `PipelineTaskConnections`, optional + Object that describes the dataset types used by the task. If not + provided, one will be constructed from the given configuration. If + provided, it is assumed that ``config`` has already been validated + and frozen. """ def __init__( @@ -134,6 +139,7 @@ def __init__( config: PipelineTaskConfig | None = None, taskClass: type[PipelineTask] | None = None, label: str | None = None, + connections: PipelineTaskConnections | None = None, ): if taskName is None: if taskClass is None: @@ -150,16 +156,20 @@ def __init__( raise ValueError("`taskClass` must be provided if `label` is not.") label = taskClass._DefaultName self.taskName = taskName - try: - config.validate() - except Exception: - _LOG.error("Configuration validation failed for task %s (%s)", label, taskName) - raise - config.freeze() + if connections is None: + # If we don't have connections yet, assume the config hasn't been + # validated yet. + try: + config.validate() + except Exception: + _LOG.error("Configuration validation failed for task %s (%s)", label, taskName) + raise + config.freeze() + connections = config.connections.ConnectionsClass(config=config) self.config = config self.taskClass = taskClass self.label = label - self.connections = config.connections.ConnectionsClass(config=config) + self.connections = connections @property def configDatasetName(self) -> str: diff --git a/python/lsst/pipe/base/pipeline_graph/__init__.py b/python/lsst/pipe/base/pipeline_graph/__init__.py new file mode 100644 index 00000000..3cf7a810 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/__init__.py @@ -0,0 +1,29 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +from ._dataset_types import * +from ._edges import * +from ._exceptions import * +from ._nodes import * +from ._pipeline_graph import * +from ._task_subsets import * +from ._tasks import * diff --git a/python/lsst/pipe/base/pipeline_graph/_dataset_types.py b/python/lsst/pipe/base/pipeline_graph/_dataset_types.py new file mode 100644 index 00000000..7cc3dd56 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_dataset_types.py @@ -0,0 +1,214 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ("DatasetTypeNode",) + +import dataclasses +from typing import TYPE_CHECKING, Any + +import networkx +from lsst.daf.butler import DatasetRef, DatasetType, DimensionGraph, Registry, StorageClass +from lsst.daf.butler.registry import MissingDatasetTypeError + +from ._exceptions import DuplicateOutputError +from ._nodes import NodeKey, NodeType + +if TYPE_CHECKING: + from ._edges import ReadEdge, WriteEdge + + +@dataclasses.dataclass(frozen=True, eq=False) +class DatasetTypeNode: + """A node in a pipeline graph that represents a resolved dataset type. + + Notes + ----- + A dataset type node represents a common definition of the dataset type + across the entire graph - it is never a component, and the storage class is + the registry dataset type's storage class or (if there isn't one) the one + defined by the producing task. + + Dataset type nodes are intentionally not equality comparable, since there + are many different (and useful) ways to compare these objects with no clear + winner as the most obvious behavior. + """ + + dataset_type: DatasetType + """Common definition of this dataset type for the graph. + """ + + is_initial_query_constraint: bool + """Whether this dataset should be included as a constraint in the initial + query for data IDs in QuantumGraph generation. + + This is only `True` for dataset types that are overall regular inputs, and + only if none of those input connections had ``deferQueryConstraint=True``. + """ + + is_prerequisite: bool + """Whether this dataset type is a prerequisite input that must exist in + the Registry before graph creation. + """ + + @classmethod + def _from_edges( + cls, key: NodeKey, xgraph: networkx.MultiDiGraph, registry: Registry, previous: DatasetTypeNode | None + ) -> DatasetTypeNode: + """Construct a dataset type node from its edges. + + Parameters + ---------- + key : `NodeKey` + Named tuple that holds the dataset type and serves as the node + object in the internal networkx graph. + xgraph : `networkx.MultiDiGraph` + The internal networkx graph. + registry : `lsst.daf.butler.Registry` + Registry client for the data repository. Only used to get + dataset type definitions and the dimension universe. + previous : `DatasetTypeNode` or `None` + Previous node for this dataset type. + + Returns + ------- + node : `DatasetTypeNode` + Node consistent with all edges pointing to it and the data + repository. + """ + try: + dataset_type = registry.getDatasetType(key.name) + is_registered = True + except MissingDatasetTypeError: + dataset_type = None + is_registered = False + if previous is not None and previous.dataset_type == dataset_type: + # This node was already resolved (with exactly the same edges + # contributing, since we clear resolutions when edges are added or + # removed). The only thing that might have changed was the + # definition in the registry, and it didn't. + return previous + is_initial_query_constraint = True + is_prerequisite: bool | None = None + producer: str | None = None + write_edge: WriteEdge + for _, _, write_edge in xgraph.in_edges(key, data="instance"): # will iterate zero or one time + if producer is not None: + raise DuplicateOutputError( + f"Dataset type {key.name!r} is produced by both {write_edge.task_label!r} " + f"and {producer!r}." + ) + producer = write_edge.task_label + dataset_type = write_edge._resolve_dataset_type(dataset_type, universe=registry.dimensions) + is_prerequisite = False + is_initial_query_constraint = False + read_edge: ReadEdge + consumers: list[str] = [] + read_edges = list(read_edge for _, _, read_edge in xgraph.out_edges(key, data="instance")) + # Put edges that are not component datasets before any edges that are. + read_edges.sort(key=lambda read_edge: read_edge.component is not None) + for read_edge in read_edges: + dataset_type, is_initial_query_constraint, is_prerequisite = read_edge._resolve_dataset_type( + current=dataset_type, + universe=registry.dimensions, + is_initial_query_constraint=is_initial_query_constraint, + is_prerequisite=is_prerequisite, + is_registered=is_registered, + producer=producer, + consumers=consumers, + ) + consumers.append(read_edge.task_label) + assert dataset_type is not None, "Graph structure guarantees at least one edge." + assert is_prerequisite is not None, "Having at least one edge guarantees is_prerequisite is known." + return DatasetTypeNode( + dataset_type=dataset_type, + is_initial_query_constraint=is_initial_query_constraint, + is_prerequisite=is_prerequisite, + ) + + @property + def name(self) -> str: + """Name of the dataset type. + + This is always the parent dataset type, never that of a component. + """ + return self.dataset_type.name + + @property + def key(self) -> NodeKey: + """Key that identifies this dataset type in internal and exported + networkx graphs. + """ + return NodeKey(NodeType.DATASET_TYPE, self.dataset_type.name) + + @property + def dimensions(self) -> DimensionGraph: + """Dimensions of the dataset type.""" + return self.dataset_type.dimensions + + @property + def storage_class_name(self) -> str: + """String name of the storage class for this dataset type.""" + return self.dataset_type.storageClass_name + + @property + def storage_class(self) -> StorageClass: + """Storage class for this dataset type.""" + return self.dataset_type.storageClass + + def __repr__(self) -> str: + return f"{self.name} ({self.storage_class_name}, {self.dimensions})" + + def generalize_ref(self, ref: DatasetRef) -> DatasetRef: + """Convert a `~lsst.daf.butler.DatasetRef` with the dataset type + associated with some task to one with the common dataset type defined + by this node. + + Parameters + ---------- + ref : `lsst.daf.butler.DatasetRef` + Reference whose dataset type is convertible to this node's, either + because it is a component with the node's dataset type as its + parent, or because it has a compatible storage class. + + Returns + ------- + ref : `lsst.daf.butler.DatasetRef` + Reference with exactly this node's dataset type. + """ + if ref.isComponent(): + ref = ref.makeCompositeRef() + if ref.datasetType.storageClass_name != self.dataset_type.storageClass_name: + return ref.overrideStorageClass(self.dataset_type.storageClass_name) + return ref + + def _to_xgraph_state(self) -> dict[str, Any]: + """Convert this node's attributes into a dictionary suitable for use + in exported networkx graphs. + """ + return { + "dataset_type": self.dataset_type, + "is_initial_query_constraint": self.is_initial_query_constraint, + "is_prerequisite": self.is_prerequisite, + "dimensions": self.dataset_type.dimensions, + "storage_class_name": self.dataset_type.storageClass_name, + "bipartite": NodeType.DATASET_TYPE.bipartite, + } diff --git a/python/lsst/pipe/base/pipeline_graph/_edges.py b/python/lsst/pipe/base/pipeline_graph/_edges.py new file mode 100644 index 00000000..10ea6b11 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_edges.py @@ -0,0 +1,714 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ("Edge", "ReadEdge", "WriteEdge") + +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from typing import Any, ClassVar, TypeVar + +from lsst.daf.butler import DatasetRef, DatasetType, DimensionUniverse, SkyPixDimension +from lsst.daf.butler.registry import MissingDatasetTypeError +from lsst.utils.classes import immutable + +from ..connectionTypes import BaseConnection +from ._exceptions import ConnectionTypeConsistencyError, IncompatibleDatasetTypeError +from ._nodes import NodeKey, NodeType + +_S = TypeVar("_S", bound="Edge") + + +@immutable +class Edge(ABC): + """Base class for edges in a pipeline graph. + + This represents the link between a task node and an input or output dataset + type. + + Parameters + ---------- + task_key : `NodeKey` + Key for the task node this edge is connected to. + dataset_type_key : `NodeKey` + Key for the dataset type node this edge is connected to. + storage_class_name : `str` + Name of the dataset type's storage class as seen by the task. + connection_name : `str` + Internal name for the connection as seen by the task. + is_calibration : `bool` + Whether this dataset type can be included in + `~lsst.daf.butler.CollectionType.CALIBRATION` collections. + raw_dimensions : `frozenset` [ `str` ] + Raw dimensions from the connection definition. + """ + + def __init__( + self, + *, + task_key: NodeKey, + dataset_type_key: NodeKey, + storage_class_name: str, + connection_name: str, + is_calibration: bool, + raw_dimensions: frozenset[str], + ): + self.task_key = task_key + self.dataset_type_key = dataset_type_key + self.connection_name = connection_name + self.storage_class_name = storage_class_name + self.is_calibration = is_calibration + self.raw_dimensions = raw_dimensions + + INIT_TO_TASK_NAME: ClassVar[str] = "INIT" + """Edge key for the special edge that connects a task init node to the + task node itself (for regular edges, this would be the connection name). + """ + + task_key: NodeKey + """Task part of the key for this edge in networkx graphs.""" + + dataset_type_key: NodeKey + """Task part of the key for this edge in networkx graphs.""" + + connection_name: str + """Name used by the task to refer to this dataset type.""" + + storage_class_name: str + """Storage class expected by this task. + + If `ReadEdge.component` is not `None`, this is the component storage class, + not the parent storage class. + """ + + is_calibration: bool + """Whether this dataset type can be included in + `~lsst.daf.butler.CollectionType.CALIBRATION` collections. + """ + + raw_dimensions: frozenset[str] + """Raw dimensions in the task declaration. + + This can only be used safely for partial comparisons: two edges with the + same ``raw_dimensions`` (and the same parent dataset type name) always have + the same resolved dimensions, but edges with different ``raw_dimensions`` + may also have the same resolvd dimensions. + """ + + @property + def is_init(self) -> bool: + """Whether this dataset is read or written when the task is + constructed, not when it is run. + """ + return self.task_key.node_type is NodeType.TASK_INIT + + @property + def task_label(self) -> str: + """Label of the task.""" + return str(self.task_key) + + @property + def parent_dataset_type_name(self) -> str: + """Name of the parent dataset type. + + All dataset type nodes in a pipeline graph are for parent dataset + types; components are represented by additional `ReadEdge` state. + """ + return str(self.dataset_type_key) + + @property + @abstractmethod + def nodes(self) -> tuple[NodeKey, NodeKey]: + """The directed pair of `NodeKey` instances this edge connects. + + This tuple is ordered in the same direction as the pipeline flow: + `task_key` precedes `dataset_type_key` for writes, and the + reverse is true for reads. + """ + raise NotImplementedError() + + @property + def key(self) -> tuple[NodeKey, NodeKey, str]: + """Ordered tuple of node keys and connection name that uniquely + identifies this edge in a pipeline graph. + """ + return self.nodes + (self.connection_name,) + + def __repr__(self) -> str: + return f"{self.nodes[0]} -> {self.nodes[1]} ({self.connection_name})" + + @property + def dataset_type_name(self) -> str: + """Dataset type name seen by the task. + + This defaults to the parent dataset type name, which is appropriate + for all writes and most reads. + """ + return self.parent_dataset_type_name + + def diff(self: _S, other: _S, connection_type: str = "connection") -> list[str]: + """Compare this edge to another one from a possibly-different + configuration of the same task label. + + Parameters + ---------- + other : `Edge` + Another edge of the same type to compare to. + connection_type : `str` + Human-readable name of the connection type of this edge (e.g. + "init input", "output") for use in returned messages. + + Returns + ------- + differences : `list` [ `str` ] + List of string messages describing differences between ``self`` and + ``other``. Will be empty if ``self == other`` or if the only + difference is in the task label or connection name (which are not + checked). Messages will use 'A' to refer to ``self`` and 'B' to + refer to ``other``. + """ + result = [] + if self.dataset_type_name != other.dataset_type_name: + result.append( + f"{connection_type.capitalize()} {self.connection_name!r} has dataset type " + f"{self.dataset_type_name!r} in A, but {other.dataset_type_name!r} in B." + ) + if self.storage_class_name != other.storage_class_name: + result.append( + f"{connection_type.capitalize()} {self.connection_name!r} has storage class " + f"{self.storage_class_name!r} in A, but {other.storage_class_name!r} in B." + ) + if self.raw_dimensions != other.raw_dimensions: + result.append( + f"{connection_type.capitalize()} {self.connection_name!r} has raw dimensions " + f"{set(self.raw_dimensions)} in A, but {set(other.raw_dimensions)} in B " + "(differences in raw dimensions may not lead to differences in resolved dimensions, " + "but this cannot be checked without re-resolving the dataset type)." + ) + if self.is_calibration != other.is_calibration: + result.append( + f"{connection_type.capitalize()} {self.connection_name!r} is marked as a calibration " + f"{'in A but not in B' if self.is_calibration else 'in B but not in A'}." + ) + return result + + @abstractmethod + def adapt_dataset_type(self, dataset_type: DatasetType) -> DatasetType: + """Transform the graph's definition of a dataset type (parent, with the + registry or producer's storage class) to the one seen by this task. + """ + raise NotImplementedError() + + @abstractmethod + def adapt_dataset_ref(self, ref: DatasetRef) -> DatasetRef: + """Transform the graph's definition of a dataset reference (parent + dataset type, with the registry or producer's storage class) to the one + seen by this task. + """ + raise NotImplementedError() + + def _to_xgraph_state(self) -> dict[str, Any]: + """Convert this edges's attributes into a dictionary suitable for use + in exported networkx graphs. + """ + return { + "parent_dataset_type_name": self.parent_dataset_type_name, + "storage_class_name": self.storage_class_name, + "is_init": bool, + } + + +class ReadEdge(Edge): + """Representation of an input connection (including init-inputs and + prerequisites) in a pipeline graph. + + Parameters + ---------- + dataset_type_key : `NodeKey` + Key for the dataset type node this edge is connected to. This should + hold the parent dataset type name for component dataset types. + task_key : `NodeKey` + Key for the task node this edge is connected to. + storage_class_name : `str` + Name of the dataset type's storage class as seen by the task. + connection_name : `str` + Internal name for the connection as seen by the task. + is_calibration : `bool` + Whether this dataset type can be included in + `~lsst.daf.butler.CollectionType.CALIBRATION` collections. + raw_dimensions : `frozenset` [ `str` ] + Raw dimensions from the connection definition. + is_prerequisite : `bool` + Whether this dataset must be present in the data repository prior to + `QuantumGraph` generation. + component : `str` or `None` + Component of the dataset type requested by the task. + defer_query_constraint : `bool` + If `True`, by default do not include this dataset type's existence as a + constraint on the initial data ID query in QuantumGraph generation. + + Notes + ----- + When included in an exported `networkx` graph (e.g. + `PipelineGraph.make_xgraph`), read edges set the following edge attributes: + + - ``parent_dataset_type_name`` + - ``storage_class_name`` + - ``is_init`` + - ``component`` + - ``is_prerequisite`` + + As with `ReadEdge` instance attributes, these descriptions of dataset types + are those specific to a task, and may differ from the graph's resolved + dataset type or (if `PipelineGraph.resolve` has not been called) there may + not even be a consistent definition of the dataset type. + """ + + def __init__( + self, + dataset_type_key: NodeKey, + task_key: NodeKey, + *, + storage_class_name: str, + connection_name: str, + is_calibration: bool, + raw_dimensions: frozenset[str], + is_prerequisite: bool, + component: str | None, + defer_query_constraint: bool, + ): + super().__init__( + task_key=task_key, + dataset_type_key=dataset_type_key, + storage_class_name=storage_class_name, + connection_name=connection_name, + raw_dimensions=raw_dimensions, + is_calibration=is_calibration, + ) + self.is_prerequisite = is_prerequisite + self.component = component + self.defer_query_constraint = defer_query_constraint + + component: str | None + """Component to add to `parent_dataset_type_name` to form the dataset type + name seen by this task. + """ + + is_prerequisite: bool + """Whether this dataset must be present in the data repository prior to + `QuantumGraph` generation. + """ + + defer_query_constraint: bool + """If `True`, by default do not include this dataset type's existence as a + constraint on the initial data ID query in QuantumGraph generation. + """ + + @property + def nodes(self) -> tuple[NodeKey, NodeKey]: + # Docstring inherited. + return (self.dataset_type_key, self.task_key) + + @property + def dataset_type_name(self) -> str: + """Complete dataset type name, as seen by the task.""" + if self.component is not None: + return f"{self.parent_dataset_type_name}.{self.component}" + return self.parent_dataset_type_name + + def diff(self: ReadEdge, other: ReadEdge, connection_type: str = "connection") -> list[str]: + # Docstring inherited. + result = super().diff(other, connection_type) + if self.defer_query_constraint != other.defer_query_constraint: + result.append( + f"{connection_type.capitalize()} {self.connection_name!r} is marked as a deferred query " + f"constraint {'in A but not in B' if self.defer_query_constraint else 'in B but not in A'}." + ) + return result + + def adapt_dataset_type(self, dataset_type: DatasetType) -> DatasetType: + # Docstring inherited. + if self.component is not None: + assert ( + self.storage_class_name == dataset_type.storageClass.allComponents()[self.component].name + ), "components with storage class overrides are not supported" + return dataset_type.makeComponentDatasetType(self.component) + if self.storage_class_name != dataset_type.storageClass_name: + return dataset_type.overrideStorageClass(self.storage_class_name) + return dataset_type + + def adapt_dataset_ref(self, ref: DatasetRef) -> DatasetRef: + # Docstring inherited. + if self.component is not None: + assert ( + self.storage_class_name == ref.datasetType.storageClass.allComponents()[self.component].name + ), "components with storage class overrides are not supported" + return ref.makeComponentRef(self.component) + if self.storage_class_name != ref.datasetType.storageClass_name: + return ref.overrideStorageClass(self.storage_class_name) + return ref + + @classmethod + def _from_connection_map( + cls, + task_key: NodeKey, + connection_name: str, + connection_map: Mapping[str, BaseConnection], + is_prerequisite: bool = False, + ) -> ReadEdge: + """Construct a `ReadEdge` instance from a `.BaseConnection` object. + + Parameters + ---------- + task_key : `NodeKey` + Key for the associated task node or task init node. + connection_name : `str` + Internal name for the connection as seen by the task,. + connection_map : Mapping [ `str`, `.BaseConnection` ] + Mapping of post-configuration object to draw dataset type + information from, keyed by connection name. + is_prerequisite : `bool`, optional + Whether this dataset must be present in the data repository prior + to `QuantumGraph` generation. + + Returns + ------- + edge : `ReadEdge` + New edge instance. + """ + connection = connection_map[connection_name] + parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) + return cls( + dataset_type_key=NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name), + task_key=task_key, + component=component, + storage_class_name=connection.storageClass, + # InitInput connections don't have .isCalibration. + is_calibration=getattr(connection, "isCalibration", False), + is_prerequisite=is_prerequisite, + connection_name=connection_name, + # InitInput connections don't have a .dimensions because they + # always have empty dimensions. + raw_dimensions=frozenset(getattr(connection, "dimensions", frozenset())), + # PrerequisiteInput and InitInput connections don't have a + # .eferQueryConstraints, because they never constrain the initial + # data ID query. + defer_query_constraint=getattr(connection, "deferQueryConstraint", False), + ) + + def _resolve_dataset_type( + self, + *, + current: DatasetType | None, + is_initial_query_constraint: bool, + is_prerequisite: bool | None, + universe: DimensionUniverse, + producer: str | None, + consumers: Sequence[str], + is_registered: bool, + ) -> tuple[DatasetType, bool, bool]: + """Participate in the construction of the `DatasetTypeNode` object + associated with this edge. + + Parameters + ---------- + current : `lsst.daf.butler.DatasetType` or `None` + The current graph-wide `DatasetType`, or `None`. This will always + be the registry's definition of the parent dataset type, if one + exists. If not, it will be the dataset type definition from the + task in the graph that writes it, if there is one. If there is no + such task, this will be `None`. + is_initial_query_constraint : `bool` + Whether this dataset type is currently marked as a constraint on + the initial data ID query in QuantumGraph generation. + is_prerequisite : `bool` | None` + Whether this dataset type is marked as a prerequisite input in all + edges processed so far. `None` if this is the first edge. + universe : `lsst.daf.butler.DimensionUniverse` + Object that holds all dimension definitions. + producer : `str` or `None` + The label of the task that produces this dataset type in the + pipeline, or `None` if it is an overall input. + consumers : `Sequence` [ `str` ] + Labels for other consuming tasks that have already participated in + this dataset type's resolution. + is_registered : `bool` + Whether a registration for this dataset type was found in the + data repository. + + Returns + ------- + dataset_type : `DatasetType` + The updated graph-wide dataset type. If ``current`` was provided, + this must be equal to it. + is_initial_query_constraint : `bool` + If `True`, this dataset type should be included as a constraint in + the initial data ID query during QuantumGraph generation; this + requires that ``is_initial_query_constraint`` also be `True` on + input. + is_prerequisite : `bool` + Whether this dataset type is marked as a prerequisite input in this + task and all other edges processed so far. + + Raises + ------ + MissingDatasetTypeError + Raised if ``current is None`` and this edge cannot define one on + its own. + IncompatibleDatasetTypeError + Raised if ``current is not None`` and this edge's definition is not + compatible with it. + ConnectionTypeConsistencyError + Raised if a prerequisite input for one task appears as a different + kind of connection in any other task. + """ + if "skypix" in self.raw_dimensions: + if current is None: + raise MissingDatasetTypeError( + f"DatasetType '{self.dataset_type_name}' referenced by " + f"{self.task_label!r} uses 'skypix' as a dimension " + f"placeholder, but has not been registered with the data repository. " + f"Note that reference catalog names are now used as the dataset " + f"type name instead of 'ref_cat'." + ) + rest1 = set(universe.extract(self.raw_dimensions - set(["skypix"])).names) + rest2 = set(dim.name for dim in current.dimensions if not isinstance(dim, SkyPixDimension)) + if rest1 != rest2: + raise IncompatibleDatasetTypeError( + f"Non-skypix dimensions for dataset type {self.dataset_type_name} declared in " + f"connections ({rest1}) are inconsistent with those in " + f"registry's version of this dataset ({rest2})." + ) + dimensions = current.dimensions + else: + dimensions = universe.extract(self.raw_dimensions) + is_initial_query_constraint = is_initial_query_constraint and not self.defer_query_constraint + if is_prerequisite is None: + is_prerequisite = self.is_prerequisite + elif is_prerequisite and not self.is_prerequisite: + raise ConnectionTypeConsistencyError( + f"Dataset type {self.parent_dataset_type_name!r} is a prerequisite input to {consumers}, " + f"but it is not a prerequisite to {self.task_label!r}." + ) + elif not is_prerequisite and self.is_prerequisite: + if producer is not None: + raise ConnectionTypeConsistencyError( + f"Dataset type {self.parent_dataset_type_name!r} is a prerequisite input to " + f"{self.task_label}, but it is produced by {producer!r}." + ) + else: + raise ConnectionTypeConsistencyError( + f"Dataset type {self.parent_dataset_type_name!r} is a prerequisite input to " + f"{self.task_label}, but it is a regular input to {consumers!r}." + ) + + def report_current_origin() -> str: + if is_registered: + return "data repository" + elif producer is not None: + return f"producing task {producer!r}" + else: + return f"consuming task(s) {consumers!r}" + + if self.component is not None: + if current is None: + raise MissingDatasetTypeError( + f"Dataset type {self.parent_dataset_type_name!r} is not registered and not produced by " + f"this pipeline, but it used by task {self.task_label!r}, via component " + f"{self.component!r}. This pipeline cannot be resolved until the parent dataset type is " + "registered." + ) + all_current_components = current.storageClass.allComponents() + if self.component not in all_current_components: + raise IncompatibleDatasetTypeError( + f"Dataset type {self.parent_dataset_type_name!r} has storage class " + f"{current.storageClass_name!r} (from {report_current_origin()}), " + f"which does not include component {self.component!r} " + f"as requested by task {self.task_label!r}." + ) + if all_current_components[self.component].name != self.storage_class_name: + raise IncompatibleDatasetTypeError( + f"Dataset type '{self.parent_dataset_type_name}.{self.component}' has storage class " + f"{all_current_components[self.component].name!r} " + f"(from {report_current_origin()}), which does not match " + f"{self.storage_class_name!r}, as requested by task {self.task_label!r}. " + "Note that storage class conversions of components are not supported." + ) + return current, is_initial_query_constraint, is_prerequisite + else: + dataset_type = DatasetType( + self.parent_dataset_type_name, + dimensions, + storageClass=self.storage_class_name, + isCalibration=self.is_calibration, + ) + if current is not None: + if not is_registered and producer is None: + # Current definition comes from another consumer; we + # require the dataset types to be exactly equal (not just + # compatible), since neither connection should take + # precedence. + if dataset_type != current: + raise MissingDatasetTypeError( + f"Definitions differ for input dataset type {self.parent_dataset_type_name!r}; " + f"task {self.task_label!r} has {dataset_type}, but the definition " + f"from {report_current_origin()} is {current}. If the storage classes are " + "compatible but different, registering the dataset type in the data repository " + "in advance will avoid this error." + ) + elif not dataset_type.is_compatible_with(current): + raise IncompatibleDatasetTypeError( + f"Incompatible definition for input dataset type {self.parent_dataset_type_name!r}; " + f"task {self.task_label!r} has {dataset_type}, but the definition " + f"from {report_current_origin()} is {current}." + ) + return current, is_initial_query_constraint, is_prerequisite + else: + return dataset_type, is_initial_query_constraint, is_prerequisite + + def _to_xgraph_state(self) -> dict[str, Any]: + # Docstring inherited. + result = super()._to_xgraph_state() + result["component"] = self.component + result["is_prerequisite"] = self.is_prerequisite + return result + + +class WriteEdge(Edge): + """Representation of an output connection (including init-outputs) in a + pipeline graph. + + Notes + ----- + When included in an exported `networkx` graph (e.g. + `PipelineGraph.make_xgraph`), write edges set the following edge + attributes: + + - ``parent_dataset_type_name`` + - ``storage_class_name`` + - ``is_init`` + + As with `WRiteEdge` instance attributes, these descriptions of dataset + types are those specific to a task, and may differ from the graph's + resolved dataset type or (if `PipelineGraph.resolve` has not been called) + there may not even be a consistent definition of the dataset type. + """ + + @property + def nodes(self) -> tuple[NodeKey, NodeKey]: + # Docstring inherited. + return (self.task_key, self.dataset_type_key) + + def adapt_dataset_type(self, dataset_type: DatasetType) -> DatasetType: + # Docstring inherited. + if self.storage_class_name != dataset_type.storageClass_name: + return dataset_type.overrideStorageClass(self.storage_class_name) + return dataset_type + + def adapt_dataset_ref(self, ref: DatasetRef) -> DatasetRef: + # Docstring inherited. + if self.storage_class_name != ref.datasetType.storageClass_name: + return ref.overrideStorageClass(self.storage_class_name) + return ref + + @classmethod + def _from_connection_map( + cls, + task_key: NodeKey, + connection_name: str, + connection_map: Mapping[str, BaseConnection], + ) -> WriteEdge: + """Construct a `WriteEdge` instance from a `.BaseConnection` object. + + Parameters + ---------- + task_key : `NodeKey` + Key for the associated task node or task init node. + connection_name : `str` + Internal name for the connection as seen by the task,. + connection_map : Mapping [ `str`, `.BaseConnection` ] + Mapping of post-configuration object to draw dataset type + information from, keyed by connection name. + + Returns + ------- + edge : `WriteEdge` + New edge instance. + """ + connection = connection_map[connection_name] + parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) + if component is not None: + raise ValueError( + f"Illegal output component dataset {connection.name!r} in task {task_key.name!r}." + ) + return cls( + task_key=task_key, + dataset_type_key=NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name), + storage_class_name=connection.storageClass, + connection_name=connection_name, + # InitOutput connections don't have .isCalibration. + is_calibration=getattr(connection, "isCalibration", False), + # InitOutput connections don't have a .dimensions because they + # always have empty dimensions. + raw_dimensions=frozenset(getattr(connection, "dimensions", frozenset())), + ) + + def _resolve_dataset_type(self, current: DatasetType | None, universe: DimensionUniverse) -> DatasetType: + """Participate in the construction of the `DatasetTypeNode` object + associated with this edge. + + Parameters + ---------- + current : `lsst.daf.butler.DatasetType` or `None` + The current graph-wide `DatasetType`, or `None`. This will always + be the registry's definition of the parent dataset type, if one + exists. + universe : `lsst.daf.butler.DimensionUniverse` + Object that holds all dimension definitions. + + Returns + ------- + dataset_type : `DatasetType` + A dataset type compatible with this edge. If ``current`` was + provided, this must be equal to it. + + Raises + ------ + IncompatibleDatasetTypeError + Raised if ``current is not None`` and this edge's definition is not + compatible with it. + """ + dimensions = universe.extract(self.raw_dimensions) + dataset_type = DatasetType( + self.parent_dataset_type_name, + dimensions, + storageClass=self.storage_class_name, + isCalibration=self.is_calibration, + ) + if current is not None: + if not current.is_compatible_with(dataset_type): + raise IncompatibleDatasetTypeError( + f"Incompatible definition for output dataset type {self.parent_dataset_type_name!r}: " + f"task {self.task_label!r} has {current}, but data repository has {dataset_type}." + ) + return current + else: + return dataset_type diff --git a/python/lsst/pipe/base/pipeline_graph/_exceptions.py b/python/lsst/pipe/base/pipeline_graph/_exceptions.py new file mode 100644 index 00000000..8ed6cd16 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_exceptions.py @@ -0,0 +1,95 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ( + "ConnectionTypeConsistencyError", + "DuplicateOutputError", + "IncompatibleDatasetTypeError", + "PipelineGraphExceptionSafetyError", + "PipelineDataCycleError", + "PipelineGraphError", + "PipelineGraphReadError", + "EdgesChangedError", + "UnresolvedGraphError", + "TaskNotImportedError", +) + + +class PipelineGraphError(RuntimeError): + """Base exception raised when there is a problem constructing or resolving + a pipeline graph. + """ + + +class DuplicateOutputError(PipelineGraphError): + """Exception raised when multiple tasks in one pipeline produce the same + output dataset type. + """ + + +class PipelineDataCycleError(PipelineGraphError): + """Exception raised when a pipeline graph contains a cycle.""" + + +class ConnectionTypeConsistencyError(PipelineGraphError): + """Exception raised when the tasks in a pipeline graph use different (and + incompatible) connection types for the same dataset type. + """ + + +class IncompatibleDatasetTypeError(PipelineGraphError): + """Exception raised when the tasks in a pipeline graph define dataset types + with the same name in incompatible ways, or when these are incompatible + with the data repository definition. + """ + + +class UnresolvedGraphError(PipelineGraphError): + """Exception raised when an operation requires dimensions or dataset types + to have been resolved, but they have not been. + """ + + +class PipelineGraphReadError(PipelineGraphError, IOError): + """Exception raised when a serialized PipelineGraph cannot be read.""" + + +class TaskNotImportedError(PipelineGraphError): + """Exception raised when accessing an attribute of a graph or graph node + that is not available unless the task class has been imported and + configured. + """ + + +class EdgesChangedError(PipelineGraphError): + """Exception raised when the edges in one version of a pipeline graph + are not consistent with those in another, but they were expected to be. + """ + + +class PipelineGraphExceptionSafetyError(PipelineGraphError): + """Exception raised when a PipelineGraph method could not provide strong + exception safety, and the graph may have been left in an inconsistent + state. + + The originating exception is always chained when this exception is raised. + """ diff --git a/python/lsst/pipe/base/pipeline_graph/_mapping_views.py b/python/lsst/pipe/base/pipeline_graph/_mapping_views.py new file mode 100644 index 00000000..6f12d42c --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_mapping_views.py @@ -0,0 +1,197 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +from collections.abc import Iterable, Iterator, Mapping +from typing import Any, ClassVar, Sequence, TypeVar, cast, overload + +import networkx + +from ._dataset_types import DatasetTypeNode +from ._exceptions import UnresolvedGraphError +from ._nodes import NodeKey, NodeType +from ._tasks import TaskInitNode, TaskNode + +_N = TypeVar("_N", covariant=True) +_T = TypeVar("_T") + + +class MappingView(Mapping[str, _N]): + """Base class for mapping views into nodes of certain types in a + `PipelineGraph`. + + + Parameters + ---------- + parent_xgraph : `networkx.MultiDiGraph` + Backing networkx graph for the `PipelineGraph` instance. + + Notes + ----- + Instances should only be constructed by `PipelineGraph` and its helper + classes. + + Iteration order is topologically sorted if and only if the backing + `PipelineGraph` has been sorted since its last modification. + """ + + def __init__(self, parent_xgraph: networkx.MultiDiGraph) -> None: + self._parent_xgraph = parent_xgraph + self._keys: list[str] | None = None + + _NODE_TYPE: ClassVar[NodeType] # defined by derived classes + + def __contains__(self, key: object) -> bool: + # The given key may not be a str, but if it isn't it'll just fail the + # check, which is what we want anyway. + return NodeKey(self._NODE_TYPE, cast(str, key)) in self._parent_xgraph + + def __iter__(self) -> Iterator[str]: + if self._keys is None: + self._keys = self._make_keys(self._parent_xgraph) + return iter(self._keys) + + def __getitem__(self, key: str) -> _N: + return self._parent_xgraph.nodes[NodeKey(self._NODE_TYPE, key)]["instance"] + + def __len__(self) -> int: + if self._keys is None: + self._keys = self._make_keys(self._parent_xgraph) + return len(self._keys) + + def __repr__(self) -> str: + return f"{type(self).__name__}({self!s})" + + def __str__(self) -> str: + return f"{{{', '.join(iter(self))}}}" + + def _reorder(self, parent_keys: Sequence[NodeKey]) -> None: + """Set this view's iteration order according to the given iterable of + parent keys. + + Parameters + ---------- + parent_keys : `~collections.abc.Sequence` [ `NodeKey` ] + Superset of the keys in this view, in the new order. + """ + self._keys = self._make_keys(parent_keys) + + def _reset(self) -> None: + """Reset all cached content. + + This should be called by the parent graph after any changes that could + invalidate the view, causing it to be reconstructed when next + requested. + """ + self._keys = None + + def _make_keys(self, parent_keys: Iterable[NodeKey]) -> list[str]: + """Make a sequence of keys for this view from an iterable of parent + keys. + + Parameters + ---------- + parent_keys : `~collections.abc.Iterable` [ `NodeKey` ] + Superset of the keys in this view. + """ + return [str(k) for k in parent_keys if k.node_type is self._NODE_TYPE] + + +class TaskMappingView(MappingView[TaskNode]): + """A mapping view of the tasks in a `PipelineGraph`. + + Notes + ----- + Mapping keys are task labels and values are `TaskNode` instances. + Iteration order is topological if and only if the `PipelineGraph` has been + sorted since its last modification. + """ + + _NODE_TYPE = NodeType.TASK + + +class TaskInitMappingView(MappingView[TaskInitNode]): + """A mapping view of the nodes representing task initialization in a + `PipelineGraph`. + + Notes + ----- + Mapping keys are task labels and values are `TaskInitNode` instances. + Iteration order is topological if and only if the `PipelineGraph` has been + sorted since its last modification. + """ + + _NODE_TYPE = NodeType.TASK_INIT + + +class DatasetTypeMappingView(MappingView[DatasetTypeNode]): + """A mapping view of the nodes representing task initialization in a + `PipelineGraph`. + + Notes + ----- + Mapping keys are parent dataset type names and values are `DatasetTypeNode` + instances, but values are only available for nodes that have been resolved + (see `PipelineGraph.resolve`). Attempting to access an unresolved value + will result in `UnresolvedGraphError` being raised. Keys for unresolved + nodes are always present and iterable. + + Iteration order is topological if and only if the `PipelineGraph` has been + sorted since its last modification. + """ + + _NODE_TYPE = NodeType.DATASET_TYPE + + def __getitem__(self, key: str) -> DatasetTypeNode: + if (result := super().__getitem__(key)) is None: + raise UnresolvedGraphError(f"Node for dataset type {key!r} has not been resolved.") + return result + + def is_resolved(self, key: str) -> bool: + """Test whether a node has been resolved.""" + return super().__getitem__(key) is not None + + @overload + def get_if_resolved(self, key: str) -> DatasetTypeNode | None: + ... # pragma: nocover + + @overload + def get_if_resolved(self, key: str, default: _T) -> DatasetTypeNode | _T: + ... # pragma: nocover + + def get_if_resolved(self, key: str, default: Any = None) -> DatasetTypeNode | Any: + """Get a node or return a default if it has not been resolved. + + Parameters + ---------- + key : `str` + Parent dataset type name. + default + Value to return if this dataset type has not been resolved. + + Raises + ------ + KeyError + Raised if the node is not present in the graph at all. + """ + if (result := super().__getitem__(key)) is None: + return default # type: ignore + return result diff --git a/python/lsst/pipe/base/pipeline_graph/_nodes.py b/python/lsst/pipe/base/pipeline_graph/_nodes.py new file mode 100644 index 00000000..b9ec00fc --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_nodes.py @@ -0,0 +1,85 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ( + "NodeKey", + "NodeType", +) + +import enum +from typing import NamedTuple + + +class NodeType(enum.Enum): + """Enumeration of the types of nodes in a PipelineGraph.""" + + DATASET_TYPE = 0 + TASK_INIT = 1 + TASK = 2 + + @property + def bipartite(self) -> int: + """The integer used as the "bipartite" key in networkx exports of a + `PipelineGraph`. + + This key is used by the `networkx.algorithms.bipartite` module. + """ + return int(self is not NodeType.DATASET_TYPE) + + def __lt__(self, other: NodeType) -> bool: + # We define __lt__ only to be able to provide deterministic tiebreaking + # on top of topological ordering of `PipelineGraph`` and views thereof. + return self.value < other.value + + +class NodeKey(NamedTuple): + """A special key type for nodes in networkx graphs. + + Notes + ----- + Using a tuple for the key allows tasks labels and dataset type names with + the same string value to coexist in the graph. These only rarely appear in + `PipelineGraph` public interfaces; when the node type is implicit, bare + `str` task labels or dataset type names are used instead. + + NodeKey objects stringify to just their name, which is used both as a way + to convert to the `str` objects used in the main public interface and as an + easy way to usefully stringify containers returned directly by networkx + algorithms (especially in error messages). Note that this requires `repr`, + not just `str`, because Python builtin containers always use `repr` on + their items, even in their implementations for `str`. + """ + + node_type: NodeType + """Node type enum for this key.""" + + name: str + """Task label or dataset type name. + + This is always the parent dataset type name for component dataset types. + """ + + def __repr__(self) -> str: + return self.name + + def __str__(self) -> str: + return self.name diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py new file mode 100644 index 00000000..5e8ea5c1 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py @@ -0,0 +1,1389 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ("PipelineGraph",) + +import gzip +import itertools +import json +from collections.abc import Iterable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, BinaryIO, Literal, TypeVar, cast + +import networkx +import networkx.algorithms.bipartite +import networkx.algorithms.dag +from lsst.daf.butler import DataCoordinate, DataId, DimensionGraph, DimensionUniverse, Registry +from lsst.resources import ResourcePath, ResourcePathExpression + +from ._dataset_types import DatasetTypeNode +from ._edges import Edge, ReadEdge, WriteEdge +from ._exceptions import ( + EdgesChangedError, + PipelineDataCycleError, + PipelineGraphError, + PipelineGraphExceptionSafetyError, + UnresolvedGraphError, +) +from ._mapping_views import DatasetTypeMappingView, TaskMappingView +from ._nodes import NodeKey, NodeType +from ._task_subsets import TaskSubset +from ._tasks import TaskInitNode, TaskNode, _TaskNodeImportedData + +if TYPE_CHECKING: + from ..config import PipelineTaskConfig + from ..connections import PipelineTaskConnections + from ..pipeline import TaskDef + from ..pipelineTask import PipelineTask + + +_G = TypeVar("_G", bound=networkx.DiGraph | networkx.MultiDiGraph) + + +class PipelineGraph: + """A graph representation of fully-configured pipeline. + + `PipelineGraph` instances are typically constructed by calling + `.Pipeline.to_graph`, but in rare cases constructing and then populating + an empty one may be preferable. + + Parameters + ---------- + description : `str`, optional + String description for this pipeline. + universe : `lsst.daf.butler.DimensionUniverse`, optional + Definitions for all butler dimensions. If not provided, some + attributes will not be available until `resolve` is called. + data_id : `lsst.daf.butler.DataCoordinate` or other data ID, optional + Data ID that represents a constraint on all quanta generated by this + pipeline. This typically just holds the instrument constraint included + in the pipeline definition, if there was one. + """ + + def __init__( + self, + *, + description: str = "", + universe: DimensionUniverse | None = None, + data_id: DataId | None = None, + ) -> None: + self._init_from_args( + xgraph=None, + sorted_keys=None, + task_subsets=None, + description=description, + universe=universe, + data_id=data_id, + ) + + def _init_from_args( + self, + xgraph: networkx.MultiDiGraph | None, + sorted_keys: Sequence[NodeKey] | None, + task_subsets: dict[str, TaskSubset] | None, + description: str, + universe: DimensionUniverse | None, + data_id: DataId | None, + ) -> None: + """Initialize the graph with possibly-nontrivial arguments. + + Parameters + ---------- + xgraph : `networkx.MultiDiGraph` or `None` + The backing networkx graph, or `None` to create an empty one. + This graph has `NodeKey` instances for nodes and the same structure + as the graph exported by `make_xgraph`, but its nodes and edges + have a single ``instance`` attribute that holds a `TaskNode`, + `TaskInitNode`, `DatasetTypeNode` (or `None`), `ReadEdge`, or + `WriteEdge` instance. + sorted_keys : `Sequence` [ `NodeKey` ] or `None` + Topologically sorted sequence of node keys, or `None` if the graph + is not sorted. + task_subsets : `dict` [ `str`, `TaskSubset` ] + Labeled subsets of tasks. Values must be constructed with + ``xgraph`` as their parent graph. + description : `str` + String description for this pipeline. + universe : `lsst.daf.butler.DimensionUniverse` or `None` + Definitions of all dimensions. + data_id : `lsst.daf.butler.DataCoordinate` or other data ID mapping. + Data ID that represents a constraint on all quanta generated from + this pipeline. + + Notes + ----- + Only empty `PipelineGraph` instances should be constructed directly by + users, which sets the signature of ``__init__`` itself, but methods on + `PipelineGraph` and its helper classes need to be able to create them + with state. Those methods can call this after calling ``__new__`` + manually, skipping ``__init__``. + + `PipelineGraph` mutator methods provide strong exception safety (the + graph is left unchanged when an exception is raised and caught) unless + the exception raised is `PipelineGraphExceptionSafetyError`. + """ + self._xgraph = xgraph if xgraph is not None else networkx.MultiDiGraph() + self._sorted_keys: Sequence[NodeKey] | None = None + self._task_subsets = task_subsets if task_subsets is not None else {} + self._description = description + self._tasks = TaskMappingView(self._xgraph) + self._dataset_types = DatasetTypeMappingView(self._xgraph) + self._raw_data_id: dict[str, Any] + if isinstance(data_id, DataCoordinate): + universe = data_id.universe + self._raw_data_id = data_id.byName() + elif data_id is None: + self._raw_data_id = {} + else: + self._raw_data_id = dict(data_id) + self._universe = universe + if sorted_keys is not None: + self._reorder(sorted_keys) + + def __repr__(self) -> str: + return f"{type(self).__name__}({self.description!r}, tasks={self.tasks!s})" + + @property + def description(self) -> str: + """String description for this pipeline.""" + return self._description + + @description.setter + def description(self, value: str) -> None: + # Docstring in setter. + self._description = value + + @property + def universe(self) -> DimensionUniverse | None: + """Definitions for all butler dimensions.""" + return self._universe + + @property + def data_id(self) -> DataCoordinate: + """Data ID that represents a constraint on all quanta generated from + this pipeline. + + This is may not be available unless `universe` is not `None`. + """ + return DataCoordinate.standardize(self._raw_data_id, universe=self.universe) + + @property + def tasks(self) -> TaskMappingView: + """A mapping view of the tasks in the graph. + + This mapping has `str` task label keys and `TaskNode` values. Iteration + is topologically and deterministically ordered if and only if `sort` + has been called since the last modification to the graph. + """ + return self._tasks + + @property + def dataset_types(self) -> DatasetTypeMappingView: + """A mapping view of the dataset types in the graph. + + This mapping has `str` parent dataset type name keys, but only provides + access to its `DatasetTypeNode` values if `resolve` has been called + since the last modification involving a task that uses a dataset type. + See `DatasetTypeMappingView` for details. + """ + return self._dataset_types + + @property + def task_subsets(self) -> Mapping[str, TaskSubset]: + """A mapping of all labeled subsets of tasks. + + Keys are subset labels, values are sets of task labels. See + `TaskSubset` for more information. + + Use `add_task_subset` to add a new subset. The subsets themselves may + be modified in-place. + """ + return self._task_subsets + + def iter_edges(self, init: bool = False) -> Iterator[Edge]: + """Iterate over edges in the graph. + + Parameters + ---------- + init : `bool`, optional + If `True` (`False` is default) iterate over the edges between task + initialization node and init input/output dataset types, instead of + the runtime task nodes and regular input/output/prerequisite + dataset types. + + Returns + ------- + edges : `~collections.abc.Iterator` [ `Edge` ] + A lazy iterator over `Edge` (`WriteEdge` or `ReadEdge`) instances. + + Notes + ----- + This method always returns _either_ init edges or runtime edges, never + both. The full (internal) graph that contains both also includes a + special edge that connects each task init node to its runtime node; + that is also never returned by this method, since it is never a part of + the init-only or runtime-only subgraphs. + """ + edge: Edge + for _, _, edge in self._xgraph.edges(data="instance"): + if edge is not None and edge.is_init == init: + yield edge + + def iter_nodes( + self, + ) -> Iterator[ + tuple[Literal[NodeType.TASK_INIT], str, TaskInitNode] + | tuple[Literal[NodeType.TASK], str, TaskInitNode] + | tuple[Literal[NodeType.DATASET_TYPE], str, DatasetTypeNode | None] + ]: + """Iterate over nodes in the graph. + + Returns + ------- + nodes : `~collections.abc.Iterator` [ `tuple` ] + A lazy iterator over all of the nodes in the graph. Each yielded + element is a tuple of: + + - the node type enum value (`NodeType`); + - the string name for the node (task label or parent dataset type + name); + - the node value (`TaskNode`, `TaskInitNode`, `DatasetTypeNode`, + or `None` for dataset type nodes that have not been resolved). + """ + key: NodeKey + if self._sorted_keys is not None: + for key in self._sorted_keys: + yield key.node_type, key.name, self._xgraph.nodes[key]["instance"] # type: ignore + else: + for key, node in self._xgraph.nodes(data="instance"): + yield key.node_type, key.name, node # type: ignore + + def iter_overall_inputs(self) -> Iterator[tuple[str, DatasetTypeNode | None]]: + """Iterate over all of the dataset types that are consumed but not + produced by the graph. + + Returns + ------- + dataset_types : `~collections.abc.Iterator` [ `tuple` ] + A lazy iterator over the overall-input dataset types (including + overall init inputs and prerequisites). Each yielded element is a + tuple of: + + - the parent dataset type name; + - the resolved `DatasetTypeNode`, or `None` if the dataset type has + - not been resolved. + """ + for generation in networkx.algorithms.dag.topological_generations(self._xgraph): + key: NodeKey + for key in generation: + # While we expect all tasks to have at least one input and + # hence never appear in the first topological generation, that + # is not true of task init nodes. + if key.node_type is NodeType.DATASET_TYPE: + yield key.name, self._xgraph.nodes[key]["instance"] + return + + def make_xgraph(self) -> networkx.MultiDiGraph: + """Export a networkx representation of the full pipeline graph, + including both init and runtime edges. + + Returns + ------- + xgraph : `networkx.MultiDiGraph` + Directed acyclic graph with parallel edges. + + Notes + ----- + The returned graph uses `NodeKey` instances for nodes. Parallel edges + represent the same dataset type appearing in multiple connections for + the same task, and are hence rare. The connection name is used as the + edge key to disambiguate those parallel edges. + + Almost all edges connect dataset type nodes to task or task init nodes + or vice versa, but there is also a special edge that connects each task + init node to its runtime node. The existence of these nodes makes the + graph not quite bipartite, unless its init-only and runtime-only + subgraphs. + + See `TaskNode`, `TaskInitNode`, `DatasetTypeNode`, `ReadEdge`, and + `WriteEdge` for the descriptive node and edge attributes added. + """ + return self._transform_xgraph_state(self._xgraph.copy(), skip_edges=False) + + def make_bipartite_xgraph(self, init: bool = False) -> networkx.MultiDiGraph: + """Return a bipartite networkx representation of just the runtime or + init-time pipeline graph. + + Parameters + ---------- + init : `bool`, optional + If `True` (`False` is default) return the graph of task + initialization nodes and init input/output dataset types, instead + of the graph of runtime task nodes and regular + input/output/prerequisite dataset types. + + Returns + ------- + xgraph : `networkx.MultiDiGraph` + Directed acyclic graph with parallel edges. + + Notes + ----- + The returned graph uses `NodeKey` instances for nodes. Parallel edges + represent the same dataset type appearing in multiple connections for + the same task, and are hence rare. The connection name is used as the + edge key to disambiguate those parallel edges. + + This graph is bipartite because each dataset type node only has edges + that connect it to a task [init] node, and vice versa. + + See `TaskNode`, `TaskInitNode`, `DatasetTypeNode`, `ReadEdge`, and + `WriteEdge` for the descriptive node and edge attributes added. + """ + return self._transform_xgraph_state( + self._make_bipartite_xgraph_internal(init).copy(), skip_edges=False + ) + + def make_task_xgraph(self, init: bool = False) -> networkx.DiGraph: + """Return a networkx representation of just the tasks in the pipeline. + + Parameters + ---------- + init : `bool`, optional + If `True` (`False` is default) return the graph of task + initialization nodes, instead of the graph of runtime task nodes. + + Returns + ------- + xgraph : `networkx.DiGraph` + Directed acyclic graph with no parallel edges. + + Notes + ----- + The returned graph uses `NodeKey` instances for nodes. The dataset + types that link these tasks are not represented at all; edges have no + attributes, and there are no parallel edges. + + See `TaskNode` and `TaskInitNode` for the descriptive node and + attributes added. + """ + bipartite_xgraph = self._make_bipartite_xgraph_internal(init) + task_keys = [ + key + for key, bipartite in bipartite_xgraph.nodes(data="bipartite") + if bipartite == NodeType.TASK.bipartite + ] + return self._transform_xgraph_state( + networkx.algorithms.bipartite.projected_graph(networkx.DiGraph(bipartite_xgraph), task_keys), + skip_edges=True, + ) + + def make_dataset_type_xgraph(self, init: bool = False) -> networkx.DiGraph: + """Return a networkx representation of just the dataset types in the + pipeline. + + Parameters + ---------- + init : `bool`, optional + If `True` (`False` is default) return the graph of init input and + output dataset types, instead of the graph of runtime (input, + output, prerequisite input) dataset types. + + Returns + ------- + xgraph : `networkx.DiGraph` + Directed acyclic graph with no parallel edges. + + Notes + ----- + The returned graph uses `NodeKey` instances for nodes. The tasks that + link these tasks are not represented at all; edges have no attributes, + and there are no parallel edges. + + See `DatasetTypeNode` for the descriptive node and attributes added. + """ + bipartite_xgraph = self._make_bipartite_xgraph_internal(init) + dataset_type_keys = [ + key + for key, bipartite in bipartite_xgraph.nodes(data="bipartite") + if bipartite == NodeType.DATASET_TYPE.bipartite + ] + return self._transform_xgraph_state( + networkx.algorithms.bipartite.projected_graph( + networkx.DiGraph(bipartite_xgraph), dataset_type_keys + ), + skip_edges=True, + ) + + def _make_bipartite_xgraph_internal(self, init: bool) -> networkx.MultiDiGraph: + """Make a bipartite init-only or runtime-only internal subgraph. + + See `make_bipartite_xgraph` for parameters and return values. + + Notes + ----- + This method returns a view of the `PipelineGraph` object's internal + backing graph, and hence should only be called in methods that copy the + result either explicitly or by running a copying algorithm before + returning it to the user. + """ + return self._xgraph.edge_subgraph([edge.key for edge in self.iter_edges(init)]) + + def _transform_xgraph_state(self, xgraph: _G, skip_edges: bool) -> _G: + """Transform networkx graph attributes in-place from the internal + "instance" attributes to the documented exported attributes. + + Parameters + ---------- + xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph` + Graph whose state should be transformed. + skip_edges : `bool` + If `True`, do not transform edge state. + + Returns + ------- + xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph` + The same object passed in, after modification. + + Notes + ----- + This should be called after making a copy of the internal graph but + before any projection down to just task or dataset type nodes, since + it assumes stateful edges. + """ + state: dict[str, Any] + for state in xgraph.nodes.values(): + node_value: TaskInitNode | TaskNode | DatasetTypeNode | None = state.pop("instance") + if node_value is not None: + state.update(node_value._to_xgraph_state()) + if not skip_edges: + for _, _, state in xgraph.edges(data=True): + edge: Edge | None = state.pop("instance", None) + if edge is not None: + state.update(edge._to_xgraph_state()) + return xgraph + + def group_by_dimensions( + self, prerequisites: bool = False + ) -> dict[DimensionGraph, tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]]]: + """Group this graph's tasks and dataset types by their dimensions. + + Parameters + ---------- + prerequisites : `bool`, optional + If `True`, include prerequisite dataset types as well as regular + input and output datasets (including intermediates). + + Returns + ------- + groups : `dict` [ `DimensionGraph`, `tuple` ] + A dictionary of groups keyed by `DimensionGraph`, in which each + value is a tuple of: + + - a `dict` of `TaskNode` instances, keyed by task label + - a `dict` of `DatasetTypeNode` instances, keyed by + dataset type name. + + that have those dimensions. + + Notes + ----- + Init inputs and outputs are always included, but always have empty + dimensions and are hence are all grouped together. + """ + result: dict[DimensionGraph, tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]]] = {} + next_new_value: tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]] = ({}, {}) + for task_label, task_node in self.tasks.items(): + if task_node.dimensions is None: + raise UnresolvedGraphError(f"Task with label {task_label!r} has not been resolved.") + if (group := result.setdefault(task_node.dimensions, next_new_value)) is next_new_value: + next_new_value = ({}, {}) # make new lists for next time + group[0][task_node.label] = task_node + for dataset_type_name, dataset_type_node in self.dataset_types.items(): + if dataset_type_node is None: + raise UnresolvedGraphError(f"Dataset type {dataset_type_name!r} has not been resolved.") + if not dataset_type_node.is_prerequisite or prerequisites: + if ( + group := result.setdefault(dataset_type_node.dataset_type.dimensions, next_new_value) + ) is next_new_value: + next_new_value = ({}, {}) # make new lists for next time + group[1][dataset_type_node.name] = dataset_type_node + return result + + @property + def is_sorted(self) -> bool: + """Whether this graph's tasks and dataset types are topologically + sorted with the exact same deterministic tiebreakers that `sort` would + apply. + + This may perform (and then discard) a full sort if `has_been_sorted` is + `False`. If the goal is to obtain a sorted graph, it is better to just + call `sort` without guarding that with an ``if not graph.is_sorted`` + check. + """ + if self._sorted_keys is not None: + return True + return all( + sorted == unsorted + for sorted, unsorted in zip(networkx.lexicographical_topological_sort(self._xgraph), self._xgraph) + ) + + @property + def has_been_sorted(self) -> bool: + """Whether this graph's tasks and dataset types have been + topologically sorted (with unspecified but deterministic tiebreakers) + since the last modification to the graph. + + This may return `False` if the graph *happens* to be sorted but `sort` + was never called, but it is potentially much faster than `is_sorted`, + which may attempt (and then discard) a full sort if `has_been_sorted` + is `False`. + """ + return self._sorted_keys is not None + + def sort(self) -> None: + """Sort this graph's nodes topologically with deterministic (but + unspecified) tiebreakers. + + This does nothing if the graph is already known to be sorted. + """ + if self._sorted_keys is None: + try: + sorted_keys: Sequence[NodeKey] = list(networkx.lexicographical_topological_sort(self._xgraph)) + except networkx.NetworkXUnfeasible as err: # pragma: no cover + # Should't be possible to get here, because we check for cycles + # when adding tasks, but we guard against it anyway. + cycle = networkx.find_cycle(self._xgraph) + raise PipelineDataCycleError( + f"Cycle detected while attempting to sort graph: {cycle}." + ) from err + self._reorder(sorted_keys) + + def producer_of(self, dataset_type_name: str) -> WriteEdge | None: + """Return the `WriteEdge` that links the producing task to the named + dataset type. + + Parameters + ---------- + dataset_type_name : `str` + Dataset type name. Must not be a component. + + Returns + ------- + edge : `WriteEdge` or `None` + Producing edge or `None` if there isn't one in this graph. + """ + for _, _, edge in self._xgraph.in_edges( + NodeKey(NodeType.DATASET_TYPE, dataset_type_name), data="instance" + ): + return edge + return None + + def consumers_of(self, dataset_type_name: str) -> list[ReadEdge]: + """Return the `ReadEdge` objects that link the named dataset type to + the tasks that consume it. + + Parameters + ---------- + dataset_type_name : `str` + Dataset type name. Must not be a component. + + Returns + ------- + edges : `list` [ `ReadEdge` ] + Edges that connect this dataset type to the tasks that consume it. + """ + return [ + edge + for _, _, edge in self._xgraph.out_edges( + NodeKey(NodeType.DATASET_TYPE, dataset_type_name), data="instance" + ) + ] + + def add_task( + self, + label: str, + task_class: type[PipelineTask], + config: PipelineTaskConfig, + connections: PipelineTaskConnections | None = None, + ) -> TaskNode: + """Add a new task to the graph. + + Parameters + ---------- + label : `str` + Label for the task in the pipeline. + task_class : `type` [ `PipelineTask` ] + Class object for the task. + config : `PipelineTaskConfig` + Configuration for the task. + connections : `PipelineTaskConnections`, optional + Object that describes the dataset types used by the task. If not + provided, one will be constructed from the given configuration. If + provided, it is assumed that ``config`` has already been validated + and frozen. + + Returns + ------- + node : `TaskNode` + The new task node added to the graph. + + Raises + ------ + ValueError + Raised if configuration validation failed when constructing + ``connections``. + PipelineDataCycleError + Raised if the graph is cyclic after this addition. + RuntimeError + Raised if an unexpected exception (which will be chained) occurred + at a stage that may have left the graph in an inconsistent state. + Other exceptions should leave the graph unchanged. + + Notes + ----- + Checks for dataset type consistency and multiple producers do not occur + until `resolve` is called, since the resolution depends on both the + state of the data repository and all contributing tasks. + + Adding new tasks removes any existing resolutions of all dataset types + it references and marks the graph as unsorted. It is most effiecient + to add all tasks up front and only then resolve and/or sort the graph. + """ + key = NodeKey(NodeType.TASK, label) + init_key = NodeKey(NodeType.TASK_INIT, label) + task_node = TaskNode._from_imported_data( + key, + init_key, + _TaskNodeImportedData.configure(label, task_class, config, connections), + universe=self.universe, + ) + self.add_task_nodes([task_node]) + return task_node + + def add_task_nodes(self, nodes: Iterable[TaskNode]) -> None: + """Add one or more existing task nodes to the graph. + + Parameters + ---------- + nodes : `~collections.abc.Iterable` [ `TaskNode` ] + Iterable of task nodes to add. If any tasks have resolved + dimensions, they must have the same dimension universe as the rest + of the graph. + + Raises + ------ + PipelineDataCycleError + Raised if the graph is cyclic after this addition. + + Notes + ----- + Checks for dataset type consistency and multiple producers do not occur + until `resolve` is called, since the resolution depends on both the + state of the data repository and all contributing tasks. + + Adding new tasks removes any existing resolutions of all dataset types + it references and marks the graph as unsorted. It is most effiecient + to add all tasks up front and only then resolve and/or sort the graph. + """ + node_data: list[tuple[NodeKey, dict[str, Any]]] = [] + edge_data: list[tuple[NodeKey, NodeKey, str, dict[str, Any]]] = [] + for task_node in nodes: + task_node = task_node._resolved(self._universe) + node_data.append( + (task_node.key, {"instance": task_node, "bipartite": task_node.key.node_type.bipartite}) + ) + node_data.append( + ( + task_node.init.key, + {"instance": task_node.init, "bipartite": task_node.init.key.node_type.bipartite}, + ) + ) + # Convert the edge objects attached to the task node to networkx. + for read_edge in task_node.init.iter_all_inputs(): + self._append_graph_data_from_edge(node_data, edge_data, read_edge) + for write_edge in task_node.init.iter_all_outputs(): + self._append_graph_data_from_edge(node_data, edge_data, write_edge) + for read_edge in task_node.iter_all_inputs(): + self._append_graph_data_from_edge(node_data, edge_data, read_edge) + for write_edge in task_node.iter_all_outputs(): + self._append_graph_data_from_edge(node_data, edge_data, write_edge) + # Add a special edge (with no Edge instance) that connects the + # TaskInitNode to the runtime TaskNode. + edge_data.append((task_node.init.key, task_node.key, Edge.INIT_TO_TASK_NAME, {"instance": None})) + if not node_data and not edge_data: + return + # Checks and preparation complete; time to start the actual + # modification, during which it's hard to provide strong exception + # safety. Start by resetting the sort ordering, if there is one. + self._reset() + try: + self._xgraph.add_nodes_from(node_data) + self._xgraph.add_edges_from(edge_data) + if not networkx.algorithms.dag.is_directed_acyclic_graph(self._xgraph): + cycle = networkx.find_cycle(self._xgraph) + raise PipelineDataCycleError(f"Cycle detected while adding tasks: {cycle}.") + except Exception: + # First try to roll back our changes. + try: + self._xgraph.remove_edges_from(edge_data) + self._xgraph.remove_nodes_from(key for key, _ in node_data) + except Exception as err: # pragma: no cover + # There's no known way to get here, but we want to make it + # clear it's a big problem if we do. + raise PipelineGraphExceptionSafetyError( + "Error while attempting to revert PipelineGraph modification has left the graph in " + "an inconsistent state." + ) from err + # Successfully rolled back; raise the original exception. + raise + + def reconfigure_tasks( + self, + *args: tuple[str, PipelineTaskConfig], + check_edges_unchanged: bool = False, + assume_edges_unchanged: bool = False, + **kwargs: PipelineTaskConfig, + ) -> None: + """Update the configuration for one or more tasks. + + Parameters + ---------- + *args : `tuple` [ `str`, `.PipelineTaskConfig` ] + Positional arguments are each a 2-tuple of task label and new + config object. Note that the same arguments may also be passed as + ``**kwargs``, which is usually more readable, but task labels in + ``*args`` are not required to be valid Python identifiers. + check_edges_unchanged : `bool`, optional + If `True`, require the edges (connections) of the modified tasks to + remain unchanged after the configuration updates, and verify that + this is the case. + assume_edges_unchanged : `bool`, optional + If `True`, the caller declares that the edges (connections) of the + modified tasks will remain unchanged after the configuration + updates, and that it is unnecessary to check this. + **kwargs : `.PipelineTaskConfig` + New config objects or overrides to apply to copies of the current + config objects, with task labels as the keywords. + + Raises + ------ + ValueError + Raised if ``assume_edges_unchanged`` and ``check_edges_unchanged`` + are both `True`, or if the same task appears twice. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change. + + Notes + ----- + If reconfiguring a task causes its edges to change, any dataset type + nodes connected to that task (not just those whose edges have changed!) + will be unresolved. + """ + new_configs: dict[str, PipelineTaskConfig] = {} + for task_label, config_update in itertools.chain(args, kwargs.items()): + if new_configs.setdefault(task_label, config_update) is not config_update: + raise ValueError(f"Config for {task_label!r} provided more than once.") + updates = { + task_label: self.tasks[task_label]._reconfigured(config, rebuild=not assume_edges_unchanged) + for task_label, config in new_configs.items() + } + self._replace_task_nodes( + updates, + check_edges_unchanged=check_edges_unchanged, + assume_edges_unchanged=assume_edges_unchanged, + message_header=( + "Unexpected change in edges for task {task_label!r} from original config (A) to " + "new configs (B):" + ), + ) + + def remove_tasks( + self, labels: Iterable[str], drop_from_subsets: bool = True + ) -> list[tuple[TaskNode, set[str]]]: + """Remove one or more tasks from the graph. + + Parameters + ---------- + labels : `~collections.abc.Iterable` [ `str` ] + Iterable of the labels of the tasks to remove. + drop_from_subsets : `bool`, optional + If `True`, drop each removed task from any subset in which it + currently appears. If `False`, raise `PipelineGraphError` if any + such subsets exist. + + Returns + ------- + nodes_and_subsets : `list` [ `tuple` [ `TaskNode`, `set` [ `str` ] ] ] + List of nodes removed and the labels of task subsets that + referenced them. + + Raises + ------ + PipelineGraphError + Raised if ``drop_from_subsets`` is `False` and the task is still + part of one or more subsets. + + Notes + ----- + Removing a task will cause dataset nodes with no other referencing + tasks to be removed. Any other dataset type nodes referenced by a + removed task will be reset to an "unresolved" state. + """ + task_nodes_and_subsets = [] + dataset_types: set[NodeKey] = set() + nodes_to_remove = set() + for label in labels: + task_node: TaskNode = self._xgraph.nodes[NodeKey(NodeType.TASK, label)]["instance"] + # Find task subsets that reference this task. + referencing_subsets = { + subset_label + for subset_label, task_subset in self.task_subsets.items() + if label in task_subset + } + if not drop_from_subsets and referencing_subsets: + raise PipelineGraphError( + f"Task {label!r} is still referenced by subset(s) {referencing_subsets}." + ) + task_nodes_and_subsets.append((task_node, referencing_subsets)) + # Find dataset types referenced by this task. + dataset_types.update(self._xgraph.predecessors(task_node.key)) + dataset_types.update(self._xgraph.successors(task_node.key)) + dataset_types.update(self._xgraph.predecessors(task_node.init.key)) + dataset_types.update(self._xgraph.successors(task_node.init.key)) + # Since there's an edge between the task and its init node, we'll + # have added those two nodes here, too, and we don't want that. + dataset_types.remove(task_node.init.key) + dataset_types.remove(task_node.key) + # Mark the task node and its init node for removal from the graph. + nodes_to_remove.add(task_node.key) + nodes_to_remove.add(task_node.init.key) + # Process the referenced datasets to see which ones are orphaned and + # need to be removed vs. just unresolved. + nodes_to_unresolve = [] + for dataset_type_key in dataset_types: + related_tasks = set() + related_tasks.update(self._xgraph.predecessors(dataset_type_key)) + related_tasks.update(self._xgraph.successors(dataset_type_key)) + related_tasks.difference_update(nodes_to_remove) + if not related_tasks: + nodes_to_remove.add(dataset_type_key) + else: + nodes_to_unresolve.append(dataset_type_key) + # Checks and preparation complete; time to start the actual + # modification, during which it's hard to provide strong exception + # safety. Start by resetting the sort ordering. + self._reset() + try: + for dataset_type_key in nodes_to_unresolve: + self._xgraph.nodes[dataset_type_key]["instance"] = None + for task_node, referencing_subsets in task_nodes_and_subsets: + for subset_label in referencing_subsets: + self._task_subsets[subset_label].remove(task_node.label) + self._xgraph.remove_nodes_from(nodes_to_remove) + except Exception as err: # pragma: no cover + # There's no known way to get here, but we want to make it + # clear it's a big problem if we do. + raise PipelineGraphExceptionSafetyError( + "Error during task removal has left the graph in an inconsistent state." + ) from err + return task_nodes_and_subsets + + def add_task_subset(self, subset_label: str, task_labels: Iterable[str], description: str = "") -> None: + """Add a label for a set of tasks that are already in the pipeline. + + Parameters + ---------- + subset_label : `str` + Label for this set of tasks. + task_labels : `~collections.abc.Iterable` [ `str` ] + Labels of the tasks to include in the set. All must already be + included in the graph. + description : `str`, optional + String description to associate with this label. + """ + subset = TaskSubset(self._xgraph, subset_label, set(task_labels), description) + self._task_subsets[subset_label] = subset + + def remove_task_subset(self, subset_label: str) -> None: + """Remove a labeled set of tasks.""" + del self._task_subsets[subset_label] + + def copy(self) -> PipelineGraph: + """Return a copy of this graph that copies all mutable state.""" + xgraph = self._xgraph.copy() + result = PipelineGraph.__new__(PipelineGraph) + result._init_from_args( + xgraph, + self._sorted_keys, + task_subsets={ + k: TaskSubset(xgraph, v.label, set(v._members), v.description) + for k, v in self._task_subsets.items() + }, + description=self._description, + universe=self.universe, + data_id=self._raw_data_id, + ) + return result + + def __copy__(self) -> PipelineGraph: + # Fully shallow copies are dangerous; we don't want shared mutable + # state to lead to broken class invariants. + return self.copy() + + def __deepcopy__(self, memo: dict) -> PipelineGraph: + # Genuine deep copies are unnecessary, since we should only ever care + # that mutable state is copied. + return self.copy() + + def import_and_configure( + self, check_edges_unchanged: bool = False, assume_edges_unchanged: bool = False + ) -> None: + """Import the `PipelineTask` classes referenced by all task nodes and + update those nodes accordingly. + + Parameters + ---------- + check_edges_unchanged : `bool`, optional + If `True`, require the edges (connections) of the modified tasks to + remain unchanged after importing and configuring each task, and + verify that this is the case. + assume_edges_unchanged : `bool`, optional + If `True`, the caller declares that the edges (connections) of the + modified tasks will remain unchanged importing and configuring each + task, and that it is unnecessary to check this. + + Raises + ------ + ValueError + Raised if ``assume_edges_unchanged`` and ``check_edges_unchanged`` + are both `True`, or if a full config is provided for a task after + another full config or an override has already been provided. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change. + + Notes + ----- + This method shouldn't need to be called unless the graph was + deserialized without importing and configuring immediately, which is + not the default behavior (but it can greatly speed up deserialization). + If all tasks have already been imported this does nothing. + + Importing and configuring a task can change its + `~TaskNode.task_class_name` or `~TaskClass.get_config_str` output, + usually because the software used to read a serialized graph is newer + than the software used to write it (e.g. a new config option has been + added, or the task was moved to a new module with a forwarding alias + left behind). These changes are allowed by ``check=True``. + + If importing and configuring a task causes its edges to change, any + dataset type nodes linked to those edges will be reset to the + unresolved state. + """ + rebuild = check_edges_unchanged or not assume_edges_unchanged + updates: dict[str, TaskNode] = {} + node_key: NodeKey + for node_key, node_state in self._xgraph.nodes.items(): + if node_key.node_type is NodeType.TASK: + task_node: TaskNode = node_state["instance"] + new_task_node = task_node._imported_and_configured(rebuild) + if new_task_node is not task_node: + updates[task_node.label] = new_task_node + self._replace_task_nodes( + updates, + check_edges_unchanged=check_edges_unchanged, + assume_edges_unchanged=assume_edges_unchanged, + message_header=( + "In task with label {task_label!r}, persisted edges (A)" + "differ from imported and configured edges (B):" + ), + ) + + def resolve(self, registry: Registry) -> None: + """Resolve all dimensions and dataset types and check them for + consistency. + + Resolving a graph also causes it to be sorted. + + Parameters + ---------- + registry : `lsst.daf.butler.Registry` + Client for the data repository to resolve against. + + Notes + ----- + The `universe` attribute are set to ``registry.dimensions`` and used to + set all `TaskNode.dimensions` attributes. Dataset type nodes are + resolved by first looking for a registry definition, then using the + producing task's definition, then looking for consistency between all + consuming task definitions. + + Raises + ------ + ConnectionTypeConsistencyError + Raised if a prerequisite input for one task appears as a different + kind of connection in any other task. + DuplicateOutputError + Raised if multiple tasks have the same dataset type as an output. + IncompatibleDatasetTypeError + Raised if different tasks have different definitions of a dataset + type. Different but compatible storage classes are permitted. + MissingDatasetTypeError + Raised if a dataset type definition is required to exist in the + data repository but none was found. This should only occur for + dataset types that are not produced by a task in the pipeline and + are consumed with different storage classes or as components by + tasks in the pipeline. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change after import and reconfiguration. + """ + node_key: NodeKey + updates: dict[NodeKey, TaskNode | DatasetTypeNode] = {} + for node_key, node_state in self._xgraph.nodes.items(): + match node_key.node_type: + case NodeType.TASK: + task_node: TaskNode = node_state["instance"] + new_task_node = task_node._resolved(registry.dimensions) + if new_task_node is not task_node: + updates[node_key] = new_task_node + case NodeType.DATASET_TYPE: + dataset_type_node: DatasetTypeNode | None = node_state["instance"] + new_dataset_type_node = DatasetTypeNode._from_edges( + node_key, self._xgraph, registry, previous=dataset_type_node + ) + if new_dataset_type_node is not dataset_type_node: + updates[node_key] = new_dataset_type_node + try: + for node_key, node_value in updates.items(): + self._xgraph.nodes[node_key]["instance"] = node_value + except Exception as err: # pragma: no cover + # There's no known way to get here, but we want to make it + # clear it's a big problem if we do. + raise PipelineGraphExceptionSafetyError( + "Error during dataset type resolution has left the graph in an inconsistent state." + ) from err + self.sort() + self._universe = registry.dimensions + + @classmethod + def read_stream( + cls, + stream: BinaryIO, + import_and_configure: bool = True, + check_edges_unchanged: bool = False, + assume_edges_unchanged: bool = False, + ) -> PipelineGraph: + """Read a serialized `PipelineGraph` from a file-like object. + + Parameters + ---------- + stream : `BinaryIO` + File-like object opened for binary reading, containing + gzip-compressed JSON. + import_and_configure : `bool`, optional + If `True`, import and configure all tasks immediately (see the + `import_and_configure` method). If `False`, some `TaskNode` and + `TaskInitNode` attributes will not be available, but reading may be + much faster. + check_edges_unchanged : `bool`, optional + Forwarded to `import_and_configure` after reading. + assume_edges_unchanged : `bool`, optional + Forwarded to `import_and_configure` after reading. + + Returns + ------- + graph : `PipelineGraph` + Deserialized pipeline graph. + + Raises + ------ + PipelineGraphReadError + Raised if the serialized `PipelineGraph` is not self-consistent. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change after import and reconfiguration. + """ + from .io import SerializedPipelineGraph + + with gzip.open(stream, "rb") as uncompressed_stream: + data = json.load(uncompressed_stream) + serialized_graph = SerializedPipelineGraph.parse_obj(data) + return serialized_graph.deserialize( + import_and_configure=import_and_configure, + check_edges_unchanged=check_edges_unchanged, + assume_edges_unchanged=assume_edges_unchanged, + ) + + @classmethod + def read_uri( + cls, + uri: ResourcePathExpression, + import_and_configure: bool = True, + check_edges_unchanged: bool = False, + assume_edges_unchanged: bool = False, + ) -> PipelineGraph: + """Read a serialized `PipelineGraph` from a file at a URI. + + Parameters + ---------- + uri : convertible to `lsst.resources.ResourcePath` + URI to a gzip-compressed JSON file containing a serialized pipeline + graph. + import_and_configure : `bool`, optional + If `True`, import and configure all tasks immediately (see + the `import_and_configure` method). If `False`, some `TaskNode` + and `TaskInitNode` attributes will not be available, but reading + may be much faster. + check_edges_unchanged : `bool`, optional + Forwarded to `import_and_configure` after reading. + assume_edges_unchanged : `bool`, optional + Forwarded to `import_and_configure` after reading. + + Returns + ------- + graph : `PipelineGraph` + Deserialized pipeline graph. + + Raises + ------ + PipelineGraphReadError + Raised if the serialized `PipelineGraph` is not self-consistent. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change after import and reconfiguration. + """ + uri = ResourcePath(uri) + with uri.open("rb") as stream: + return cls.read_stream( + cast(BinaryIO, stream), + import_and_configure=import_and_configure, + check_edges_unchanged=check_edges_unchanged, + assume_edges_unchanged=assume_edges_unchanged, + ) + + def write_stream(self, stream: BinaryIO) -> None: + """Write the pipeline to a file-like object. + + Parameters + ---------- + stream + File-like object opened for binary writing. + + Notes + ----- + The file format is gzipped JSON, and is intended to be human-readable, + but it should not be considered a stable public interface for outside + code, which should always use `PipelineGraph` methods (or at least the + `io.SerializedPipelineGraph` class) to read these files. + """ + from .io import SerializedPipelineGraph + + with gzip.open(stream, mode="wb") as compressed_stream: + compressed_stream.write( + SerializedPipelineGraph.serialize(self).json(exclude_defaults=True, indent=2).encode("utf-8") + ) + + def write_uri(self, uri: ResourcePathExpression) -> None: + """Write the pipeline to a file given a URI. + + Parameters + ---------- + uri : convertible to `lsst.resources.ResourcePath` + URI to write to . May have ``.json.gz`` or no extension (which + will cause a ``.json.gz`` extension to be added). + + Notes + ----- + The file format is gzipped JSON, and is intended to be human-readable, + but it should not be considered a stable public interface for outside + code, which should always use `PipelineGraph` methods (or at least the + `io.SerializedPipelineGraph` class) to read these files. + """ + uri = ResourcePath(uri) + extension = uri.getExtension() + if not extension: + uri = uri.updatedExtension(".json.gz") + elif extension != ".json.gz": + raise ValueError("Expanded pipeline files should always have a .json.gz extension.") + with uri.open(mode="wb") as stream: + self.write_stream(cast(BinaryIO, stream)) + + def _iter_task_defs(self) -> Iterator[TaskDef]: + """Iterate over this pipeline as a sequence of `TaskDef` instances. + + Notes + ----- + This is a package-private method intended to aid in the transition to a + codebase more fully integrated with the `PipelineGraph` class, in which + both `TaskDef` and `PipelineDatasetTypes` are expected to go away, and + much of the functionality on the `Pipeline` class will be moved to + `PipelineGraph` as well. + + Raises + ------ + TaskNotImportedError + Raised if `TaskNode.is_imported` is `False` for any task. + """ + from ..pipeline import TaskDef + + for node in self._tasks.values(): + yield TaskDef( + config=node.config, + taskClass=node.task_class, + label=node.label, + connections=node._get_imported_data().connections, + ) + + def _replace_task_nodes( + self, + updates: Mapping[str, TaskNode], + check_edges_unchanged: bool, + assume_edges_unchanged: bool, + message_header: str, + ) -> None: + """Replace task nodes and update edges and dataset type nodes + accordingly. + + Parameters + ---------- + updates : `Mapping` [ `str`, `TaskNode` ] + New task nodes with task label keys. All keys must be task labels + that are already present in the graph. + check_edges_unchanged : `bool`, optional + If `True`, require the edges (connections) of the modified tasks to + remain unchanged after importing and configuring each task, and + verify that this is the case. + assume_edges_unchanged : `bool`, optional + If `True`, the caller declares that the edges (connections) of the + modified tasks will remain unchanged importing and configuring each + task, and that it is unnecessary to check this. + message_header : `str` + Template for `str.format` with a single ``task_label`` placeholder + to use as the first line in `EdgesChangedError` messages that show + the differences between new task edges and old task edges. Should + include the fact that the rest of the message will refer to the old + task as "A" and the new task as "B", and end with a colon. + + Raises + ------ + ValueError + Raised if ``assume_edges_unchanged`` and ``check_edges_unchanged`` + are both `True`, or if a full config is provided for a task after + another full config or an override has already been provided. + EdgesChangedError + Raised if ``check_edges_unchanged=True`` and the edges of a task do + change. + """ + deep: dict[str, TaskNode] = {} + shallow: dict[str, TaskNode] = {} + if assume_edges_unchanged: + if check_edges_unchanged: + raise ValueError("Cannot simultaneously assume and check that edges have not changed.") + shallow.update(updates) + else: + for task_label, new_task_node in updates.items(): + old_task_node = self.tasks[task_label] + messages = old_task_node.diff_edges(new_task_node) + if messages: + if check_edges_unchanged: + messages.insert(0, message_header.format(task_label=task_label)) + raise EdgesChangedError("\n".join(messages)) + else: + deep[task_label] = new_task_node + else: + shallow[task_label] = new_task_node + try: + if deep: + removed = self.remove_tasks(deep.keys(), drop_from_subsets=True) + self.add_task_nodes(deep.values()) + for replaced_task_node, referencing_subsets in removed: + for subset_label in referencing_subsets: + self._task_subsets[subset_label].add(replaced_task_node.label) + for task_node in shallow.values(): + self._xgraph.nodes[task_node.key]["instance"] = task_node + self._xgraph.nodes[task_node.init.key]["instance"] = task_node.init + except PipelineGraphExceptionSafetyError: # pragma: no cover + raise + except Exception as err: # pragma: no cover + # There's no known way to get here, but we want to make it clear + # it's a big problem if we do. + raise PipelineGraphExceptionSafetyError( + "Error while replacing tasks has left the graph in an inconsistent state." + ) from err + + def _append_graph_data_from_edge( + self, + node_data: list[tuple[NodeKey, dict[str, Any]]], + edge_data: list[tuple[NodeKey, NodeKey, str, dict[str, Any]]], + edge: Edge, + ) -> None: + """Append networkx state dictionaries for an edge and the corresponding + dataset type node. + + Parameters + ---------- + node_data : `list` + List of node keys and state dictionaries. A node is appended if + one does not already exist for this dataset type. + edge_data : `list` + List of node key pairs, connection names, and state dictionaries + for edges. + edge : `Edge` + New edge being processed. + """ + if (existing_dataset_type_state := self._xgraph.nodes.get(edge.dataset_type_key)) is not None: + existing_dataset_type_state["instance"] = None + else: + node_data.append( + ( + edge.dataset_type_key, + { + "instance": None, + "bipartite": NodeType.DATASET_TYPE.bipartite, + }, + ) + ) + edge_data.append( + edge.nodes + + ( + edge.connection_name, + {"instance": edge}, + ) + ) + + def _reorder(self, sorted_keys: Sequence[NodeKey]) -> None: + """Set the order of all views of this graph from the given sorted + sequence of task labels and dataset type names. + """ + self._sorted_keys = sorted_keys + self._tasks._reorder(sorted_keys) + self._dataset_types._reorder(sorted_keys) + + def _reset(self) -> None: + """Reset the all views of this graph following a modification that + might invalidate them. + """ + self._sorted_keys = None + self._tasks._reset() + self._dataset_types._reset() diff --git a/python/lsst/pipe/base/pipeline_graph/_task_subsets.py b/python/lsst/pipe/base/pipeline_graph/_task_subsets.py new file mode 100644 index 00000000..1c48ecab --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_task_subsets.py @@ -0,0 +1,122 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ("TaskSubset",) + +from collections.abc import Iterator, MutableSet + +import networkx +import networkx.algorithms.boundary + +from ._exceptions import PipelineGraphError +from ._nodes import NodeKey, NodeType + + +class TaskSubset(MutableSet[str]): + """A specialized set that represents a labeles subset of the tasks in a + pipeline graph. + + Instances of this class should never be constructed directly; they should + only be accessed via the `PipelineGraph.task_subsets` attribute and created + by the `PipelineGraph.add_task_subset` method. + + Parameters + ---------- + parent_xgraph : `networkx.DiGraph` + Parent networkx graph that this subgraph is part of. + label : `str` + Label associated with this subset of the pipeline. + members : `set` [ `str` ] + Labels of the tasks that are members of this subset. + description : `str`, optional + Description string associated with this labeled subset. + + Notes + ----- + Iteration order is arbitrary, even when the parent pipeline graph is + ordered (there is no guarantee that an ordering of the tasks in the graph + implies a consistent ordering of subsets). + """ + + def __init__( + self, + parent_xgraph: networkx.DiGraph, + label: str, + members: set[str], + description: str, + ): + self._parent_xgraph = parent_xgraph + self._label = label + self._members = members + self._description = description + + @property + def label(self) -> str: + """Label associated with this subset of the pipeline.""" + return self._label + + @property + def description(self) -> str: + """Description string associated with this labeled subset.""" + return self._description + + @description.setter + def description(self, value: str) -> None: + # Docstring in getter. + self._description = value + + def __repr__(self) -> str: + return f"{self.label}: {self.description!r}, tasks={{{', '.join(iter(self))}}}" + + def __contains__(self, key: object) -> bool: + return key in self._members + + def __len__(self) -> int: + return len(self._members) + + def __iter__(self) -> Iterator[str]: + return iter(self._members) + + def add(self, task_label: str) -> None: + """Add a new task to this subset. + + Parameters + ---------- + task_label : `str` + Label for the task. Must already be present in the parent pipeline + graph. + """ + key = NodeKey(NodeType.TASK, task_label) + if key not in self._parent_xgraph: + raise PipelineGraphError(f"{task_label!r} is not a task in the parent pipeline.") + self._members.add(key.name) + + def discard(self, task_label: str) -> None: + """Remove a task from the subset if it is present. + + Parameters + ---------- + task_label : `str` + Label for the task. Must already be present in the parent pipeline + graph. + """ + self._members.discard(task_label) diff --git a/python/lsst/pipe/base/pipeline_graph/_tasks.py b/python/lsst/pipe/base/pipeline_graph/_tasks.py new file mode 100644 index 00000000..85eb239c --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_tasks.py @@ -0,0 +1,855 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ("TaskNode", "TaskInitNode") + +import dataclasses +from collections.abc import Iterator, Mapping +from typing import TYPE_CHECKING, Any, cast + +from lsst.daf.butler import DimensionGraph, DimensionUniverse +from lsst.utils.classes import immutable +from lsst.utils.doImport import doImportType +from lsst.utils.introspection import get_full_type_name + +from .. import automatic_connection_constants as acc +from ..connections import PipelineTaskConnections +from ..connectionTypes import BaseConnection, InitOutput, Output +from ._edges import Edge, ReadEdge, WriteEdge +from ._exceptions import TaskNotImportedError, UnresolvedGraphError +from ._nodes import NodeKey, NodeType + +if TYPE_CHECKING: + from ..config import PipelineTaskConfig + from ..pipelineTask import PipelineTask + + +@dataclasses.dataclass(frozen=True) +class _TaskNodeImportedData: + """An internal struct that holds `TaskNode` and `TaskInitNode` state that + requires task classes to be imported. + """ + + task_class: type[PipelineTask] + """Type object for the task.""" + + config: PipelineTaskConfig + """Configuration object for the task.""" + + connection_map: dict[str, BaseConnection] + """Mapping from connection name to connection. + + In addition to ``connections.allConnections``, this also holds the + "automatic" config, log, and metadata connections using the names defined + in the `.automatic_connection_constants` module. + """ + + connections: PipelineTaskConnections + """Configured connections object for the task.""" + + @classmethod + def configure( + cls, + label: str, + task_class: type[PipelineTask], + config: PipelineTaskConfig, + connections: PipelineTaskConnections | None = None, + ) -> _TaskNodeImportedData: + """Construct while creating a `PipelineTaskConnections` instance if + necessary. + + Parameters + ---------- + label : `str` + Label for the task in the pipeline. Only used in error messages. + task_class : `type` [ `.PipelineTask` ] + Pipeline task `type` object. + config : `.PipelineTaskConfig` + Configuration for the task. + connections : `.PipelineTaskConnections`, optional + Object that describes the dataset types used by the task. If not + provided, one will be constructed from the given configuration. If + provided, it is assumed that ``config`` has already been validated + and frozen. + + Returns + ------- + data : `_TaskNodeImportedData` + Instance of this struct. + """ + if connections is None: + # If we don't have connections yet, assume the config hasn't been + # validated yet. + try: + config.validate() + except Exception as err: + raise ValueError( + f"Configuration validation failed for task {label!r} (see chained exception)." + ) from err + config.freeze() + connections = task_class.ConfigClass.ConnectionsClass(config=config) + connection_map = dict(connections.allConnections) + connection_map[acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME] = InitOutput( + acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label), + acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, + ) + if not config.saveMetadata: + raise ValueError(f"Metadata for task {label} cannot be disabled.") + connection_map[acc.METADATA_OUTPUT_CONNECTION_NAME] = Output( + acc.METADATA_OUTPUT_TEMPLATE.format(label=label), + acc.METADATA_OUTPUT_STORAGE_CLASS, + dimensions=set(connections.dimensions), + ) + if config.saveLogOutput: + connection_map[acc.LOG_OUTPUT_CONNECTION_NAME] = Output( + acc.LOG_OUTPUT_TEMPLATE.format(label=label), + acc.LOG_OUTPUT_STORAGE_CLASS, + dimensions=set(connections.dimensions), + ) + return cls(task_class, config, connection_map, connections) + + +@immutable +class TaskInitNode: + """A node in a pipeline graph that represents the construction of a + `PipelineTask`. + + Parameters + ---------- + inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] + Graph edges that represent inputs required just to construct an + instance of this task, keyed by connection name. + outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ] + Graph edges that represent outputs of this task that are available + after just constructing it, keyed by connection name. + + This does not include the special `config_init_output` edge; use + `iter_all_outputs` to include that, too. + config_output : `WriteEdge` + The special init output edge that persists the task's configuration. + imported_data : `_TaskNodeImportedData`, optional + Internal struct that holds information that requires the task class to + have been be imported. + task_class_name : `str`, optional + Fully-qualified name of the task class. Must be provided if + ``imported_data`` is not. + config_str : `str`, optional + Configuration for the task as a string of override statements. Must be + provided if ``imported_data`` is not. + + Notes + ----- + When included in an exported `networkx` graph (e.g. + `PipelineGraph.make_xgraph`), task initialization nodes set the following + node attributes: + + - ``task_class_name`` + - ``bipartite`` (see `NodeType.bipartite`) + - ``task_class`` (only if `is_imported` is `True`) + - ``config`` (only if `is_importd` is `True`) + """ + + def __init__( + self, + key: NodeKey, + *, + inputs: Mapping[str, ReadEdge], + outputs: Mapping[str, WriteEdge], + config_output: WriteEdge, + imported_data: _TaskNodeImportedData | None = None, + task_class_name: str | None = None, + config_str: str | None = None, + ): + self.key = key + self.inputs = inputs + self.outputs = outputs + self.config_output = config_output + # Instead of setting attributes to None, we do not set them at all; + # this works better with the @immutable decorator, which supports + # deferred initialization but not reassignment. + if task_class_name is not None: + self._task_class_name = task_class_name + if config_str is not None: + self._config_str = config_str + if imported_data is not None: + self._imported_data = imported_data + else: + assert ( + self._task_class_name is not None and self._config_str is not None + ), "If imported_data is not present, task_class_name and config_str must be." + + key: NodeKey + """Key that identifies this node in internal and exported networkx graphs. + """ + + inputs: Mapping[str, ReadEdge] + """Graph edges that represent inputs required just to construct an instance + of this task, keyed by connection name. + """ + + outputs: Mapping[str, WriteEdge] + """Graph edges that represent outputs of this task that are available after + just constructing it, keyed by connection name. + + This does not include the special `config_output` edge; use + `iter_all_outputs` to include that, too. + """ + + config_output: WriteEdge + """The special output edge that persists the task's configuration. + """ + + @property + def label(self) -> str: + """Label of this configuration of a task in the pipeline.""" + return str(self.key) + + @property + def is_imported(self) -> bool: + """Whether this the task type for this node has been imported and + its configuration overrides applied. + + If this is `False`, the `task_class` and `config` attributes may not + be accessed. + """ + return hasattr(self, "_imported_data") + + @property + def task_class(self) -> type[PipelineTask]: + """Type object for the task. + + Accessing this attribute when `is_imported` is `False` will raise + `TaskNotImportedError`, but accessing `task_class_name` will not. + """ + return self._get_imported_data().task_class + + @property + def task_class_name(self) -> str: + """The fully-qualified string name of the task class.""" + try: + return self._task_class_name + except AttributeError: + pass + self._task_class_name = get_full_type_name(self.task_class) + return self._task_class_name + + @property + def config(self) -> PipelineTaskConfig: + """Configuration for the task. + + This is always frozen. + + Accessing this attribute when `is_imported` is `False` will raise + `TaskNotImportedError`, but calling `get_config_str` will not. + """ + return self._get_imported_data().config + + def get_config_str(self) -> str: + """Return the configuration for this task as a string of override + statements. + + Returns + ------- + config_str : `str` + String containing configuration-overload statements. + """ + try: + return self._config_str + except AttributeError: + pass + self._config_str = self.config.saveToString() + return self._config_str + + def iter_all_inputs(self) -> Iterator[ReadEdge]: + """Iterate over all inputs required for construction. + + This is the same as iteration over ``inputs.values()``, but it will be + updated to include any automatic init-input connections added in the + future, while `inputs` will continue to hold only task-defined init + inputs. + """ + return iter(self.inputs.values()) + + def iter_all_outputs(self) -> Iterator[WriteEdge]: + """Iterate over all outputs available after construction, including + special ones. + """ + yield from self.outputs.values() + yield self.config_output + + def diff_edges(self, other: TaskInitNode) -> list[str]: + """Compare the edges of this task initialization node to those from the + same task label in a different pipeline. + + Parameters + ---------- + other : `TaskInitNode` + Other node to compare to. Must have the same task label, but need + not have the same configuration or even the same task class. + + Returns + ------- + differences : `list` [ `str` ] + List of string messages describing differences between ``self`` and + ``other``. Will be empty if the two nodes have the same edges. + Messages will use 'A' to refer to ``self`` and 'B' to refer to + ``other``. + """ + result = [] + result += _diff_edge_mapping(self.inputs, self.inputs, self.label, "init input") + result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "init output") + result += self.config_output.diff(other.config_output, "config init output") + return result + + def _to_xgraph_state(self) -> dict[str, Any]: + """Convert this nodes's attributes into a dictionary suitable for use + in exported networkx graphs. + """ + result = {"task_class_name": self.task_class_name, "bipartite": NodeType.TASK_INIT.bipartite} + if hasattr(self, "_imported_data"): + result["task_class"] = self.task_class + result["config"] = self.config + return result + + def _get_imported_data(self) -> _TaskNodeImportedData: + """Return the imported data struct. + + Returns + ------- + imported_data : `_TaskNodeImportedData` + Internal structure holding state that requires the task class to + have been imported. + + Raises + ------ + TaskNotImportedError + Raised if `is_imported` is `False`. + """ + try: + return self._imported_data + except AttributeError: + raise TaskNotImportedError( + f"Task class {self.task_class_name!r} for label {self.label!r} has not been imported " + "(see PipelineGraph.import_and_configure)." + ) from None + + +@immutable +class TaskNode: + """A node in a pipeline graph that represents a labeled configuration of a + `PipelineTask`. + + Parameters + ---------- + key : `NodeKey` + Identifier for this node in networkx graphs. + init : `TaskInitNode` + Node representing the initialization of this task. + prerequisite_inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] + Graph edges that represent prerequisite inputs to this task, keyed by + connection name. + + Prerequisite inputs must already exist in the data repository when a + `QuantumGraph` is built, but have more flexibility in how they are + looked up than regular inputs. + inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] + Graph edges that represent regular runtime inputs to this task, keyed + by connection name. + outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ] + Graph edges that represent regular runtime outputs of this task, keyed + by connection name. + + This does not include the special `log_output` and `metadata_output` + edges; use `iter_all_outputs` to include that, too. + log_output : `WriteEdge` or `None` + The special runtime output that persists the task's logs. + metadata_output : `WriteEdge` + The special runtime output that persists the task's metadata. + dimensions : `lsst.daf.butler.DimensionGraph` or `frozenset` + Dimensions of the task. If a `frozenset`, the dimensions have not been + resolved by a `~lsst.daf.butler.DimensionUniverse` and cannot be safely + compared to other sets of dimensions. + + Notes + ----- + Task nodes are intentionally not equality comparable, since there are many + different (and useful) ways to compare these objects with no clear winner + as the most obvious behavior. + + When included in an exported `networkx` graph (e.g. + `PipelineGraph.make_xgraph`), task nodes set the following node attributes: + + - ``task_class_name`` + - ``bipartite`` (see `NodeType.bipartite`) + - ``task_class`` (only if `is_imported` is `True`) + - ``config`` (only if `is_importd` is `True`) + """ + + def __init__( + self, + key: NodeKey, + init: TaskInitNode, + *, + prerequisite_inputs: Mapping[str, ReadEdge], + inputs: Mapping[str, ReadEdge], + outputs: Mapping[str, WriteEdge], + log_output: WriteEdge | None, + metadata_output: WriteEdge, + dimensions: DimensionGraph | frozenset, + ): + self.key = key + self.init = init + self.prerequisite_inputs = prerequisite_inputs + self.inputs = inputs + self.outputs = outputs + self.log_output = log_output + self.metadata_output = metadata_output + self._dimensions = dimensions + + @staticmethod + def _from_imported_data( + key: NodeKey, + init_key: NodeKey, + data: _TaskNodeImportedData, + universe: DimensionUniverse | None, + ) -> TaskNode: + """Construct from a `PipelineTask` type and its configuration. + + Parameters + ---------- + key : `NodeKey` + Identifier for this node in networkx graphs. + init : `TaskInitNode` + Node representing the initialization of this task. + data : `_TaskNodeImportedData` + Internal struct that holds information that requires the task class + to have been be imported. + universe : `lsst.daf.butler.DimensionUniverse` or `None` + Definitions of all dimensions. + + Returns + ------- + node : `TaskNode` + New task node. + + Raises + ------ + ValueError + Raised if configuration validation failed when constructing + ``connections``. + """ + + init_inputs = { + name: ReadEdge._from_connection_map(init_key, name, data.connection_map) + for name in data.connections.initInputs + } + prerequisite_inputs = { + name: ReadEdge._from_connection_map(key, name, data.connection_map, is_prerequisite=True) + for name in data.connections.prerequisiteInputs + } + inputs = { + name: ReadEdge._from_connection_map(key, name, data.connection_map) + for name in data.connections.inputs + } + init_outputs = { + name: WriteEdge._from_connection_map(init_key, name, data.connection_map) + for name in data.connections.initOutputs + } + outputs = { + name: WriteEdge._from_connection_map(key, name, data.connection_map) + for name in data.connections.outputs + } + init = TaskInitNode( + key=init_key, + inputs=init_inputs, + outputs=init_outputs, + config_output=WriteEdge._from_connection_map( + init_key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, data.connection_map + ), + imported_data=data, + ) + instance = TaskNode( + key=key, + init=init, + prerequisite_inputs=prerequisite_inputs, + inputs=inputs, + outputs=outputs, + log_output=( + WriteEdge._from_connection_map(key, acc.LOG_OUTPUT_CONNECTION_NAME, data.connection_map) + if data.config.saveLogOutput + else None + ), + metadata_output=WriteEdge._from_connection_map( + key, acc.METADATA_OUTPUT_CONNECTION_NAME, data.connection_map + ), + dimensions=( + frozenset(data.connections.dimensions) + if universe is None + else universe.extract(data.connections.dimensions) + ), + ) + return instance + + key: NodeKey + """Key that identifies this node in internal and exported networkx graphs. + """ + + prerequisite_inputs: Mapping[str, ReadEdge] + """Graph edges that represent prerequisite inputs to this task. + + Prerequisite inputs must already exist in the data repository when a + `QuantumGraph` is built, but have more flexibility in how they are looked + up than regular inputs. + """ + + inputs: Mapping[str, ReadEdge] + """Graph edges that represent regular runtime inputs to this task. + """ + + outputs: Mapping[str, WriteEdge] + """Graph edges that represent regular runtime outputs of this task. + + This does not include the special `log_output` and `metadata_output` edges; + use `iter_all_outputs` to include that, too. + """ + + log_output: WriteEdge | None + """The special runtime output that persists the task's logs. + """ + + metadata_output: WriteEdge + """The special runtime output that persists the task's metadata. + """ + + @property + def label(self) -> str: + """Label of this configuration of a task in the pipeline.""" + return self.key.name + + @property + def is_imported(self) -> bool: + """Whether this the task type for this node has been imported and + its configuration overrides applied. + + If this is `False`, the `task_class` and `config` attributes may not + be accessed. + """ + return self.init.is_imported + + @property + def task_class(self) -> type[PipelineTask]: + """Type object for the task. + + Accessing this attribute when `is_imported` is `False` will raise + `TaskNotImportedError`, but accessing `task_class_name` will not. + """ + return self.init.task_class + + @property + def task_class_name(self) -> str: + """The fully-qualified string name of the task class.""" + return self.init.task_class_name + + @property + def config(self) -> PipelineTaskConfig: + """Configuration for the task. + + This is always frozen. + + Accessing this attribute when `is_imported` is `False` will raise + `TaskNotImportedError`, but calling `get_config_str` will not. + """ + return self.init.config + + @property + def has_resolved_dimensions(self) -> bool: + """Whether the `dimensions` attribute may be accessed. + + If `False`, the `raw_dimensions` attribute may be used to obtain a + set of dimension names that has not been resolved by a + `~lsst.daf.butler.DimensionsUniverse`. + """ + return type(self._dimensions) is DimensionGraph + + @property + def dimensions(self) -> DimensionGraph: + """Standardized dimensions of the task.""" + if not self.has_resolved_dimensions: + raise UnresolvedGraphError(f"Dimensions for task {self.label!r} have not been resolved.") + return cast(DimensionGraph, self._dimensions) + + @property + def raw_dimensions(self) -> frozenset[str]: + """Raw dimensions of the task, with standardization by a + `~lsst.daf.butler.DimensionUniverse` not guaranteed. + """ + if self.has_resolved_dimensions: + return frozenset(cast(DimensionGraph, self._dimensions).names) + else: + return cast(frozenset[str], self._dimensions) + + def __repr__(self) -> str: + if self.has_resolved_dimensions: + return f"{self.label} ({self.task_class_name}, {self.dimensions})" + else: + return f"{self.label} ({self.task_class_name})" + + def get_config_str(self) -> str: + """Return the configuration for this task as a string of override + statements. + + Returns + ------- + config_str : `str` + String containing configuration-overload statements. + """ + return self.init.get_config_str() + + def iter_all_inputs(self) -> Iterator[ReadEdge]: + """Iterate over all runtime inputs, including both regular inputs and + prerequisites. + """ + yield from self.prerequisite_inputs.values() + yield from self.inputs.values() + + def iter_all_outputs(self) -> Iterator[WriteEdge]: + """Iterate over all runtime outputs, including special ones.""" + yield from self.outputs.values() + yield self.metadata_output + if self.log_output is not None: + yield self.log_output + + def diff_edges(self, other: TaskNode) -> list[str]: + """Compare the edges of this task node to those from the same task + label in a different pipeline. + + This also calls `TaskInitNode.diff_edges`. + + Parameters + ---------- + other : `TaskInitNode` + Other node to compare to. Must have the same task label, but need + not have the same configuration or even the same task class. + + Returns + ------- + differences : `list` [ `str` ] + List of string messages describing differences between ``self`` and + ``other``. Will be empty if the two nodes have the same edges. + Messages will use 'A' to refer to ``self`` and 'B' to refer to + ``other``. + """ + result = self.init.diff_edges(other.init) + result += _diff_edge_mapping( + self.prerequisite_inputs, other.prerequisite_inputs, self.label, "prerequisite input" + ) + result += _diff_edge_mapping(self.inputs, other.inputs, self.label, "input") + result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "output") + if self.log_output is not None: + if other.log_output is not None: + result += self.log_output.diff(other.log_output, "log output") + else: + result.append("Log output is present in A, but not in B.") + elif other.log_output is not None: + result.append("Log output is present in B, but not in A.") + result += self.metadata_output.diff(other.metadata_output, "metadata output") + return result + + def _imported_and_configured(self, rebuild: bool) -> TaskNode: + """Import the task class and use it to construct a new instance. + + Parameters + ---------- + rebuild : `bool` + If `True`, import the task class and configure its connections to + generate new edges that may differ from the current ones. If + `False`, import the task class but just update the `task_class` and + `config` attributes, and assume the edges have not changed. + + Returns + ------- + node : `TaskNode` + Task node instance for which `is_imported` is `True`. Will be + ``self`` if this is the case already. + """ + from ..pipelineTask import PipelineTask + + if self.is_imported: + return self + task_class = doImportType(self.task_class_name) + if not issubclass(task_class, PipelineTask): + raise TypeError(f"{self.task_class_name!r} is not a PipelineTask subclass.") + config = task_class.ConfigClass() + config.loadFromString(self.get_config_str()) + return self._reconfigured(config, rebuild=rebuild, task_class=task_class) + + def _reconfigured( + self, + config: PipelineTaskConfig, + rebuild: bool, + task_class: type[PipelineTask] | None = None, + ) -> TaskNode: + """Return a version of this node with new configuration. + + Parameters + ---------- + config : `.PipelineTaskConfig` + New configuration for the task. + rebuild : `bool` + If `True`, use the configured connections to generate new edges + that may differ from the current ones. If `False`, just update the + `task_class` and `config` attributes, and assume the edges have not + changed. + task_class : `type` [ `PipelineTask` ], optional + Subclass of `PipelineTask`. This defaults to ``self.task_class`, + but may be passed as an argument if that is not available because + the task class was not imported when ``self`` was constructed. + + Returns + ------- + node : `TaskNode` + Task node instance with the new config. + """ + if task_class is None: + task_class = self.task_class + imported_data = _TaskNodeImportedData.configure(self.key.name, task_class, config) + if rebuild: + return self._from_imported_data( + self.key, + self.init.key, + imported_data, + universe=self._dimensions.universe if type(self._dimensions) is DimensionGraph else None, + ) + else: + return TaskNode( + self.key, + TaskInitNode( + self.init.key, + inputs=self.init.inputs, + outputs=self.init.outputs, + config_output=self.init.config_output, + imported_data=imported_data, + ), + prerequisite_inputs=self.prerequisite_inputs, + inputs=self.inputs, + outputs=self.outputs, + log_output=self.log_output, + metadata_output=self.metadata_output, + dimensions=self._dimensions, + ) + + def _resolved(self, universe: DimensionUniverse | None) -> TaskNode: + """Return an otherwise-equivalent task node with resolved dimensions. + + Parameters + ---------- + universe : `lsst.daf.butler.DimensionUniverse` or `None` + Definitions for all dimensions. + + Returns + ------- + node : `TaskNode` + Task node instance with `dimensions` resolved by the given + universe. Will be ``self`` if this is the case already. + """ + if self.has_resolved_dimensions: + if cast(DimensionGraph, self._dimensions).universe is universe: + return self + elif universe is None: + return self + return TaskNode( + key=self.key, + init=self.init, + prerequisite_inputs=self.prerequisite_inputs, + inputs=self.inputs, + outputs=self.outputs, + log_output=self.log_output, + metadata_output=self.metadata_output, + dimensions=( + universe.extract(self.raw_dimensions) if universe is not None else self.raw_dimensions + ), + ) + + def _to_xgraph_state(self) -> dict[str, Any]: + """Convert this nodes's attributes into a dictionary suitable for use + in exported networkx graphs. + """ + result = self.init._to_xgraph_state() + if self.has_resolved_dimensions: + result["dimensions"] = self._dimensions + result["raw_dimensions"] = self.raw_dimensions + return result + + def _get_imported_data(self) -> _TaskNodeImportedData: + """Return the imported data struct. + + Returns + ------- + imported_data : `_TaskNodeImportedData` + Internal structure holding state that requires the task class to + have been imported. + + Raises + ------ + TaskNotImportedError + Raised if `is_imported` is `False`. + """ + return self.init._get_imported_data() + + +def _diff_edge_mapping( + a_mapping: Mapping[str, Edge], b_mapping: Mapping[str, Edge], task_label: str, connection_type: str +) -> list[str]: + """Compare a pair of mappings of edges. + + Parameters + ---------- + a_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ] + First mapping to compare. Expected to have connection names as keys. + b_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ] + First mapping to compare. If keys differ from those of ``a_mapping``, + this will be reported as a difference (in addition to element-wise + comparisons). + task_label : `str` + Task label associated with both mappings. + connection_type : `str` + Type of connection (e.g. "input" or "init output") associated with both + connections. This is a human-readable string to include in difference + messages. + """ + results = [] + b_to_do = set(b_mapping.keys()) + for connection_name, a_edge in a_mapping.items(): + if (b_edge := b_mapping.get(connection_name)) is None: + results.append( + f"{connection_type.capitalize()} {connection_name!r} of task " + f"{task_label!r} exists in A, but not in B (or it may have a different connection type)." + ) + else: + results.extend(a_edge.diff(b_edge, connection_type)) + b_to_do.discard(connection_name) + for connection_name in b_to_do: + results.append( + f"{connection_type.capitalize()} {connection_name!r} of task " + f"{task_label!r} exists in A, but not in B (or it may have a different connection type)." + ) + return results diff --git a/python/lsst/pipe/base/pipeline_graph/io.py b/python/lsst/pipe/base/pipeline_graph/io.py new file mode 100644 index 00000000..09e52df5 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/io.py @@ -0,0 +1,578 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ( + "expect_not_none", + "SerializedEdge", + "SerializedTaskInitNode", + "SerializedTaskNode", + "SerializedDatasetTypeNode", + "SerializedTaskSubset", + "SerializedPipelineGraph", +) + +from typing import Any, TypeVar + +import networkx +import pydantic +from lsst.daf.butler import DatasetType, DimensionConfig, DimensionGraph, DimensionUniverse + +from .. import automatic_connection_constants as acc +from ._dataset_types import DatasetTypeNode +from ._edges import Edge, ReadEdge, WriteEdge +from ._exceptions import PipelineGraphReadError +from ._nodes import NodeKey, NodeType +from ._pipeline_graph import PipelineGraph +from ._task_subsets import TaskSubset +from ._tasks import TaskInitNode, TaskNode + +_U = TypeVar("_U") + +_IO_VERSION_INFO = (0, 0, 1) +"""Version tuple embedded in saved PipelineGraphs. +""" + + +def expect_not_none(value: _U | None, msg: str) -> _U: + """Check that a value is not `None` and return it. + + Parameters + ---------- + value + Value to check + msg + Error message for the case where ``value is None``. + + Returns + ------- + value + Value, guaranteed not to be `None`. + + Raises + ------ + PipelineGraphReadError + Raised with ``msg`` if ``value is None``. + """ + if value is None: + raise PipelineGraphReadError(msg) + return value + + +class SerializedEdge(pydantic.BaseModel): + """Struct used to represent a serialized `Edge` in a `PipelineGraph`. + + All `ReadEdge` and `WriteEdge` state not included here is instead + effectively serialized by the context in which a `SerializedEdge` appears + (e.g. the keys of the nested dictionaries in which it serves as the value + type). + """ + + dataset_type_name: str + """Full dataset type name (including component).""" + + storage_class: str + """Name of the storage class.""" + + raw_dimensions: list[str] + """Raw dimensions of the dataset type from the task connections.""" + + is_calibration: bool = False + """Whether this dataset type can be included in + `~lsst.daf.butler.CollectionType.CALIBRATION` collections.""" + + defer_query_constraint: bool = False + """If `True`, by default do not include this dataset type's existence as a + constraint on the initial data ID query in QuantumGraph generation.""" + + @classmethod + def serialize(cls, target: Edge) -> SerializedEdge: + """Transform an `Edge` to a `SerializedEdge`.""" + return SerializedEdge.construct( + storage_class=target.storage_class_name, + dataset_type_name=target.dataset_type_name, + raw_dimensions=sorted(target.raw_dimensions), + is_calibration=target.is_calibration, + defer_query_constraint=getattr(target, "defer_query_constraint", False), + ) + + def deserialize_read_edge( + self, + task_key: NodeKey, + connection_name: str, + is_prerequisite: bool = False, + ) -> ReadEdge: + """Transform a `SerializedEdge` to a `ReadEdge`.""" + parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(self.dataset_type_name) + dataset_type_key = NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name) + return ReadEdge( + dataset_type_key, + task_key, + storage_class_name=self.storage_class, + is_prerequisite=is_prerequisite, + component=component, + connection_name=connection_name, + is_calibration=self.is_calibration, + defer_query_constraint=self.defer_query_constraint, + raw_dimensions=frozenset(self.raw_dimensions), + ) + + def deserialize_write_edge( + self, + task_key: NodeKey, + connection_name: str, + ) -> WriteEdge: + """Transform a `SerializedEdge` to a `WriteEdge`.""" + dataset_type_key = NodeKey(NodeType.DATASET_TYPE, self.dataset_type_name) + return WriteEdge( + task_key=task_key, + dataset_type_key=dataset_type_key, + storage_class_name=self.storage_class, + connection_name=connection_name, + is_calibration=self.is_calibration, + raw_dimensions=frozenset(self.raw_dimensions), + ) + + +class SerializedTaskInitNode(pydantic.BaseModel): + """Struct used to represent a serialized `TaskInitNode` in a + `PipelineGraph`. + + The task label is serialized by the context in which a + `SerializedTaskInitNode` appears (e.g. the keys of the nested dictionary + in which it serves as the value type), and the task class name and config + string are save with the corresponding `SerializedTaskNode`. + """ + + inputs: dict[str, SerializedEdge] + """Mapping of serialized init-input edges, keyed by connection name.""" + + outputs: dict[str, SerializedEdge] + """Mapping of serialized init-output edges, keyed by connection name.""" + + config_output: SerializedEdge + """The serialized config init-output edge.""" + + index: int | None = None + """The index of this node in the sorted sequence of `PipelineGraph`. + + This is `None` if the `PipelineGraph` was not sorted when it was + serialized. + """ + + @classmethod + def serialize(cls, target: TaskInitNode) -> SerializedTaskInitNode: + """Transform a `TaskInitNode` to a `SerializedTaskInitNode`.""" + return cls.construct( + inputs={ + connection_name: SerializedEdge.serialize(edge) + for connection_name, edge in sorted(target.inputs.items()) + }, + outputs={ + connection_name: SerializedEdge.serialize(edge) + for connection_name, edge in sorted(target.outputs.items()) + }, + config_output=SerializedEdge.serialize(target.config_output), + ) + + def deserialize( + self, + key: NodeKey, + task_class_name: str, + config_str: str, + ) -> TaskInitNode: + """Transform a `SerializedTaskInitNode` to a `TaskInitNode`.""" + return TaskInitNode( + key, + inputs={ + connection_name: serialized_edge.deserialize_read_edge(key, connection_name) + for connection_name, serialized_edge in self.inputs.items() + }, + outputs={ + connection_name: serialized_edge.deserialize_write_edge(key, connection_name) + for connection_name, serialized_edge in self.outputs.items() + }, + config_output=self.config_output.deserialize_write_edge( + key, + acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, + ), + task_class_name=task_class_name, + config_str=config_str, + ) + + +class SerializedTaskNode(pydantic.BaseModel): + """Struct used to represent a serialized `TaskNode` in a `PipelineGraph`. + + The task label is serialized by the context in which a + `SerializedTaskNode` appears (e.g. the keys of the nested dictionary in + which it serves as the value type). + """ + + task_class: str + """Fully-qualified name of the task class.""" + + init: SerializedTaskInitNode + """Serialized task initialization node.""" + + config_str: str + """Configuration for the task as a string of override statements.""" + + prerequisite_inputs: dict[str, SerializedEdge] + """Mapping of serialized prerequisiste input edges, keyed by connection + name. + """ + + inputs: dict[str, SerializedEdge] + """Mapping of serialized input edges, keyed by connection name.""" + + outputs: dict[str, SerializedEdge] + """Mapping of serialized output edges, keyed by connection name.""" + + metadata_output: SerializedEdge + """The serialized metadata output edge.""" + + dimensions: list[str] + """The task's dimensions, if they were resolved.""" + + log_output: SerializedEdge | None = None + """The serialized log output edge.""" + + index: int | None = None + """The index of this node in the sorted sequence of `PipelineGraph`. + + This is `None` if the `PipelineGraph` was not sorted when it was + serialized. + """ + + @classmethod + def serialize(cls, target: TaskNode) -> SerializedTaskNode: + """Transform a `TaskNode` to a `SerializedTaskNode`.""" + return cls.construct( + task_class=target.task_class_name, + init=SerializedTaskInitNode.serialize(target.init), + config_str=target.get_config_str(), + dimensions=list(target.raw_dimensions), + prerequisite_inputs={ + connection_name: SerializedEdge.serialize(edge) + for connection_name, edge in sorted(target.prerequisite_inputs.items()) + }, + inputs={ + connection_name: SerializedEdge.serialize(edge) + for connection_name, edge in sorted(target.inputs.items()) + }, + outputs={ + connection_name: SerializedEdge.serialize(edge) + for connection_name, edge in sorted(target.outputs.items()) + }, + metadata_output=SerializedEdge.serialize(target.metadata_output), + log_output=( + SerializedEdge.serialize(target.log_output) if target.log_output is not None else None + ), + ) + + def deserialize(self, key: NodeKey, init_key: NodeKey, universe: DimensionUniverse | None) -> TaskNode: + """Transform a `SerializedTaskNode` to a `TaskNode`.""" + init = self.init.deserialize( + init_key, + task_class_name=self.task_class, + config_str=expect_not_none( + self.config_str, f"No serialized config file for task with label {key.name!r}." + ), + ) + inputs = { + connection_name: serialized_edge.deserialize_read_edge(key, connection_name) + for connection_name, serialized_edge in self.inputs.items() + } + prerequisite_inputs = { + connection_name: serialized_edge.deserialize_read_edge(key, connection_name, is_prerequisite=True) + for connection_name, serialized_edge in self.prerequisite_inputs.items() + } + outputs = { + connection_name: serialized_edge.deserialize_write_edge(key, connection_name) + for connection_name, serialized_edge in self.outputs.items() + } + if (serialized_log_output := self.log_output) is not None: + log_output = serialized_log_output.deserialize_write_edge(key, acc.LOG_OUTPUT_CONNECTION_NAME) + else: + log_output = None + metadata_output = self.metadata_output.deserialize_write_edge( + key, acc.METADATA_OUTPUT_CONNECTION_NAME + ) + dimensions: frozenset[str] | DimensionGraph + if universe is not None: + dimensions = universe.extract(self.dimensions) + else: + dimensions = frozenset(self.dimensions) + return TaskNode( + key=key, + init=init, + inputs=inputs, + prerequisite_inputs=prerequisite_inputs, + outputs=outputs, + log_output=log_output, + metadata_output=metadata_output, + dimensions=dimensions, + ) + + +class SerializedDatasetTypeNode(pydantic.BaseModel): + """Struct used to represent a serialized `DatasetTypeNode` in a + `PipelineGraph`. + + Unresolved dataset types are serialized as instances with at most the + `index` attribute set, and are typically converted to JSON with pydantic's + ``exclude_defaults=True`` option to keep this compact. + + The dataset typename is serialized by the context in which a + `SerializedDatasetTypeNode` appears (e.g. the keys of the nested dictionary + in which it serves as the value type). + """ + + dimensions: list[str] | None = None + """Dimensions of the dataset type.""" + + storage_class: str | None = None + """Name of the storage class.""" + + is_calibration: bool = False + """Whether this dataset type is a calibration.""" + + is_initial_query_constraint: bool = False + """Whether this dataset type should be a query constraint during + `QuantumGraph` generation.""" + + is_prerequisite: bool = False + """Whether datasets of this dataset type must exist in the input collection + before `QuantumGraph` generation.""" + + index: int | None = None + """The index of this node in the sorted sequence of `PipelineGraph`. + + This is `None` if the `PipelineGraph` was not sorted when it was + serialized. + """ + + @classmethod + def serialize(cls, target: DatasetTypeNode | None) -> SerializedDatasetTypeNode: + """Transform a `DatasetTypeNode` to a `SerializedDatasetTypeNode`.""" + if target is None: + return cls.construct() + return cls.construct( + dimensions=list(target.dataset_type.dimensions.names), + storage_class=target.dataset_type.storageClass_name, + is_calibration=target.dataset_type.isCalibration(), + is_initial_query_constraint=target.is_initial_query_constraint, + is_prerequisite=target.is_prerequisite, + ) + + def deserialize(self, key: NodeKey, universe: DimensionUniverse | None) -> DatasetTypeNode | None: + """Transform a `SerializedDatasetTypeNode` to a `DatasetTypeNode`.""" + if self.dimensions is not None: + dataset_type = DatasetType( + key.name, + expect_not_none( + self.dimensions, + f"Serialized dataset type {key.name!r} has no dimensions.", + ), + storageClass=expect_not_none( + self.storage_class, + f"Serialized dataset type {key.name!r} has no storage class.", + ), + isCalibration=self.is_calibration, + universe=expect_not_none( + universe, + f"Serialized dataset type {key.name!r} has dimensions, " + "but no dimension universe was stored.", + ), + ) + return DatasetTypeNode( + dataset_type=dataset_type, + is_prerequisite=self.is_prerequisite, + is_initial_query_constraint=self.is_initial_query_constraint, + ) + return None + + +class SerializedTaskSubset(pydantic.BaseModel): + """Struct used to represent a serialized `TaskSubset` in a `PipelineGraph`. + + The subsetlabel is serialized by the context in which a + `SerializedDatasetTypeNode` appears (e.g. the keys of the nested dictionary + in which it serves as the value type). + """ + + description: str + """Description of the subset.""" + + tasks: list[str] + """Labels of tasks in the subset, sorted lexicographically for + determinism. + """ + + @classmethod + def serialize(cls, target: TaskSubset) -> SerializedTaskSubset: + """Transform a `TaskSubset` into a `SerializedTaskSubset`.""" + return cls.construct(description=target._description, tasks=list(sorted(target))) + + def deserialize_task_subset(self, label: str, xgraph: networkx.MultiDiGraph) -> TaskSubset: + """Transform a `SerializedTaskSubset` into a `TaskSubset`.""" + members = set(self.tasks) + return TaskSubset(xgraph, label, members, self.description) + + +class SerializedPipelineGraph(pydantic.BaseModel): + """Struct used to represent a serialized `PipelineGraph`.""" + + version: str = ".".join(str(v) for v in _IO_VERSION_INFO) + """Serialization version.""" + + description: str + """Human-readable description of the pipeline.""" + + tasks: dict[str, SerializedTaskNode] = pydantic.Field(default_factory=dict) + """Mapping of serialized tasks, keyed by label.""" + + dataset_types: dict[str, SerializedDatasetTypeNode] = pydantic.Field(default_factory=dict) + """Mapping of serialized dataset types, keyed by parent dataset type name. + """ + + task_subsets: dict[str, SerializedTaskSubset] = pydantic.Field(default_factory=dict) + """Mapping of task subsets, keyed by subset label.""" + + dimensions: dict[str, Any] | None = None + """Dimension universe configuration.""" + + data_id: dict[str, Any] = pydantic.Field(default_factory=dict) + """Data ID that constrains all quanta generated from this pipeline.""" + + @classmethod + def serialize(cls, target: PipelineGraph) -> SerializedPipelineGraph: + """Transform a `PipelineGraph` into a `SerializedPipelineGraph`.""" + result = SerializedPipelineGraph.construct( + description=target.description, + tasks={label: SerializedTaskNode.serialize(node) for label, node in target.tasks.items()}, + dataset_types={ + name: SerializedDatasetTypeNode().serialize(target.dataset_types.get_if_resolved(name)) + for name in target.dataset_types.keys() + }, + task_subsets={ + label: SerializedTaskSubset.serialize(subset) for label, subset in target.task_subsets.items() + }, + dimensions=target.universe.dimensionConfig.toDict() if target.universe is not None else None, + data_id=target._raw_data_id, + ) + if target._sorted_keys: + for index, node_key in enumerate(target._sorted_keys): + match node_key.node_type: + case NodeType.TASK: + result.tasks[node_key.name].index = index + case NodeType.DATASET_TYPE: + result.dataset_types[node_key.name].index = index + case NodeType.TASK_INIT: + result.tasks[node_key.name].init.index = index + return result + + def deserialize( + self, + import_and_configure: bool = True, + check_edges_unchanged: bool = False, + assume_edges_unchanged: bool = False, + ) -> PipelineGraph: + """Transform a `SerializedPipelineGraph` into a `PipelineGraph`.""" + universe: DimensionUniverse | None = None + if self.dimensions is not None: + universe = DimensionUniverse( + config=DimensionConfig( + expect_not_none( + self.dimensions, + "Serialized pipeline graph has not been resolved; " + "load it is a MutablePipelineGraph instead.", + ) + ) + ) + xgraph = networkx.MultiDiGraph() + sort_index_map: dict[int, NodeKey] = {} + for dataset_type_name, serialized_dataset_type in self.dataset_types.items(): + dataset_type_key = NodeKey(NodeType.DATASET_TYPE, dataset_type_name) + dataset_type_node = serialized_dataset_type.deserialize(dataset_type_key, universe) + xgraph.add_node( + dataset_type_key, instance=dataset_type_node, bipartite=NodeType.DATASET_TYPE.value + ) + if serialized_dataset_type.index is not None: + sort_index_map[serialized_dataset_type.index] = dataset_type_key + for task_label, serialized_task in self.tasks.items(): + task_key = NodeKey(NodeType.TASK, task_label) + task_init_key = NodeKey(NodeType.TASK_INIT, task_label) + task_node = serialized_task.deserialize(task_key, task_init_key, universe) + if serialized_task.index is not None: + sort_index_map[serialized_task.index] = task_key + if serialized_task.init.index is not None: + sort_index_map[serialized_task.init.index] = task_init_key + xgraph.add_node(task_key, instance=task_node, bipartite=NodeType.TASK.bipartite) + xgraph.add_node(task_init_key, instance=task_node.init, bipartite=NodeType.TASK_INIT.bipartite) + xgraph.add_edge(task_init_key, task_key, Edge.INIT_TO_TASK_NAME, instance=None) + for read_edge in task_node.init.iter_all_inputs(): + xgraph.add_edge( + read_edge.dataset_type_key, + read_edge.task_key, + read_edge.connection_name, + instance=read_edge, + ) + for write_edge in task_node.init.iter_all_outputs(): + xgraph.add_edge( + write_edge.task_key, + write_edge.dataset_type_key, + write_edge.connection_name, + instance=write_edge, + ) + for read_edge in task_node.iter_all_inputs(): + xgraph.add_edge( + read_edge.dataset_type_key, + read_edge.task_key, + read_edge.connection_name, + instance=read_edge, + ) + for write_edge in task_node.iter_all_outputs(): + xgraph.add_edge( + write_edge.task_key, + write_edge.dataset_type_key, + write_edge.connection_name, + instance=write_edge, + ) + result = PipelineGraph.__new__(PipelineGraph) + result._init_from_args( + xgraph, + sorted_keys=[sort_index_map[i] for i in range(len(xgraph))] if sort_index_map else None, + task_subsets={ + subset_label: serialized_subset.deserialize_task_subset(subset_label, xgraph) + for subset_label, serialized_subset in self.task_subsets.items() + }, + description=self.description, + universe=universe, + data_id=self.data_id, + ) + if import_and_configure: + result.import_and_configure( + check_edges_unchanged=check_edges_unchanged, + assume_edges_unchanged=assume_edges_unchanged, + ) + return result diff --git a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py index 217e1c8a..d7b936e7 100644 --- a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py +++ b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py @@ -371,7 +371,6 @@ def make_connection(self, cls: type[_T]) -> _T: return cls( name=self.dataset_type_name, storageClass=storage_class, - isCalibration=self.is_calibration, multiple=self.multiple, ) diff --git a/tests/test_pipeline_graph.py b/tests/test_pipeline_graph.py new file mode 100644 index 00000000..0e402092 --- /dev/null +++ b/tests/test_pipeline_graph.py @@ -0,0 +1,1255 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Tests of things related to the GraphBuilder class.""" + +import copy +import io +import logging +import unittest +from typing import Any + +import lsst.pipe.base.automatic_connection_constants as acc +import lsst.utils.tests +from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse +from lsst.daf.butler.registry import MissingDatasetTypeError +from lsst.pipe.base.pipeline_graph import ( + ConnectionTypeConsistencyError, + DuplicateOutputError, + Edge, + EdgesChangedError, + IncompatibleDatasetTypeError, + NodeKey, + NodeType, + PipelineGraph, + PipelineGraphError, + UnresolvedGraphError, +) +from lsst.pipe.base.tests.mocks import ( + DynamicConnectionConfig, + DynamicTestPipelineTask, + DynamicTestPipelineTaskConfig, + get_mock_name, +) + +_LOG = logging.getLogger(__name__) + + +class MockRegistry: + """A test-utility stand-in for lsst.daf.butler.Registry that just knows + how to get dataset types. + """ + + def __init__(self, dimensions: DimensionUniverse, dataset_types: dict[str, DatasetType]) -> None: + self.dimensions = dimensions + self._dataset_types = dataset_types + + def getDatasetType(self, name: str) -> DatasetType: + try: + return self._dataset_types[name] + except KeyError: + raise MissingDatasetTypeError(name) + + +class PipelineGraphTestCase(unittest.TestCase): + """Tests for the `PipelineGraph` class. + + Tests for `PipelineGraph.resolve` are mostly in + `PipelineGraphResolveTestCase` later in this file. + """ + + def setUp(self) -> None: + # Simple test pipeline has two tasks, 'a' and 'b', with dataset types + # 'input', 'intermediate', and 'output'. There are no dimensions on + # any of those. We add tasks in reverse order to better test sorting. + # There is one labeled task subset, 'only_b', with just 'b' in it. + # We copy the configs so the originals (the instance attributes) can + # be modified and reused after the ones passed in to the graph are + # frozen. + self.description = "A pipeline for PipelineGraph unit tests." + self.graph = PipelineGraph() + self.graph.description = self.description + self.b_config = DynamicTestPipelineTaskConfig() + self.b_config.init_inputs["in_schema"] = DynamicConnectionConfig(dataset_type_name="schema") + self.b_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1") + self.b_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="output_1") + self.graph.add_task("b", DynamicTestPipelineTask, copy.deepcopy(self.b_config)) + self.a_config = DynamicTestPipelineTaskConfig() + self.a_config.init_outputs["out_schema"] = DynamicConnectionConfig(dataset_type_name="schema") + self.a_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="input_1") + self.a_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1") + self.graph.add_task("a", DynamicTestPipelineTask, copy.deepcopy(self.a_config)) + self.graph.add_task_subset("only_b", ["b"]) + self.subset_description = "A subset with only task B in it." + self.graph.task_subsets["only_b"].description = self.subset_description + self.dimensions = DimensionUniverse() + self.maxDiff = None + + def test_unresolved_accessors(self) -> None: + """Test attribute accessors, iteration, and simple methods on a graph + that has not had `PipelineGraph.resolve` called on it.""" + self.check_base_accessors(self.graph) + self.assertEqual( + repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask)" + ) + + def test_sorting(self) -> None: + """Test sort methods on PipelineGraph.""" + self.assertFalse(self.graph.has_been_sorted) + self.assertFalse(self.graph.is_sorted) + self.graph.sort() + self.check_sorted(self.graph) + + def test_unresolved_xgraph_export(self) -> None: + """Test exporting an unresolved PipelineGraph to networkx in various + ways.""" + self.check_make_xgraph(self.graph, resolved=False) + self.check_make_bipartite_xgraph(self.graph, resolved=False) + self.check_make_task_xgraph(self.graph, resolved=False) + self.check_make_dataset_type_xgraph(self.graph, resolved=False) + + def test_unresolved_stream_io(self) -> None: + """Test round-tripping an unresolved PipelineGraph through in-memory + serialization. + """ + stream = io.BytesIO() + self.graph.write_stream(stream) + stream.seek(0) + roundtripped = PipelineGraph.read_stream(stream) + self.check_make_xgraph(roundtripped, resolved=False) + + def test_unresolved_file_io(self) -> None: + """Test round-tripping an unresolved PipelineGraph through file + serialization. + """ + with lsst.utils.tests.getTempFilePath(".json.gz") as filename: + self.graph.write_uri(filename) + roundtripped = PipelineGraph.read_uri(filename) + self.check_make_xgraph(roundtripped, resolved=False) + + def test_unresolved_deferred_import_io(self) -> None: + """Test round-tripping an unresolved PipelineGraph through + serialization, without immediately importing tasks on read. + """ + stream = io.BytesIO() + self.graph.write_stream(stream) + stream.seek(0) + roundtripped = PipelineGraph.read_stream(stream, import_and_configure=False) + self.check_make_xgraph(roundtripped, resolved=False, imported_and_configured=False) + # Check that we can still resolve the graph without importing tasks. + roundtripped.resolve(MockRegistry(self.dimensions, {})) + self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) + roundtripped.import_and_configure(assume_edges_unchanged=True) + self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) + + def test_resolved_accessors(self) -> None: + """Test attribute accessors, iteration, and simple methods on a graph + that has had `PipelineGraph.resolve` called on it. + + This includes the accessors available on unresolved graphs as well as + new ones, and we expect the resolved graph to be sorted as well. + """ + self.graph.resolve(MockRegistry(self.dimensions, {})) + self.check_base_accessors(self.graph) + self.check_sorted(self.graph) + self.assertEqual( + repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask, {})" + ) + self.assertEqual(self.graph.tasks["a"].dimensions, self.dimensions.empty) + self.assertEqual(repr(self.graph.dataset_types["input_1"]), "input_1 (_mock_StructuredDataDict, {})") + self.assertEqual(self.graph.dataset_types["input_1"].key, NodeKey(NodeType.DATASET_TYPE, "input_1")) + self.assertEqual(self.graph.dataset_types["input_1"].dimensions, self.dimensions.empty) + self.assertEqual(self.graph.dataset_types["input_1"].storage_class_name, "_mock_StructuredDataDict") + self.assertEqual(self.graph.dataset_types["input_1"].storage_class.name, "_mock_StructuredDataDict") + + def test_resolved_xgraph_export(self) -> None: + """Test exporting a resolved PipelineGraph to networkx in various + ways.""" + self.graph.resolve(MockRegistry(self.dimensions, {})) + self.check_make_xgraph(self.graph, resolved=True) + self.check_make_bipartite_xgraph(self.graph, resolved=True) + self.check_make_task_xgraph(self.graph, resolved=True) + self.check_make_dataset_type_xgraph(self.graph, resolved=True) + + def test_resolved_stream_io(self) -> None: + """Test round-tripping a resolved PipelineGraph through in-memory + serialization. + """ + self.graph.resolve(MockRegistry(self.dimensions, {})) + stream = io.BytesIO() + self.graph.write_stream(stream) + stream.seek(0) + roundtripped = PipelineGraph.read_stream(stream) + self.check_make_xgraph(roundtripped, resolved=True) + + def test_resolved_file_io(self) -> None: + """Test round-tripping a resolved PipelineGraph through file + serialization. + """ + self.graph.resolve(MockRegistry(self.dimensions, {})) + with lsst.utils.tests.getTempFilePath(".json.gz") as filename: + self.graph.write_uri(filename) + roundtripped = PipelineGraph.read_uri(filename) + self.check_make_xgraph(roundtripped, resolved=True) + + def test_resolved_deferred_import_io(self) -> None: + """Test round-tripping a resolved PipelineGraph through serialization, + without immediately importing tasks on read. + """ + self.graph.resolve(MockRegistry(self.dimensions, {})) + stream = io.BytesIO() + self.graph.write_stream(stream) + stream.seek(0) + roundtripped = PipelineGraph.read_stream(stream, import_and_configure=False) + self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False) + roundtripped.import_and_configure(check_edges_unchanged=True) + self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True) + + def test_unresolved_copies(self) -> None: + """Test making copies of an unresolved PipelineGraph.""" + copy1 = self.graph.copy() + self.assertIsNot(copy1, self.graph) + self.check_make_xgraph(copy1, resolved=False) + copy2 = copy.copy(self.graph) + self.assertIsNot(copy2, self.graph) + self.check_make_xgraph(copy2, resolved=False) + copy3 = copy.deepcopy(self.graph) + self.assertIsNot(copy3, self.graph) + self.check_make_xgraph(copy3, resolved=False) + + def test_resolved_copies(self) -> None: + """Test making copies of a resolved PipelineGraph.""" + self.graph.resolve(MockRegistry(self.dimensions, {})) + copy1 = self.graph.copy() + self.assertIsNot(copy1, self.graph) + self.check_make_xgraph(copy1, resolved=True) + copy2 = copy.copy(self.graph) + self.assertIsNot(copy2, self.graph) + self.check_make_xgraph(copy2, resolved=True) + copy3 = copy.deepcopy(self.graph) + self.assertIsNot(copy3, self.graph) + self.check_make_xgraph(copy3, resolved=True) + + def check_base_accessors(self, graph: PipelineGraph) -> None: + """Implementation for test methods that check attribute access, + iteration, and simple methods. + + The given graph must be unchanged from the one defined in `setUp`, + other than sorting. + """ + self.assertEqual(graph.description, self.description) + self.assertEqual(graph.tasks.keys(), {"a", "b"}) + self.assertEqual( + graph.dataset_types.keys(), + { + "schema", + "input_1", + "intermediate_1", + "output_1", + "a_config", + "a_log", + "a_metadata", + "b_config", + "b_log", + "b_metadata", + }, + ) + self.assertEqual(graph.task_subsets.keys(), {"only_b"}) + self.assertEqual( + {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=False)}, + { + ( + NodeKey(NodeType.DATASET_TYPE, "input_1"), + NodeKey(NodeType.TASK, "a"), + "input_1 -> a (input1)", + ), + ( + NodeKey(NodeType.TASK, "a"), + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), + "a -> intermediate_1 (output1)", + ), + ( + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), + NodeKey(NodeType.TASK, "b"), + "intermediate_1 -> b (input1)", + ), + ( + NodeKey(NodeType.TASK, "b"), + NodeKey(NodeType.DATASET_TYPE, "output_1"), + "b -> output_1 (output1)", + ), + (NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.DATASET_TYPE, "a_log"), "a -> a_log (_log)"), + ( + NodeKey(NodeType.TASK, "a"), + NodeKey(NodeType.DATASET_TYPE, "a_metadata"), + "a -> a_metadata (_metadata)", + ), + (NodeKey(NodeType.TASK, "b"), NodeKey(NodeType.DATASET_TYPE, "b_log"), "b -> b_log (_log)"), + ( + NodeKey(NodeType.TASK, "b"), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"), + "b -> b_metadata (_metadata)", + ), + }, + ) + self.assertEqual( + {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=True)}, + { + ( + NodeKey(NodeType.TASK_INIT, "a"), + NodeKey(NodeType.DATASET_TYPE, "schema"), + "a -> schema (out_schema)", + ), + ( + NodeKey(NodeType.DATASET_TYPE, "schema"), + NodeKey(NodeType.TASK_INIT, "b"), + "schema -> b (in_schema)", + ), + ( + NodeKey(NodeType.TASK_INIT, "a"), + NodeKey(NodeType.DATASET_TYPE, "a_config"), + "a -> a_config (_config)", + ), + ( + NodeKey(NodeType.TASK_INIT, "b"), + NodeKey(NodeType.DATASET_TYPE, "b_config"), + "b -> b_config (_config)", + ), + }, + ) + self.assertEqual( + {(node_type, name) for node_type, name, _ in graph.iter_nodes()}, + { + NodeKey(NodeType.TASK, "a"), + NodeKey(NodeType.TASK, "b"), + NodeKey(NodeType.TASK_INIT, "a"), + NodeKey(NodeType.TASK_INIT, "b"), + NodeKey(NodeType.DATASET_TYPE, "schema"), + NodeKey(NodeType.DATASET_TYPE, "input_1"), + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), + NodeKey(NodeType.DATASET_TYPE, "output_1"), + NodeKey(NodeType.DATASET_TYPE, "a_config"), + NodeKey(NodeType.DATASET_TYPE, "a_log"), + NodeKey(NodeType.DATASET_TYPE, "a_metadata"), + NodeKey(NodeType.DATASET_TYPE, "b_config"), + NodeKey(NodeType.DATASET_TYPE, "b_log"), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"), + }, + ) + self.assertEqual({name for name, _ in graph.iter_overall_inputs()}, {"input_1"}) + self.assertEqual({edge.task_label for edge in graph.consumers_of("input_1")}, {"a"}) + self.assertEqual({edge.task_label for edge in graph.consumers_of("intermediate_1")}, {"b"}) + self.assertEqual({edge.task_label for edge in graph.consumers_of("output_1")}, set()) + self.assertIsNone(graph.producer_of("input_1")) + self.assertEqual(graph.producer_of("intermediate_1").task_label, "a") + self.assertEqual(graph.producer_of("output_1").task_label, "b") + self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks=")) + self.assertEqual( + repr(graph.task_subsets["only_b"]), f"only_b: {self.subset_description!r}, tasks={{b}}" + ) + + def check_sorted(self, graph: PipelineGraph) -> None: + """Run a battery of tests on a PipelineGraph that must be + deterministically sorted. + + The given graph must be unchanged from the one defined in `setUp`, + other than sorting. + """ + self.assertTrue(graph.has_been_sorted) + self.assertTrue(graph.is_sorted) + self.assertEqual( + [(node_type, name) for node_type, name, _ in graph.iter_nodes()], + [ + # We only advertise that the order is topological and + # deterministic, so this test is slightly over-specified; there + # are other orders that are consistent with our guarantees. + NodeKey(NodeType.DATASET_TYPE, "input_1"), + NodeKey(NodeType.TASK_INIT, "a"), + NodeKey(NodeType.DATASET_TYPE, "a_config"), + NodeKey(NodeType.DATASET_TYPE, "schema"), + NodeKey(NodeType.TASK_INIT, "b"), + NodeKey(NodeType.DATASET_TYPE, "b_config"), + NodeKey(NodeType.TASK, "a"), + NodeKey(NodeType.DATASET_TYPE, "a_log"), + NodeKey(NodeType.DATASET_TYPE, "a_metadata"), + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), + NodeKey(NodeType.TASK, "b"), + NodeKey(NodeType.DATASET_TYPE, "b_log"), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"), + NodeKey(NodeType.DATASET_TYPE, "output_1"), + ], + ) + # Most users should only care that the tasks and dataset types are + # topologically sorted. + self.assertEqual(list(graph.tasks), ["a", "b"]) + self.assertEqual( + list(graph.dataset_types), + [ + "input_1", + "a_config", + "schema", + "b_config", + "a_log", + "a_metadata", + "intermediate_1", + "b_log", + "b_metadata", + "output_1", + ], + ) + # __str__ and __repr__ of course work on unsorted mapping views, too, + # but the order of elements is then nondeterministic and hard to test. + self.assertEqual(repr(self.graph.tasks), "TaskMappingView({a, b})") + self.assertEqual( + repr(self.graph.dataset_types), + ( + "DatasetTypeMappingView({input_1, a_config, schema, b_config, a_log, a_metadata, " + "intermediate_1, b_log, b_metadata, output_1})" + ), + ) + + def check_make_xgraph( + self, graph: PipelineGraph, resolved: bool, imported_and_configured: bool = True + ) -> None: + """Check that the given graph exports as expected to networkx. + + The given graph must be unchanged from the one defined in `setUp`, + other than being resolved (if ``resolved=True``) or round-tripped + through serialization without tasks being imported (if + ``imported_and_configured=False``). + """ + xgraph = graph.make_xgraph() + expected_edges = ( + {edge.key for edge in graph.iter_edges()} + | {edge.key for edge in graph.iter_edges(init=True)} + | { + (NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK, "a"), Edge.INIT_TO_TASK_NAME), + (NodeKey(NodeType.TASK_INIT, "b"), NodeKey(NodeType.TASK, "b"), Edge.INIT_TO_TASK_NAME), + } + ) + test_edges = set(xgraph.edges) + self.assertEqual(test_edges, expected_edges) + expected_nodes = { + NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node( + "a", resolved, imported_and_configured=imported_and_configured + ), + NodeKey(NodeType.TASK, "a"): self.get_expected_task_node( + "a", resolved, imported_and_configured=imported_and_configured + ), + NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node( + "b", resolved, imported_and_configured=imported_and_configured + ), + NodeKey(NodeType.TASK, "b"): self.get_expected_task_node( + "b", resolved, imported_and_configured=imported_and_configured + ), + NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( + "schema", resolved, is_initial_query_constraint=False + ), + NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( + "input_1", resolved, is_initial_query_constraint=True + ), + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( + "intermediate_1", resolved, is_initial_query_constraint=False + ), + NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( + "output_1", resolved, is_initial_query_constraint=False + ), + } + test_nodes = dict(xgraph.nodes.items()) + self.assertEqual(set(test_nodes.keys()), set(expected_nodes.keys())) + for key, expected_node in expected_nodes.items(): + test_node = test_nodes[key] + self.assertEqual(expected_node, test_node, key) + + def check_make_bipartite_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: + """Check that the given graph's init-only or runtime subset exports as + expected to networkx. + + The given graph must be unchanged from the one defined in `setUp`, + other than being resolved (if ``resolved=True``). + """ + run_xgraph = graph.make_bipartite_xgraph() + self.assertEqual(set(run_xgraph.edges), {edge.key for edge in graph.iter_edges()}) + self.assertEqual( + dict(run_xgraph.nodes.items()), + { + NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved), + NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( + "input_1", resolved, is_initial_query_constraint=True + ), + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( + "intermediate_1", resolved, is_initial_query_constraint=False + ), + NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( + "output_1", resolved, is_initial_query_constraint=False + ), + }, + ) + init_xgraph = graph.make_bipartite_xgraph( + init=True, + ) + self.assertEqual(set(init_xgraph.edges), {edge.key for edge in graph.iter_edges(init=True)}) + self.assertEqual( + dict(init_xgraph.nodes.items()), + { + NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved), + NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( + "schema", resolved, is_initial_query_constraint=False + ), + NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), + }, + ) + + def check_make_task_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: + """Check that the given graph's task-only projection exports as + expected to networkx. + + The given graph must be unchanged from the one defined in `setUp`, + other than being resolved (if ``resolved=True``). + """ + run_xgraph = graph.make_task_xgraph() + self.assertEqual(set(run_xgraph.edges), {(NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.TASK, "b"))}) + self.assertEqual( + dict(run_xgraph.nodes.items()), + { + NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved), + NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved), + }, + ) + init_xgraph = graph.make_task_xgraph( + init=True, + ) + self.assertEqual( + set(init_xgraph.edges), + {(NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK_INIT, "b"))}, + ) + self.assertEqual( + dict(init_xgraph.nodes.items()), + { + NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved), + NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved), + }, + ) + + def check_make_dataset_type_xgraph(self, graph: PipelineGraph, resolved: bool) -> None: + """Check that the given graph's dataset-type-only projection exports as + expected to networkx. + + The given graph must be unchanged from the one defined in `setUp`, + other than being resolved (if ``resolved=True``). + """ + run_xgraph = graph.make_dataset_type_xgraph() + self.assertEqual( + set(run_xgraph.edges), + { + (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "intermediate_1")), + (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_log")), + (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_metadata")), + ( + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), + NodeKey(NodeType.DATASET_TYPE, "output_1"), + ), + (NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), NodeKey(NodeType.DATASET_TYPE, "b_log")), + ( + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"), + ), + }, + ) + self.assertEqual( + dict(run_xgraph.nodes.items()), + { + NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved), + NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node( + "input_1", resolved, is_initial_query_constraint=True + ), + NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node( + "intermediate_1", resolved, is_initial_query_constraint=False + ), + NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node( + "output_1", resolved, is_initial_query_constraint=False + ), + }, + ) + init_xgraph = graph.make_dataset_type_xgraph(init=True) + self.assertEqual( + set(init_xgraph.edges), + {(NodeKey(NodeType.DATASET_TYPE, "schema"), NodeKey(NodeType.DATASET_TYPE, "b_config"))}, + ) + self.assertEqual( + dict(init_xgraph.nodes.items()), + { + NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node( + "schema", resolved, is_initial_query_constraint=False + ), + NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved), + NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved), + }, + ) + + def get_expected_task_node( + self, label: str, resolved: bool, imported_and_configured: bool = True + ) -> dict[str, Any]: + """Construct a networkx-export task node for comparison.""" + result = self.get_expected_task_init_node( + label, resolved, imported_and_configured=imported_and_configured + ) + if resolved: + result["dimensions"] = self.dimensions.empty + result["raw_dimensions"] = frozenset() + return result + + def get_expected_task_init_node( + self, label: str, resolved: bool, imported_and_configured: bool = True + ) -> dict[str, Any]: + """Construct a networkx-export task init for comparison.""" + result = { + "task_class_name": "lsst.pipe.base.tests.mocks.DynamicTestPipelineTask", + "bipartite": 1, + } + if imported_and_configured: + result["task_class"] = DynamicTestPipelineTask + result["config"] = getattr(self, f"{label}_config") + return result + + def get_expected_config_node(self, label: str, resolved: bool) -> dict[str, Any]: + """Construct a networkx-export init-output config dataset type node for + comparison. + """ + if not resolved: + return {"bipartite": 0} + else: + return { + "dataset_type": DatasetType( + acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label), + self.dimensions.empty, + acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, + ), + "is_initial_query_constraint": False, + "is_prerequisite": False, + "dimensions": self.dimensions.empty, + "storage_class_name": acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, + "bipartite": 0, + } + + def get_expected_log_node(self, label: str, resolved: bool) -> dict[str, Any]: + """Construct a networkx-export output log dataset type node for + comparison. + """ + if not resolved: + return {"bipartite": 0} + else: + return { + "dataset_type": DatasetType( + acc.LOG_OUTPUT_TEMPLATE.format(label=label), + self.dimensions.empty, + acc.LOG_OUTPUT_STORAGE_CLASS, + ), + "is_initial_query_constraint": False, + "is_prerequisite": False, + "dimensions": self.dimensions.empty, + "storage_class_name": acc.LOG_OUTPUT_STORAGE_CLASS, + "bipartite": 0, + } + + def get_expected_metadata_node(self, label: str, resolved: bool) -> dict[str, Any]: + """Construct a networkx-export output metadata dataset type node for + comparison. + """ + if not resolved: + return {"bipartite": 0} + else: + return { + "dataset_type": DatasetType( + acc.METADATA_OUTPUT_TEMPLATE.format(label=label), + self.dimensions.empty, + acc.METADATA_OUTPUT_STORAGE_CLASS, + ), + "is_initial_query_constraint": False, + "is_prerequisite": False, + "dimensions": self.dimensions.empty, + "storage_class_name": acc.METADATA_OUTPUT_STORAGE_CLASS, + "bipartite": 0, + } + + def get_expected_connection_node( + self, name: str, resolved: bool, *, is_initial_query_constraint: bool + ) -> dict[str, Any]: + """Construct a networkx-export dataset type node for comparison.""" + if not resolved: + return {"bipartite": 0} + else: + return { + "dataset_type": DatasetType( + name, + self.dimensions.empty, + get_mock_name("StructuredDataDict"), + ), + "is_initial_query_constraint": is_initial_query_constraint, + "is_prerequisite": False, + "dimensions": self.dimensions.empty, + "storage_class_name": get_mock_name("StructuredDataDict"), + "bipartite": 0, + } + + def test_construct_with_data_coordinate(self) -> None: + """Test constructing a graph with a DataCoordinate. + + Since this creates a graph with DimensionUniverse, all tasks added to + it should have resolved dimensions, but not (yet) resolved dataset + types. We use that to test a few other operations in that state. + """ + data_id = DataCoordinate.standardize(instrument="I", universe=self.dimensions) + graph = PipelineGraph(data_id=data_id) + self.assertEqual(graph.universe, self.dimensions) + self.assertEqual(graph.data_id, data_id) + graph.add_task("b1", DynamicTestPipelineTask, self.b_config) + self.assertEqual(graph.tasks["b1"].dimensions, self.dimensions.empty) + # Still can't group by dimensions, because the dataset types aren't + # resolved. + with self.assertRaises(UnresolvedGraphError): + graph.group_by_dimensions() + # Transferring a node from this graph to ``self.graph`` should + # unresolve the dimensions. + self.graph.add_task_nodes([graph.tasks["b1"]]) + self.assertIsNot(self.graph.tasks["b1"], graph.tasks["b1"]) + self.assertFalse(self.graph.tasks["b1"].has_resolved_dimensions) + # Do the opposite transfer, which should resolve dimensions. + graph.add_task_nodes([self.graph.tasks["a"]]) + self.assertIsNot(self.graph.tasks["a"], graph.tasks["a"]) + self.assertTrue(graph.tasks["a"].has_resolved_dimensions) + + def test_group_by_dimensions(self) -> None: + """Test PipelineGraph.group_by_dimensions.""" + with self.assertRaises(UnresolvedGraphError): + self.graph.group_by_dimensions() + self.a_config.dimensions = ["visit"] + self.a_config.outputs["output1"].dimensions = ["visit"] + self.a_config.prerequisite_inputs["prereq1"] = DynamicConnectionConfig( + dataset_type_name="prereq_1", + multiple=True, + dimensions=["htm7"], + is_calibration=True, + ) + self.b_config.dimensions = ["htm7"] + self.b_config.inputs["input1"].dimensions = ["visit"] + self.b_config.inputs["input1"].multiple = True + self.b_config.outputs["output1"].dimensions = ["htm7"] + self.graph.reconfigure_tasks(a=self.a_config, b=self.b_config) + self.graph.resolve(MockRegistry(self.dimensions, {})) + visit_dims = self.dimensions.extract(["visit"]) + htm7_dims = self.dimensions.extract(["htm7"]) + expected = { + self.dimensions.empty: ( + {}, + { + "schema": self.graph.dataset_types["schema"], + "input_1": self.graph.dataset_types["input_1"], + "a_config": self.graph.dataset_types["a_config"], + "b_config": self.graph.dataset_types["b_config"], + }, + ), + visit_dims: ( + {"a": self.graph.tasks["a"]}, + { + "a_log": self.graph.dataset_types["a_log"], + "a_metadata": self.graph.dataset_types["a_metadata"], + "intermediate_1": self.graph.dataset_types["intermediate_1"], + }, + ), + htm7_dims: ( + {"b": self.graph.tasks["b"]}, + { + "b_log": self.graph.dataset_types["b_log"], + "b_metadata": self.graph.dataset_types["b_metadata"], + "output_1": self.graph.dataset_types["output_1"], + }, + ), + } + self.assertEqual(self.graph.group_by_dimensions(), expected) + expected[htm7_dims][1]["prereq_1"] = self.graph.dataset_types["prereq_1"] + self.assertEqual(self.graph.group_by_dimensions(prerequisites=True), expected) + + def test_add_and_remove(self) -> None: + """Tests for adding and removing tasks and task subsets from a + PipelineGraph. + """ + # Can't remove a task while it's still in a subset. + with self.assertRaises(PipelineGraphError): + self.graph.remove_tasks(["b"], drop_from_subsets=False) + # ...unless you remove the subset. + self.graph.remove_task_subset("only_b") + self.assertFalse(self.graph.task_subsets) + ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=False) + self.assertFalse(referencing_subsets) + self.assertEqual(self.graph.tasks.keys(), {"a"}) + # Add that task back in. + self.graph.add_task_nodes([b]) + self.assertEqual(self.graph.tasks.keys(), {"a", "b"}) + # Add the subset back in. + self.graph.add_task_subset("only_b", {"b"}) + self.assertEqual(self.graph.task_subsets.keys(), {"only_b"}) + # Resolve the graph's dataset types and task dimensions. + self.graph.resolve(MockRegistry(self.dimensions, {})) + self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) + self.assertTrue(self.graph.dataset_types.is_resolved("output_1")) + self.assertTrue(self.graph.dataset_types.is_resolved("schema")) + self.assertTrue(self.graph.dataset_types.is_resolved("intermediate_1")) + # Remove the task while removing it from the subset automatically. This + # should also unresolve (only) the referenced dataset types and drop + # any datasets no longer attached to any task. + self.assertEqual(self.graph.tasks.keys(), {"a", "b"}) + ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=True) + self.assertEqual(referencing_subsets, {"only_b"}) + self.assertEqual(self.graph.tasks.keys(), {"a"}) + self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) + self.assertNotIn("output1", self.graph.dataset_types) + self.assertFalse(self.graph.dataset_types.is_resolved("schema")) + self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1")) + + def test_reconfigure(self) -> None: + """Tests for PipelineGraph.reconfigure.""" + self.graph.resolve(MockRegistry(self.dimensions, {})) + self.b_config.outputs["output1"].storage_class = "TaskMetadata" + with self.assertRaises(ValueError): + # Can't check and assume together. + self.graph.reconfigure_tasks( + b=self.b_config, assume_edges_unchanged=True, check_edges_unchanged=True + ) + # Check that graph is unchanged after error. + self.check_base_accessors(self.graph) + with self.assertRaises(EdgesChangedError): + self.graph.reconfigure_tasks(b=self.b_config, check_edges_unchanged=True) + self.check_base_accessors(self.graph) + # Make a change that does affect edges; this will unresolve most + # dataset types. + self.graph.reconfigure_tasks(b=self.b_config) + self.assertTrue(self.graph.dataset_types.is_resolved("input_1")) + self.assertFalse(self.graph.dataset_types.is_resolved("output_1")) + self.assertFalse(self.graph.dataset_types.is_resolved("schema")) + self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1")) + # Resolving again will pick up the new storage class + self.graph.resolve(MockRegistry(self.dimensions, {})) + self.assertEqual( + self.graph.dataset_types["output_1"].storage_class_name, get_mock_name("TaskMetadata") + ) + + +class PipelineGraphResolveTestCase(unittest.TestCase): + """More extensive tests for PipelineGraph.resolve and its primate helper + methods. + + These are in a separate TestCase because they utilize a different `setUp` + from the rest of the `PipelineGraph` tests. + """ + + def setUp(self) -> None: + self.a_config = DynamicTestPipelineTaskConfig() + self.b_config = DynamicTestPipelineTaskConfig() + self.dimensions = DimensionUniverse() + self.maxDiff = None + + def make_graph(self) -> PipelineGraph: + graph = PipelineGraph() + graph.add_task("a", DynamicTestPipelineTask, self.a_config) + graph.add_task("b", DynamicTestPipelineTask, self.b_config) + return graph + + def test_prerequisite_inconsistency(self) -> None: + """Test that we raise an exception when one edge defines a dataset type + as a prerequisite and another does not. + + This test will hopefully someday go away (along with + `DatasetTypeNode.is_prerequisite`) when the QuantumGraph generation + algorithm becomes more flexible. + """ + self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") + self.b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d") + graph = self.make_graph() + with self.assertRaises(ConnectionTypeConsistencyError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_prerequisite_inconsistency_reversed(self) -> None: + """Same as `test_prerequisite_inconsistency`, with the order the edges + are added to the graph reversed. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d") + self.b_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") + graph = self.make_graph() + with self.assertRaises(ConnectionTypeConsistencyError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_prerequisite_output(self) -> None: + """Test that we raise an exception when one edge defines a dataset type + as a prerequisite but another defines it as an output. + """ + self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d") + self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") + graph = self.make_graph() + with self.assertRaises(ConnectionTypeConsistencyError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_skypix_missing(self) -> None: + """Test that we raise an exception when one edge uses the "skypix" + dimension as a placeholder but the dataset type is not registered. + """ + self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", dimensions={"skypix"} + ) + graph = self.make_graph() + with self.assertRaises(MissingDatasetTypeError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_skypix_inconsistent(self) -> None: + """Test that we raise an exception when one edge uses the "skypix" + dimension as a placeholder but the rest of the dimensions are + inconsistent with the registered dataset type. + """ + self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", dimensions={"skypix", "visit"} + ) + graph = self.make_graph() + with self.assertRaises(IncompatibleDatasetTypeError): + graph.resolve( + MockRegistry( + self.dimensions, + { + "d": DatasetType( + "d", dimensions=self.dimensions.extract(["htm7"]), storageClass="ArrowTable" + ) + }, + ) + ) + with self.assertRaises(IncompatibleDatasetTypeError): + graph.resolve( + MockRegistry( + self.dimensions, + { + "d": DatasetType( + "d", + dimensions=self.dimensions.extract(["htm7", "visit", "skymap"]), + storageClass="ArrowTable", + ) + }, + ) + ) + + def test_duplicate_outputs(self) -> None: + """Test that we raise an exception when a dataset type node would have + two write edges. + """ + self.a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") + self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d") + graph = self.make_graph() + with self.assertRaises(DuplicateOutputError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_component_of_unregistered_parent(self) -> None: + """Test that we raise an exception when a component dataset type's + parent is not registered. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c") + graph = self.make_graph() + with self.assertRaises(MissingDatasetTypeError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_undefined_component(self) -> None: + """Test that we raise an exception when a component dataset type's + parent is registered, but its storage class does not have that + component. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c") + graph = self.make_graph() + with self.assertRaises(IncompatibleDatasetTypeError): + graph.resolve( + MockRegistry( + self.dimensions, + {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))}, + ) + ) + + def test_bad_component_storage_class(self) -> None: + """Test that we raise an exception when a component dataset type's + parent is registered, but does not have that component. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d.schema", storage_class="StructuredDataDict" + ) + graph = self.make_graph() + with self.assertRaises(IncompatibleDatasetTypeError): + graph.resolve( + MockRegistry( + self.dimensions, + {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))}, + ) + ) + + def test_input_storage_class_incompatible_with_registry(self) -> None: + """Test that we raise an exception when an input connection's storage + class is incompatible with the registry definition. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="StructuredDataList" + ) + graph = self.make_graph() + with self.assertRaises(IncompatibleDatasetTypeError): + graph.resolve( + MockRegistry( + self.dimensions, + {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))}, + ) + ) + + def test_output_storage_class_incompatible_with_registry(self) -> None: + """Test that we raise an exception when an output connection's storage + class is incompatible with the registry definition. + """ + self.a_config.outputs["o"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="StructuredDataList" + ) + graph = self.make_graph() + with self.assertRaises(IncompatibleDatasetTypeError): + graph.resolve( + MockRegistry( + self.dimensions, + {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))}, + ) + ) + + def test_input_storage_class_incompatible_with_output(self) -> None: + """Test that we raise an exception when an input connection's storage + class is incompatible with the storage class of the output connection. + """ + self.a_config.outputs["o"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="ArrowTable" + ) + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="StructuredDataList" + ) + graph = self.make_graph() + with self.assertRaises(IncompatibleDatasetTypeError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_ambiguous_storage_class(self) -> None: + """Test that we raise an exception when two input connections define + the same dataset with different storage classes (even compatible ones) + and there is no output connection or registry definition to take + precedence. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable") + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="ArrowAstropy" + ) + graph = self.make_graph() + with self.assertRaises(MissingDatasetTypeError): + graph.resolve(MockRegistry(self.dimensions, {})) + + def test_inputs_compatible_with_registry(self) -> None: + """Test successful resolution of a dataset type where input edges have + different but compatible storage classes and the dataset type is + already registered. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable") + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="ArrowAstropy" + ) + graph = self.make_graph() + dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame")) + graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type})) + self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type) + a_i = graph.tasks["a"].inputs["i"] + b_i = graph.tasks["b"].inputs["i"] + self.assertEqual( + a_i.adapt_dataset_type(dataset_type), + dataset_type.overrideStorageClass(get_mock_name("ArrowTable")), + ) + self.assertEqual( + b_i.adapt_dataset_type(dataset_type), + dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")), + ) + data_id = DataCoordinate.makeEmpty(self.dimensions) + ref = DatasetRef(dataset_type, data_id, run="r") + a_ref = a_i.adapt_dataset_ref(ref) + b_ref = b_i.adapt_dataset_ref(ref) + self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable"))) + self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy"))) + self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) + self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + + def test_output_compatible_with_registry(self) -> None: + """Test successful resolution of a dataset type where an output edge + has a different but compatible storage class from the dataset type + already registered. + """ + self.a_config.outputs["o"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="ArrowTable" + ) + graph = self.make_graph() + dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame")) + graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type})) + self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type) + a_o = graph.tasks["a"].outputs["o"] + self.assertEqual( + a_o.adapt_dataset_type(dataset_type), + dataset_type.overrideStorageClass(get_mock_name("ArrowTable")), + ) + data_id = DataCoordinate.makeEmpty(self.dimensions) + ref = DatasetRef(dataset_type, data_id, run="r") + a_ref = a_o.adapt_dataset_ref(ref) + self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable"))) + self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) + + def test_inputs_compatible_with_output(self) -> None: + """Test successful resolution of a dataset type where an input edge has + a different but compatible storage class from the output edge, and + the dataset type is not registered. + """ + self.a_config.outputs["o"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="ArrowTable" + ) + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="ArrowAstropy" + ) + graph = self.make_graph() + a_o = graph.tasks["a"].outputs["o"] + b_i = graph.tasks["b"].inputs["i"] + graph.resolve(MockRegistry(self.dimensions, {})) + self.assertEqual(graph.dataset_types["d"].storage_class_name, get_mock_name("ArrowTable")) + self.assertEqual( + a_o.adapt_dataset_type(graph.dataset_types["d"].dataset_type), + graph.dataset_types["d"].dataset_type, + ) + self.assertEqual( + b_i.adapt_dataset_type(graph.dataset_types["d"].dataset_type), + graph.dataset_types["d"].dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")), + ) + data_id = DataCoordinate.makeEmpty(self.dimensions) + ref = DatasetRef(graph.dataset_types["d"].dataset_type, data_id, run="r") + a_ref = a_o.adapt_dataset_ref(ref) + b_ref = b_i.adapt_dataset_ref(ref) + self.assertEqual(a_ref, ref) + self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy"))) + self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) + self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + + def test_component_resolved_by_input(self) -> None: + """Test successful resolution of a component dataset type due to + another input referencing the parent dataset type. + """ + self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable") + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d.schema", storage_class="ArrowSchema" + ) + graph = self.make_graph() + parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) + graph.resolve(MockRegistry(self.dimensions, {})) + self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) + a_i = graph.tasks["a"].inputs["i"] + b_i = graph.tasks["b"].inputs["i"] + self.assertEqual(b_i.dataset_type_name, "d.schema") + self.assertEqual(a_i.adapt_dataset_type(parent_dataset_type), parent_dataset_type) + self.assertEqual( + b_i.adapt_dataset_type(parent_dataset_type), + parent_dataset_type.makeComponentDatasetType("schema"), + ) + data_id = DataCoordinate.makeEmpty(self.dimensions) + ref = DatasetRef(parent_dataset_type, data_id, run="r") + a_ref = a_i.adapt_dataset_ref(ref) + b_ref = b_i.adapt_dataset_ref(ref) + self.assertEqual(a_ref, ref) + self.assertEqual(b_ref, ref.makeComponentRef("schema")) + self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) + self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + + def test_component_resolved_by_output(self) -> None: + """Test successful resolution of a component dataset type due to + an output connection referencing the parent dataset type. + """ + self.a_config.outputs["o"] = DynamicConnectionConfig( + dataset_type_name="d", storage_class="ArrowTable" + ) + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d.schema", storage_class="ArrowSchema" + ) + graph = self.make_graph() + parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) + graph.resolve(MockRegistry(self.dimensions, {})) + self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) + a_o = graph.tasks["a"].outputs["o"] + b_i = graph.tasks["b"].inputs["i"] + self.assertEqual(b_i.dataset_type_name, "d.schema") + self.assertEqual(a_o.adapt_dataset_type(parent_dataset_type), parent_dataset_type) + self.assertEqual( + b_i.adapt_dataset_type(parent_dataset_type), + parent_dataset_type.makeComponentDatasetType("schema"), + ) + data_id = DataCoordinate.makeEmpty(self.dimensions) + ref = DatasetRef(parent_dataset_type, data_id, run="r") + a_ref = a_o.adapt_dataset_ref(ref) + b_ref = b_i.adapt_dataset_ref(ref) + self.assertEqual(a_ref, ref) + self.assertEqual(b_ref, ref.makeComponentRef("schema")) + self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) + self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + + def test_component_resolved_by_registry(self) -> None: + """Test successful resolution of a component dataset type due to + the parent dataset type already being registered. + """ + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d.schema", storage_class="ArrowSchema" + ) + graph = self.make_graph() + parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable")) + graph.resolve(MockRegistry(self.dimensions, {"d": parent_dataset_type})) + self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type) + b_i = graph.tasks["b"].inputs["i"] + self.assertEqual(b_i.dataset_type_name, "d.schema") + self.assertEqual( + b_i.adapt_dataset_type(parent_dataset_type), + parent_dataset_type.makeComponentDatasetType("schema"), + ) + data_id = DataCoordinate.makeEmpty(self.dimensions) + ref = DatasetRef(parent_dataset_type, data_id, run="r") + b_ref = b_i.adapt_dataset_ref(ref) + self.assertEqual(b_ref, ref.makeComponentRef("schema")) + self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main()