From d976e5ab2594c84035bc377a2f676277d996ab30 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Fri, 20 Jan 2023 00:01:35 -0500 Subject: [PATCH] WIP: Move pipeline graph code around, add copy methods. --- python/lsst/pipe/base/__init__.py | 2 +- python/lsst/pipe/base/pipeline.py | 4 +- .../lsst/pipe/base/pipeline_graph/__init__.py | 7 +- python/lsst/pipe/base/pipeline_graph/_abcs.py | 210 +++++++ .../base/pipeline_graph/_dataset_types.py | 109 ++++ .../lsst/pipe/base/pipeline_graph/_edges.py | 272 +++++++++ .../pipeline_graph/_generic_pipeline_graph.py | 119 +--- .../base/pipeline_graph/_pipeline_graph.py | 141 ++++- .../_resolved_pipeline_graph.py | 158 ----- .../lsst/pipe/base/pipeline_graph/_tasks.py | 181 ++++++ .../_unresolved_pipeline_graph.py | 76 --- .../pipeline_graph/_vertices_and_edges.py | 546 ------------------ 12 files changed, 920 insertions(+), 905 deletions(-) create mode 100644 python/lsst/pipe/base/pipeline_graph/_abcs.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_dataset_types.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_edges.py delete mode 100644 python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_tasks.py delete mode 100644 python/lsst/pipe/base/pipeline_graph/_unresolved_pipeline_graph.py delete mode 100644 python/lsst/pipe/base/pipeline_graph/_vertices_and_edges.py diff --git a/python/lsst/pipe/base/__init__.py b/python/lsst/pipe/base/__init__.py index e3dc92c08..c54ccd3f7 100644 --- a/python/lsst/pipe/base/__init__.py +++ b/python/lsst/pipe/base/__init__.py @@ -13,7 +13,7 @@ # We import the main PipelineGraph types and the module (above), but we don't # lift all symbols to package scope. -from .pipeline_graph import ResolvedPipelineGraph, UnresolvedPipelineGraph +from .pipeline_graph import MutablePipelineGraph, ResolvedPipelineGraph from .pipelineTask import * from .struct import * from .task import * diff --git a/python/lsst/pipe/base/pipeline.py b/python/lsst/pipe/base/pipeline.py index 3cb3e98d1..e3a0b97f8 100644 --- a/python/lsst/pipe/base/pipeline.py +++ b/python/lsst/pipe/base/pipeline.py @@ -67,7 +67,7 @@ from .config import PipelineTaskConfig from .configOverrides import ConfigOverrides from .connections import iterConnections -from .pipeline_graph import UnresolvedPipelineGraph +from .pipeline_graph import MutablePipelineGraph from .pipelineTask import PipelineTask from .task import _TASK_METADATA_TYPE @@ -677,7 +677,7 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]: raise pipelineIR.ContractError( f"Contract(s) '{contract.contract}' were not satisfied{extra_info}" ) - graph = UnresolvedPipelineGraph() + graph = MutablePipelineGraph() for task_def in taskDefs: graph.add_task( task_def.label, diff --git a/python/lsst/pipe/base/pipeline_graph/__init__.py b/python/lsst/pipe/base/pipeline_graph/__init__.py index d6ba9675c..36a779604 100644 --- a/python/lsst/pipe/base/pipeline_graph/__init__.py +++ b/python/lsst/pipe/base/pipeline_graph/__init__.py @@ -20,9 +20,10 @@ # along with this program. If not, see . from __future__ import annotations +from ._abcs import * +from ._dataset_types import * +from ._edges import * from ._exceptions import * from ._generic_pipeline_graph import * from ._pipeline_graph import * -from ._resolved_pipeline_graph import * -from ._unresolved_pipeline_graph import * -from ._vertices_and_edges import * +from ._tasks import * diff --git a/python/lsst/pipe/base/pipeline_graph/_abcs.py b/python/lsst/pipe/base/pipeline_graph/_abcs.py new file mode 100644 index 000000000..cfe4d264b --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_abcs.py @@ -0,0 +1,210 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ( + "Node", + "Edge", + "SubgraphView", + "MappingSubgraphView", +) + +import itertools +from abc import ABC, abstractmethod +from collections.abc import Iterable, Iterator, Mapping +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar + +import networkx +from lsst.daf.butler import DatasetType, DimensionUniverse, Registry + +from ._exceptions import ConnectionTypeConsistencyError, IncompatibleDatasetTypeError + +if TYPE_CHECKING: + from ..connectionTypes import BaseConnection + from ._dataset_types import DatasetTypeNode + + +class Node(ABC): + + BIPARTITE_CONSTANT: ClassVar[int] + + @abstractmethod + def _resolve(self, state: dict[str, Any], graph: networkx.DiGraph, registry: Registry) -> None: + raise NotImplementedError() + + @abstractmethod + def _unresolve(self, state: dict[str, Any]) -> None: + raise NotImplementedError() + + +class Edge(ABC): + + task_label: str + parent_dataset_type_name: str + + @property + @abstractmethod + def key(self) -> tuple[str, str]: + raise NotImplementedError() + + @property + def dataset_type_name(self) -> str: + return self.parent_dataset_type_name + + @classmethod + @abstractmethod + def _from_connection( + cls, + task_label: str, + connection: BaseConnection, + edge_data: list[tuple[str, str, dict[str, Any]]], + *, + is_init: bool, + is_prerequisite: bool = False, + ) -> Edge: + raise NotImplementedError() + + @abstractmethod + def _check_dataset_type( + self, + state: dict[str, Any], + xgraph: networkx.DiGraph, + dataset_type_node: DatasetTypeNode, + ) -> None: + if state["is_init"] != dataset_type_node.is_init: + referencing_tasks = list( + itertools.chain( + xgraph.predecessors(dataset_type_node.name), + xgraph.successors(dataset_type_node.name), + ) + ) + if state["is_init"]: + raise ConnectionTypeConsistencyError( + f"{dataset_type_node.name!r} is an init dataset in task {self.task_label!r}, " + f"but a run dataset in task(s) {referencing_tasks}." + ) + else: + raise ConnectionTypeConsistencyError( + f"{dataset_type_node.name!r} is a run dataset in task {self.task_label!r}, " + f"but an init dataset in task(s) {referencing_tasks}." + ) + if state["is_prerequisite"] != dataset_type_node.is_prerequisite: + referencing_tasks = list(xgraph.successors(dataset_type_node.name)) + if state["is_prerequisite"]: + raise ConnectionTypeConsistencyError( + f"Dataset type {dataset_type_node.name!r} is a prerequisite input in " + f"task {self.task_label!r}, but it was not a prerequisite to " + f"{referencing_tasks}." + ) + else: + raise ConnectionTypeConsistencyError( + f"Dataset type {dataset_type_node.name!r} is not a prerequisite input in " + f"task {self.task_label!r}, but it was a prerequisite to " + f"{referencing_tasks}." + ) + connection: BaseConnection = state["connection"] + if connection.isCalibration != dataset_type_node.is_calibration: + referencing_tasks = list( + itertools.chain( + xgraph.predecessors(dataset_type_node.name), + xgraph.successors(dataset_type_node.name), + ) + ) + if connection.isCalibration: + raise IncompatibleDatasetTypeError( + f"Dataset type {dataset_type_node.name!r} is a calibration in " + f"task {self.task_label}, but it was not in task(s) {referencing_tasks}." + ) + else: + raise IncompatibleDatasetTypeError( + f"Dataset type {dataset_type_node.name!r} is not a calibration in " + f"task {self.task_label}, but it was in task(s) {referencing_tasks}." + ) + + @abstractmethod + def _make_dataset_type_state(self, state: dict[str, Any]) -> dict[str, Any]: + raise NotImplementedError() + + @abstractmethod + def _resolve_dataset_type( + self, + state: dict[str, Any], + current: DatasetType | None, + universe: DimensionUniverse, + ) -> DatasetType: + raise NotImplementedError() + + +_G = TypeVar("_G", bound=networkx.DiGraph, covariant=True) + + +class SubgraphView(Generic[_G]): + def __init__(self, parent_xgraph: networkx.DiGraph) -> None: + self._parent_xgraph = parent_xgraph + self._xgraph: _G | None = None + + @property + def xgraph(self) -> _G: + if self._xgraph is None: + self._xgraph = self._make_xgraph() + return self._xgraph + + @abstractmethod + def _make_xgraph(self) -> _G: + raise NotImplementedError() + + def _reset(self) -> None: + self._xgraph = None + + +_N = TypeVar("_N", bound=Node, covariant=True) + + +class MappingSubgraphView(Generic[_G, _N], Mapping[str, _N], SubgraphView[_G]): + def __init__(self, parent_xgraph: networkx.DiGraph) -> None: + self._parent_xgraph = parent_xgraph + self._keys: list[str] | None = None + + @abstractmethod + def __contains__(self, key: object) -> bool: + raise NotImplementedError() + + def __iter__(self) -> Iterator[str]: + if self._keys is None: + self._keys = [k for k in self._parent_xgraph if k in self] + return iter(self._keys) + + def __getitem__(self, key: str) -> _N: + if key not in self: + raise KeyError(key) + return self._parent_xgraph[key]["instance"] + + def __len__(self) -> int: + if self._keys is None: + self._keys = [k for k in self._parent_xgraph if k in self] + return len(self._keys) + + def _reorder(self, parent_keys: Iterable[str]) -> None: + self._keys = [k for k in parent_keys if k in self] + + def _reset(self) -> None: + super()._reset() + self._keys = None diff --git a/python/lsst/pipe/base/pipeline_graph/_dataset_types.py b/python/lsst/pipe/base/pipeline_graph/_dataset_types.py new file mode 100644 index 000000000..b176abd55 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_dataset_types.py @@ -0,0 +1,109 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ( + "DatasetTypeNode", + "ResolvedDatasetTypeNode", +) + +import dataclasses +import itertools +from typing import Any, ClassVar, TypeVar + +import networkx +from lsst.daf.butler import DatasetType, Registry +from lsst.daf.butler.registry import MissingDatasetTypeError + +from ._abcs import Edge, MappingSubgraphView, Node + + +@dataclasses.dataclass(frozen=True, eq=False) +class DatasetTypeNode(Node): + BIPARTITE_CONSTANT: ClassVar[int] = 1 + + name: str + is_calibration: bool + is_init: bool + is_prerequisite: bool + + def _resolve(self, state: dict[str, Any], graph: networkx.DiGraph, registry: Registry) -> None: + try: + dataset_type = registry.getDatasetType(self.name) + except MissingDatasetTypeError: + dataset_type = None + for edge_state in itertools.chain( + graph.in_edges(self.name, data=True), graph.out_edges(self.name, data=True) + ): + edge: Edge = edge_state["instance"] + dataset_type = edge._resolve_dataset_type( + edge_state, + current=dataset_type, + universe=registry.dimensions, + ) + assert dataset_type is not None, "Graph structure guarantees at least one edge." + state[self.name] = ResolvedDatasetTypeNode( + name=self.name, + is_calibration=self.is_calibration, + is_init=self.is_init, + is_prerequisite=self.is_prerequisite, + dataset_type=dataset_type, + ) + + def _unresolve(self, state: dict[str, Any]) -> None: + pass + + +@dataclasses.dataclass(frozen=True, eq=False) +class ResolvedDatasetTypeNode(DatasetTypeNode): + dataset_type: DatasetType + + def _resolve(self, state: dict[str, Any], graph: networkx.DiGraph, registry: Registry) -> None: + pass + + def _unresolve(self, state: dict[str, Any]) -> None: + state[self.name] = DatasetTypeNode( + name=self.name, + is_calibration=self.is_calibration, + is_init=self.is_init, + is_prerequisite=self.is_prerequisite, + ) + + +_D = TypeVar("_D", bound=DatasetTypeNode, covariant=True) + + +class DatasetTypeSubgraphView(MappingSubgraphView[networkx.DiGraph, _D]): + def __init__(self, parent_xgraph: networkx.DiGraph, is_init: bool): + super().__init__(parent_xgraph) + self._is_init = is_init + + def __contains__(self, key: object) -> bool: + if state := self._parent_xgraph.nodes.get(key): + return ( + state["bipartite"] == DatasetTypeNode.BIPARTITE_CONSTANT and state["is_init"] == self._is_init + ) + return False + + def _make_xgraph(self) -> networkx.DiGraph: + return networkx.freeze( + networkx.bipartite.projected_graph(self._parent_xgraph, self, multigraph=False) + ) diff --git a/python/lsst/pipe/base/pipeline_graph/_edges.py b/python/lsst/pipe/base/pipeline_graph/_edges.py new file mode 100644 index 000000000..deb7294ff --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_edges.py @@ -0,0 +1,272 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ( + "WriteEdge", + "ReadEdge", +) + +import dataclasses +from collections.abc import Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any + +import networkx +from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Registry +from lsst.daf.butler.registry import MissingDatasetTypeError + +from ._abcs import Edge +from ._dataset_types import DatasetTypeNode +from ._exceptions import ConnectionTypeConsistencyError, IncompatibleDatasetTypeError, MultipleProducersError + +if TYPE_CHECKING: + from ..connectionTypes import BaseConnection + + +@dataclasses.dataclass(frozen=True, eq=False) +class ReadEdge(Edge): + parent_dataset_type_name: str + task_label: str + component: str | None + storage_class_name: str + lookup_function: Callable[ + [DatasetType, Registry, DataCoordinate, Sequence[str]], Iterable[DatasetRef] + ] | None = None + + @property + def key(self) -> tuple[str, str]: + return (self.parent_dataset_type_name, self.task_label) + + @property + def dataset_type_name(self) -> str: + return f"{self.parent_dataset_type_name}.{self.component}" + + @classmethod + def _from_connection( + cls, + task_label: str, + connection: BaseConnection, + edge_data: list[tuple[str, str, dict[str, Any]]], + *, + is_init: bool, + is_prerequisite: bool = False, + ) -> ReadEdge: + parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) + instance = cls( + task_label, + parent_dataset_type_name, + component, + connection.storageClass, + ) + edge_data.append( + ( + instance.parent_dataset_type_name, + instance.task_label, + dict( + instance=instance, connection=connection, is_init=is_init, is_prerequisite=is_prerequisite + ), + ) + ) + return instance + + def _check_dataset_type( + self, + state: dict[str, Any], + xgraph: networkx.DiGraph, + dataset_type_node: DatasetTypeNode, + ) -> None: + super()._check_dataset_type(state, xgraph, dataset_type_node) + if state["is_prerequisite"] != dataset_type_node.is_prerequisite: + referencing_tasks = list(xgraph.successors(dataset_type_node.name)) + if state["is_prerequisite"]: + raise ConnectionTypeConsistencyError( + f"Dataset type {dataset_type_node.name!r} is a prerequisite input in " + f"task {self.task_label!r}, but it was a regular input to " + f"{referencing_tasks}." + ) + else: + raise ConnectionTypeConsistencyError( + f"Dataset type {dataset_type_node.name!r} is a regular input in " + f"task {self.task_label!r}, but it was a prerequisite input to " + f"{referencing_tasks}." + ) + + def _make_dataset_type_state(self, state: dict[str, Any]) -> dict[str, Any]: + connection: BaseConnection = state["connection"] + return dict( + instance=DatasetTypeNode( + self.parent_dataset_type_name, + is_calibration=connection.isCalibration, + is_init=state["is_init"], + is_prerequisite=state["is_prerequisite"], + ), + bipartite=DatasetTypeNode.BIPARTITE_CONSTANT, + ) + + def _resolve_dataset_type( + self, + state: dict[str, Any], + current: DatasetType | None, + universe: DimensionUniverse, + ) -> DatasetType: + connection: BaseConnection = state["connection"] + dimensions = connection.resolve_dimensions( + universe, + current.dimensions if current is not None else None, + self.task_label, + ) + if self.component is not None: + if current is None: + raise MissingDatasetTypeError( + f"Dataset type {self.parent_dataset_type_name!r} is not registered and not produced by " + f"this pipeline, but it used by task {self.task_label!r}, via component " + f"{self.component!r}. This pipeline cannot be resolved until the parent dataset type is " + "registered." + ) + if self.component not in current.storageClass.components: + raise IncompatibleDatasetTypeError( + f"Dataset type {self.parent_dataset_type_name!r} has storage class " + f"{current.storageClass_name!r}, which does not include component {self.component!r} " + f"as requested by task {self.task_label!r}." + ) + if current.storageClass.components[self.component].name != connection.storageClass: + raise IncompatibleDatasetTypeError( + f"Dataset type '{self.parent_dataset_type_name}.{self.component}' has storage class " + f"{current.storageClass.components[self.component].name!r}, which does not match " + f"{connection.storageClass!r}, as requested by task {self.task_label!r}. " + "Note that storage class conversions of components are not supported." + ) + return current + else: + dataset_type = DatasetType( + self.parent_dataset_type_name, + dimensions, + storageClass=connection.storageClass, + isCalibration=connection.isCalibration, + ) + if current is not None: + if not dataset_type.is_compatible_with(current): + raise IncompatibleDatasetTypeError( + f"Incompatible definition for input dataset type {self.parent_dataset_type_name!r} " + f"to task {self.task_label!r}: {current} != {dataset_type}." + ) + return current + else: + return dataset_type + + +@dataclasses.dataclass(frozen=True, eq=False) +class WriteEdge(Edge): + task_label: str + parent_dataset_type_name: str + storage_class_name: str + + @property + def key(self) -> tuple[str, str]: + return (self.task_label, self.parent_dataset_type_name) + + @classmethod + def _from_connection( + cls, + task_label: str, + connection: BaseConnection, + edge_data: list[tuple[str, str, dict[str, Any]]], + *, + is_init: bool, + is_prerequisite: bool = False, + ) -> WriteEdge: + parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) + if component is not None: + raise ValueError(f"Illegal output component dataset {connection.name!r} in task {task_label!r}.") + instance = cls(task_label, parent_dataset_type_name, connection.storageClass) + edge_data.append( + ( + instance.task_label, + instance.parent_dataset_type_name, + dict( + instance=instance, + connection=connection, + is_init=is_init, + is_prerequisite=is_prerequisite, + ), + ) + ) + return instance + + def _check_dataset_type( + self, + state: dict[str, Any], + xgraph: networkx.DiGraph, + dataset_type_node: DatasetTypeNode, + ) -> None: + super()._check_dataset_type(state, xgraph, dataset_type_node) + if dataset_type_node.is_prerequisite: + referencing_tasks = list(xgraph.successors(dataset_type_node.name)) + raise ConnectionTypeConsistencyError( + f"Dataset type {dataset_type_node.name!r} is an output of " + f"task {self.task_label!r}, but it was a prerequisite input to " + f"{referencing_tasks}." + ) + for existing_producer in xgraph.successors(dataset_type_node.name): + raise MultipleProducersError( + f"Dataset type {dataset_type_node.name} is produced by both {self.task_label!r} " + f"and {existing_producer!r}." + ) + + def _make_dataset_type_state(self, state: dict[str, Any]) -> dict[str, Any]: + connection: BaseConnection = state["connection"] + return dict( + instance=DatasetTypeNode( + self.parent_dataset_type_name, + is_calibration=connection.isCalibration, + is_init=state["is_init"], + is_prerequisite=False, + ), + bipartite=DatasetTypeNode.BIPARTITE_CONSTANT, + ) + + def _resolve_dataset_type( + self, + state: dict[str, Any], + current: DatasetType | None, + universe: DimensionUniverse, + ) -> DatasetType: + connection: BaseConnection = state["connection"] + dimensions = connection.resolve_dimensions( + universe, + current.dimensions if current is not None else None, + self.task_label, + ) + dataset_type = DatasetType( + self.parent_dataset_type_name, + dimensions, + storageClass=connection.storageClass, + isCalibration=connection.isCalibration, + ) + if current is not None: + if not current.is_compatible_with(dataset_type): + raise IncompatibleDatasetTypeError( + f"Incompatible definition for output dataset type {self.parent_dataset_type_name!r} " + f"from task {self.task_label!r}: {current} != {dataset_type}." + ) + return current + else: + return dataset_type diff --git a/python/lsst/pipe/base/pipeline_graph/_generic_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_generic_pipeline_graph.py index 903a317c9..175f0f19a 100644 --- a/python/lsst/pipe/base/pipeline_graph/_generic_pipeline_graph.py +++ b/python/lsst/pipe/base/pipeline_graph/_generic_pipeline_graph.py @@ -20,126 +20,36 @@ # along with this program. If not, see . from __future__ import annotations -__all__ = ( - "GenericPipelineGraph", - "SubgraphView", - "TasksView", - "DatasetTypesView", -) +__all__ = ("GenericPipelineGraph",) -from abc import abstractmethod from collections import ChainMap -from collections.abc import Iterable, Iterator, Mapping -from typing import Generic, TypeVar +from collections.abc import Mapping +from typing import Any, Generic import networkx -from ._vertices_and_edges import DatasetTypeVertex, ReadEdge, TaskVertex, WriteEdge - -_G = TypeVar("_G", bound=networkx.DiGraph, covariant=True) -_V = TypeVar("_V", bound=TaskVertex | DatasetTypeVertex, covariant=True) -_D = TypeVar("_D", bound=DatasetTypeVertex, covariant=True) -_T = TypeVar("_T", bound=TaskVertex, covariant=True) - - -class SubgraphView(Generic[_G]): - def __init__(self, parent_xgraph: networkx.DiGraph) -> None: - self._parent_xgraph = parent_xgraph - self._xgraph: _G | None = None - - @property - def xgraph(self) -> _G: - if self._xgraph is None: - self._xgraph = self._make_xgraph() - return self._xgraph - - @abstractmethod - def _make_xgraph(self) -> _G: - raise NotImplementedError() - - def _reset(self) -> None: - self._xgraph = None - - -class MappingSubgraphView(Generic[_G, _V], Mapping[str, _V], SubgraphView[_G]): - def __init__(self, parent_xgraph: networkx.DiGraph) -> None: - self._parent_xgraph = parent_xgraph - self._keys: list[str] | None = None - - @abstractmethod - def __contains__(self, key: object) -> bool: - raise NotImplementedError() - - def __iter__(self) -> Iterator[str]: - if self._keys is None: - self._keys = [k for k in self._parent_xgraph if k in self] - return iter(self._keys) - - def __getitem__(self, key: str) -> _V: - if key not in self: - raise KeyError(key) - return self._parent_xgraph[key]["instance"] - - def __len__(self) -> int: - if self._keys is None: - self._keys = [k for k in self._parent_xgraph if k in self] - return len(self._keys) - - def _reorder(self, parent_keys: Iterable[str]) -> None: - self._keys = [k for k in parent_keys if k in self] - - def _reset(self) -> None: - super()._reset() - self._keys = None - - -class TasksView(MappingSubgraphView[networkx.MultiDiGraph, _T]): - def __contains__(self, key: object) -> bool: - return ( - key in self._parent_xgraph - and self._parent_xgraph[key]["bipartite"] == TaskVertex.BIPARTITE_CONSTANT - ) - - def _make_xgraph(self) -> networkx.MultiDiGraph: - return networkx.freeze(networkx.bipartite.projected_graph(self._parent_xgraph, self, multigraph=True)) - - -class DatasetTypesView(MappingSubgraphView[networkx.DiGraph, _D]): - def __init__(self, parent_xgraph: networkx.DiGraph, is_init: bool): - super().__init__(parent_xgraph) - self._is_init = is_init - - def __contains__(self, key: object) -> bool: - if state := self._parent_xgraph.nodes.get(key): - return ( - state["bipartite"] == DatasetTypeVertex.BIPARTITE_CONSTANT - and state["is_init"] == self._is_init - ) - return False - - def _make_xgraph(self) -> networkx.DiGraph: - return networkx.freeze( - networkx.bipartite.projected_graph(self._parent_xgraph, self, multigraph=False) - ) +from ._dataset_types import _D, DatasetTypeSubgraphView +from ._edges import ReadEdge, WriteEdge +from ._tasks import _T, TaskSubgraphView class GenericPipelineGraph(Generic[_T, _D]): def __init__(self, xgraph: networkx.DiGraph) -> None: self._xgraph = xgraph - self._tasks = TasksView[_T](xgraph) - self._init_dataset_types = DatasetTypesView[_D](xgraph, True) - self._run_dataset_types = DatasetTypesView[_D](xgraph, False) + self._tasks = TaskSubgraphView[_T](xgraph) + self._init_dataset_types = DatasetTypeSubgraphView[_D](xgraph, True) + self._run_dataset_types = DatasetTypeSubgraphView[_D](xgraph, False) @property - def tasks(self) -> TasksView[_T]: + def tasks(self) -> TaskSubgraphView[_T]: return self._tasks @property - def init_dataset_types(self) -> DatasetTypesView[_D]: + def init_dataset_types(self) -> DatasetTypeSubgraphView[_D]: return self._init_dataset_types @property - def run_dataset_types(self) -> DatasetTypesView[_D]: + def run_dataset_types(self) -> DatasetTypeSubgraphView[_D]: return self._run_dataset_types @property @@ -162,3 +72,8 @@ def consumers_of(self, dataset_type_name: str) -> dict[str, ReadEdge]: return { task_label: edge for task_label, _, edge in self._xgraph.out_edges(dataset_type_name, "state") } + + def _reorder(self, other: GenericPipelineGraph[Any, Any]) -> None: + self._tasks._reorder(other._tasks) + self._init_dataset_types._reorder(other._init_dataset_types) + self._run_dataset_types._reorder(other._run_dataset_types) diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py index fc97aaab5..71390425e 100644 --- a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py +++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py @@ -20,35 +20,54 @@ # along with this program. If not, see . from __future__ import annotations -__all__ = ("PipelineGraph",) +__all__ = ("PipelineGraph", "MutablePipelineGraph", "ResolvedPipelineGraph") -from collections.abc import Iterable -from typing import TypeVar +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, final import networkx +from lsst.daf.butler import Registry +from ._abcs import Edge, Node +from ._dataset_types import _D, DatasetTypeNode, ResolvedDatasetTypeNode from ._exceptions import DatasetDependencyError from ._generic_pipeline_graph import GenericPipelineGraph -from ._vertices_and_edges import DatasetTypeVertex, TaskVertex +from ._tasks import _T, ResolvedTaskNode, TaskNode -_D = TypeVar("_D", bound=DatasetTypeVertex, covariant=True) -_T = TypeVar("_T", bound=TaskVertex, covariant=True) +if TYPE_CHECKING: + from ..config import PipelineTaskConfig + from ..connections import PipelineTaskConnections + from ..pipelineTask import PipelineTask class PipelineGraph(GenericPipelineGraph[_T, _D]): - def __init__( - self, - xgraph: networkx.DiGraph, - prerequisite_inputs: Iterable[str] = (), - is_sorted: bool = False, - ) -> None: - super().__init__(xgraph) - self._prerequisite_inputs = set(prerequisite_inputs) - self._is_sorted = is_sorted - @property + @abstractmethod def is_sorted(self) -> bool: - return self._is_sorted + raise NotImplementedError() + + @abstractmethod + def sort(self) -> None: + raise NotImplementedError() + + @abstractmethod + def resolved(self, registry: Registry) -> ResolvedPipelineGraph: + raise NotImplementedError() + + @abstractmethod + def copy(self) -> PipelineGraph: + raise NotImplementedError() + + @abstractmethod + def mutable_copy(self) -> MutablePipelineGraph: + raise NotImplementedError() + + +@final +class MutablePipelineGraph(PipelineGraph[TaskNode, DatasetTypeNode]): + def __init__(self, xgraph: networkx.DiGraph | None = None) -> None: + super().__init__(xgraph if xgraph is not None else networkx.DiGraph()) + self._is_sorted = False def sort(self) -> None: if not self._is_sorted: @@ -60,3 +79,91 @@ def sort(self) -> None: self._init_dataset_types._reorder(sorted_nodes) self._run_dataset_types._reorder(sorted_nodes) self._is_sorted = True + + @property + def is_sorted(self) -> bool: + return self._is_sorted + + def resolved(self, registry: Registry) -> ResolvedPipelineGraph: + return ResolvedPipelineGraph(self, registry) + + def copy(self) -> MutablePipelineGraph: + # We can get away with a shallow copy of the networkx graph because our + # own Node and Edge objects are all immutable, so it doesn't matter + # that they're shared. Everything else that's mutable + result = MutablePipelineGraph(self._xgraph.copy()) + result._reorder(self) + result._is_sorted = self._is_sorted + return result + + def mutable_copy(self) -> MutablePipelineGraph: + return self.copy() + + def add_task( + self, + label: str, + task_class: type[PipelineTask], + config: PipelineTaskConfig, + connections: PipelineTaskConnections | None = None, + ) -> None: + # Make the task node, the corresponding state dict that will be held + # by the networkx graph (which include the node instance), and the + # state dicts for the edges + task_node, task_state, edge_data = TaskNode._from_pipeline_task( + label, task_class, config, connections + ) + node_data: list[tuple[str, dict[str, Any]]] = [(label, task_state)] + for _, _, edge_state in edge_data: + edge: Edge = edge_state["instance"] + if edge.parent_dataset_type_name in self._xgraph: + dataset_type_node = self._xgraph[edge.parent_dataset_type_name]["instance"] + edge._check_dataset_type(edge_state, self._xgraph, dataset_type_node) + else: + node_data.append((edge.parent_dataset_type_name, edge._make_dataset_type_state(edge_state))) + + # Checks complete; time to start the actual modification. + try: + self._is_sorted = False + self.tasks._reset() + self.init_dataset_types._reset() + self.run_dataset_types._reset() + self._xgraph.add_nodes_from(node_data) + self._xgraph.add_edges_from(edge_data) + except Exception as err: + raise RuntimeError( + "Error during PipelineGraph modification has left the graph in an inconsistent state." + ) from err + + +@final +class ResolvedPipelineGraph(PipelineGraph[ResolvedTaskNode, ResolvedDatasetTypeNode]): + def __init__(self, original: PipelineGraph, registry: Registry): + original.sort() + super().__init__(original._xgraph.copy()) + for state in self._xgraph.nodes.values(): + node: Node = state["instance"] + node._resolve(state, self, registry) + self._reorder(original) + + @property + def is_sorted(self) -> bool: + return True + + def sort(self) -> None: + pass + + def resolved(self, registry: Registry) -> ResolvedPipelineGraph: + return self + + def copy(self) -> ResolvedPipelineGraph: + return self + + def mutable_copy(self) -> MutablePipelineGraph: + xgraph = self._xgraph.copy() + for state in xgraph.nodes.values(): + node: Node = state["instance"] + node._unresolve(state) + result = MutablePipelineGraph(xgraph) + result._reorder(self) + result._is_sorted = True + return result diff --git a/python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py deleted file mode 100644 index 4c4f725d0..000000000 --- a/python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py +++ /dev/null @@ -1,158 +0,0 @@ -# This file is part of pipe_base. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -from __future__ import annotations - -__all__ = ("ResolvedPipelineGraph",) - -from typing import TYPE_CHECKING - -from lsst.daf.butler import DatasetType, DimensionUniverse, Registry -from lsst.daf.butler.registry import MissingDatasetTypeError - -from ._exceptions import IncompatibleDatasetTypeError -from ._pipeline_graph import PipelineGraph -from ._unresolved_pipeline_graph import UnresolvedPipelineGraph -from ._vertices_and_edges import ResolvedDatasetTypeVertex, ResolvedTaskVertex, Vertex - -if TYPE_CHECKING: - from ..connectionTypes import BaseConnection - - -class ResolvedPipelineGraph(PipelineGraph[ResolvedTaskVertex, ResolvedDatasetTypeVertex]): - def __init__(self, unresolved: UnresolvedPipelineGraph, registry: Registry): - super().__init__( - unresolved._xgraph.copy(), - prerequisite_inputs=unresolved._prerequisite_inputs, - is_sorted=unresolved._is_sorted, - ) - for state in self._xgraph.nodes.values(): - vertex: Vertex = state["instance"] - vertex._resolve(state, self, registry) - - def _resolve_dataset_type(self, name: str, registry: Registry) -> DatasetType: - try: - dataset_type = registry.getDatasetType(name) - except MissingDatasetTypeError: - dataset_type = None - if (write_edge := self.producer_of(name)) is not None: - dataset_type = self._make_connection_dataset_type( - write_edge.task_label, - name, - current=dataset_type, - universe=registry.dimensions, - is_input=False, - ) - if dataset_type is not None: - # Registry and/or producer task have locked down the dataset type - # definition, so all we need to do here is check for consistency. - for consumer in self.consumers_of(name): - self._make_connection_dataset_type( - consumer, - name, - current=dataset_type, - universe=registry.dimensions, - is_input=False, - ) - else: - consumer_dataset_types = { - self._make_connection_dataset_type( - consumer, - name, - current=dataset_type, - universe=registry.dimensions, - is_input=False, - ) - for consumer in self.consumers_of(name) - } - if len(consumer_dataset_types) > 1: - raise MissingDatasetTypeError( - f"Dataset type {name!r} is not registered and not produced by this pipeline, " - f"but it used as an input by multiple tasks with different storage classes. " - f"This pipeline cannot be resolved until the dataset type is registered." - ) - assert len(consumer_dataset_types) == 1, "Can't have a dataset type node with no edges." - (dataset_type,) = consumer_dataset_types - return dataset_type - - def _make_connection_dataset_type( - self, - task_label: str, - dataset_type_name: str, - current: DatasetType | None, - universe: DimensionUniverse, - is_input: bool, - ) -> DatasetType: - task_state = self._xgraph[task_label] - connection: BaseConnection = task_state["connections"].allConnections[ - task_state["connection_reverse_mapping"][dataset_type_name] - ] - dimensions = connection.resolve_dimensions( - universe, - current.dimensions if current is not None else None, - task_label, - ) - parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) - assert parent_dataset_type_name == dataset_type_name, "Guaranteed by class invariants." - if component is not None: - if current is None: - raise MissingDatasetTypeError( - f"Dataset type {dataset_type_name!r} is not registered and not produced by this " - f"pipeline, but it used by task {task_label!r} via component {component!r}. " - f"This pipeline cannot be resolved until the parent dataset type is registered." - ) - if component not in current.storageClass.components: - raise IncompatibleDatasetTypeError( - f"Dataset type {parent_dataset_type_name!r} has storage class " - f"{current.storageClass_name!r}, which does not include a {component!r} " - f"as requested by task {task_label!r}." - ) - if current.storageClass.components[component].name != connection.storageClass: - raise IncompatibleDatasetTypeError( - f"Dataset type {dataset_type_name!r} has storage class " - f"{current.storageClass.components[component].name!r}, which does not match " - f"{connection.storageClass!r}, as requested by task {task_label!r}. " - "Note that storage class conversions of components are not supported." - ) - dataset_type = DatasetType( - parent_dataset_type_name, - dimensions, - storageClass=current.storageClass, - isCalibration=connection.isCalibration, - ) - - else: - dataset_type = connection.makeDatasetType(universe) - if current is not None: - if is_input: - if not dataset_type.is_compatible_with(current): - raise IncompatibleDatasetTypeError( - f"Incompatible definition for input dataset type {dataset_type_name!r} to task " - f"{task_label!r}: {current} != {dataset_type}." - ) - else: - if not current.is_compatible_with(dataset_type): - raise IncompatibleDatasetTypeError( - f"Incompatible definition for output dataset type {dataset_type_name!r} from task " - f"{task_label!r}: {current} != {dataset_type}." - ) - return current - else: - return dataset_type diff --git a/python/lsst/pipe/base/pipeline_graph/_tasks.py b/python/lsst/pipe/base/pipeline_graph/_tasks.py new file mode 100644 index 000000000..1c05b9bc8 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_tasks.py @@ -0,0 +1,181 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ( + "TaskNode", + "ResolvedTaskNode", + "TaskSubgraphView", +) + +import dataclasses +from collections import ChainMap +from collections.abc import Mapping, Set +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast + +import networkx +from lsst.daf.butler import DimensionGraph, Registry + +from ..connections import PipelineTaskConnections, iterConnections +from ._abcs import MappingSubgraphView, Node +from ._edges import ReadEdge, WriteEdge + +if TYPE_CHECKING: + from ..config import PipelineTaskConfig + from ..pipelineTask import PipelineTask + + +@dataclasses.dataclass(frozen=True, eq=False) +class TaskNode(Node): + + BIPARTITE_CONSTANT: ClassVar[int] = 0 + + label: str + task_class: type[PipelineTask] + config: PipelineTaskConfig + init_inputs: Mapping[str, ReadEdge] + prerequisite_inputs: Mapping[str, ReadEdge] + inputs: Mapping[str, ReadEdge] + init_outputs: Mapping[str, WriteEdge] + outputs: Mapping[str, WriteEdge] + + @property + def all_inputs(self) -> Mapping[str, ReadEdge]: + return ChainMap( + cast(dict, self.init_inputs), + cast(dict, self.prerequisite_inputs), + cast(dict, self.inputs), + ) + + @property + def all_outputs(self) -> Mapping[str, WriteEdge]: + return ChainMap( + cast(dict, self.init_outputs), + cast(dict, self.outputs), + ) + + @property + def init_dataset_types(self) -> Set[str]: + return self.init_inputs.keys() & self.init_outputs.keys() + + @property + def run_dataset_types(self) -> Set[str]: + return self.prerequisite_inputs.keys() & self.inputs.keys() & self.outputs.keys() + + @staticmethod + def _from_pipeline_task( + label: str, + task_class: type[PipelineTask], + config: PipelineTaskConfig, + connections: PipelineTaskConnections | None = None, + ) -> tuple[TaskNode, dict[str, Any], list[tuple[str, str, dict[str, Any]]]]: + if connections is None: + connections = task_class.ConfigClass.ConnectionsClass(config=config) + edge_data: list[tuple[str, str, dict[str, Any]]] = [] + init_inputs = { + c.name: ReadEdge._from_connection(label, c, edge_data, is_init=True) + for c in iterConnections(connections, "initInputs") + } + prerequisite_inputs = { + c.name: ReadEdge._from_connection(label, c, edge_data, is_init=False, is_prerequisite=True) + for c in iterConnections(connections, "prerequisiteInputs") + } + inputs = { + c.name: ReadEdge._from_connection(label, c, edge_data, is_init=False) + for c in iterConnections(connections, "inputs") + } + init_outputs = { + c.name: WriteEdge._from_connection(label, c, edge_data, is_init=True) + for c in iterConnections(connections, "initOutputs") + } + outputs = { + c.name: WriteEdge._from_connection(label, c, edge_data, is_init=False) + for c in iterConnections(connections, "outputs") + } + instance = TaskNode( + label=label, + task_class=task_class, + config=config, + init_inputs=init_inputs, + prerequisite_inputs=prerequisite_inputs, + inputs=inputs, + init_outputs=init_outputs, + outputs=outputs, + ) + return ( + instance, + dict( + instance=instance, + connections=connections, + bipartite=TaskNode.BIPARTITE_CONSTANT, + ), + edge_data, + ) + + def _resolve(self, state: dict[str, Any], graph: networkx.DiGraph, registry: Registry) -> None: + state[self.label] = ResolvedTaskNode( + label=self.label, + task_class=self.task_class, + config=self.config, + init_inputs=self.init_inputs, + prerequisite_inputs=self.prerequisite_inputs, + inputs=self.inputs, + init_outputs=self.init_outputs, + outputs=self.outputs, + dimensions=registry.dimensions.extract(state["connections"].dimensions), + ) + + def _unresolve(self, state: dict[str, Any]) -> None: + pass + + +@dataclasses.dataclass(frozen=True, eq=False) +class ResolvedTaskNode(TaskNode): + dimensions: DimensionGraph + + def _resolve(self, state: dict[str, Any], graph: networkx.DiGraph, registry: Registry) -> None: + pass + + def _unresolve(self, state: dict[str, Any]) -> None: + state[self.label] = TaskNode( + label=self.label, + task_class=self.task_class, + config=self.config, + init_inputs=self.init_inputs, + prerequisite_inputs=self.prerequisite_inputs, + inputs=self.inputs, + init_outputs=self.init_outputs, + outputs=self.outputs, + ) + + +_T = TypeVar("_T", bound=TaskNode, covariant=True) + + +class TaskSubgraphView(MappingSubgraphView[networkx.MultiDiGraph, _T]): + def __contains__(self, key: object) -> bool: + return ( + key in self._parent_xgraph + and self._parent_xgraph[key]["bipartite"] == TaskNode.BIPARTITE_CONSTANT + ) + + def _make_xgraph(self) -> networkx.MultiDiGraph: + return networkx.freeze(networkx.bipartite.projected_graph(self._parent_xgraph, self, multigraph=True)) diff --git a/python/lsst/pipe/base/pipeline_graph/_unresolved_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_unresolved_pipeline_graph.py deleted file mode 100644 index 285abcb53..000000000 --- a/python/lsst/pipe/base/pipeline_graph/_unresolved_pipeline_graph.py +++ /dev/null @@ -1,76 +0,0 @@ -# This file is part of pipe_base. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -from __future__ import annotations - -__all__ = ("UnresolvedPipelineGraph",) - -from typing import TYPE_CHECKING, Any - -import networkx - -from ._pipeline_graph import PipelineGraph -from ._vertices_and_edges import DatasetTypeVertex, Edge, TaskVertex - -if TYPE_CHECKING: - from ..config import PipelineTaskConfig - from ..connections import PipelineTaskConnections - from ..pipelineTask import PipelineTask - - -class UnresolvedPipelineGraph(PipelineGraph[TaskVertex, DatasetTypeVertex]): - def __init__(self) -> None: - super().__init__(networkx.DiGraph()) - - def add_task( - self, - label: str, - task_class: type[PipelineTask], - config: PipelineTaskConfig, - connections: PipelineTaskConnections | None = None, - ) -> None: - # Make the task vertex, the corresponding state dict that will be held - # by the networkx graph (which include the vertex instance), and the - # state dicts for the edges - task_vertex, task_state, edge_data = TaskVertex._from_pipeline_task( - label, task_class, config, connections - ) - vertex_data: list[tuple[str, dict[str, Any]]] = [(label, task_state)] - for _, _, edge_state in edge_data: - edge: Edge = edge_state["instance"] - if edge.parent_dataset_type_name in self._xgraph: - dataset_type_vertex = self._xgraph[edge.parent_dataset_type_name]["instance"] - edge._check_dataset_type(edge_state, self._xgraph, dataset_type_vertex) - else: - vertex_data.append((edge.parent_dataset_type_name, edge._make_dataset_type_state(edge_state))) - - # Checks complete; time to start the actual modification. - try: - self._is_sorted = False - self.tasks._reset() - self.init_dataset_types._reset() - self.run_dataset_types._reset() - self._xgraph.add_nodes_from(vertex_data) - self._xgraph.add_edges_from(edge_data) - self._prerequisite_inputs.update(task_vertex.prerequisite_inputs) - except Exception as err: - raise RuntimeError( - "Error during PipelineGraph modification has left the graph in an inconsistent state." - ) from err diff --git a/python/lsst/pipe/base/pipeline_graph/_vertices_and_edges.py b/python/lsst/pipe/base/pipeline_graph/_vertices_and_edges.py deleted file mode 100644 index 2d1891df4..000000000 --- a/python/lsst/pipe/base/pipeline_graph/_vertices_and_edges.py +++ /dev/null @@ -1,546 +0,0 @@ -# This file is part of pipe_base. -# -# Developed for the LSST Data Management System. -# This product includes software developed by the LSST Project -# (http://www.lsst.org). -# See the COPYRIGHT file at the top-level directory of this distribution -# for details of code ownership. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -from __future__ import annotations - -__all__ = ( - "Vertex", - "TaskVertex", - "ResolvedTaskVertex", - "DatasetTypeVertex", - "ResolvedDatasetTypeVertex", - "Edge", - "WriteEdge", - "ReadEdge", -) - -import dataclasses -import itertools -from abc import ABC, abstractmethod -from collections import ChainMap -from collections.abc import Callable, Iterable, Mapping, Sequence, Set -from typing import TYPE_CHECKING, Any, ClassVar, cast - -import networkx -from lsst.daf.butler import ( - DataCoordinate, - DatasetRef, - DatasetType, - DimensionGraph, - DimensionUniverse, - Registry, -) -from lsst.daf.butler.registry import MissingDatasetTypeError - -from ..connections import PipelineTaskConnections, iterConnections -from ._exceptions import ConnectionTypeConsistencyError, IncompatibleDatasetTypeError, MultipleProducersError - -if TYPE_CHECKING: - from ..config import PipelineTaskConfig - from ..connectionTypes import BaseConnection - from ..pipelineTask import PipelineTask - - -class Vertex(ABC): - - BIPARTITE_CONSTANT: ClassVar[int] - - @abstractmethod - def _resolve(self, state: dict[str, Any], graph: networkx.DiGraph, registry: Registry) -> None: - raise NotImplementedError() - - -@dataclasses.dataclass(frozen=True, eq=False) -class TaskVertex: - - BIPARTITE_CONSTANT: ClassVar[int] = 0 - - label: str - task_class: type[PipelineTask] - config: PipelineTaskConfig - init_inputs: Mapping[str, ReadEdge] - prerequisite_inputs: Mapping[str, ReadEdge] - inputs: Mapping[str, ReadEdge] - init_outputs: Mapping[str, WriteEdge] - outputs: Mapping[str, WriteEdge] - - @property - def all_inputs(self) -> Mapping[str, ReadEdge]: - return ChainMap( - cast(dict, self.init_inputs), - cast(dict, self.prerequisite_inputs), - cast(dict, self.inputs), - ) - - @property - def all_outputs(self) -> Mapping[str, WriteEdge]: - return ChainMap( - cast(dict, self.init_outputs), - cast(dict, self.outputs), - ) - - @property - def init_dataset_types(self) -> Set[str]: - return self.init_inputs.keys() & self.init_outputs.keys() - - @property - def run_dataset_types(self) -> Set[str]: - return self.prerequisite_inputs.keys() & self.inputs.keys() & self.outputs.keys() - - @staticmethod - def _from_pipeline_task( - label: str, - task_class: type[PipelineTask], - config: PipelineTaskConfig, - connections: PipelineTaskConnections | None = None, - ) -> tuple[TaskVertex, dict[str, Any], list[tuple[str, str, dict[str, Any]]]]: - if connections is None: - connections = task_class.ConfigClass.ConnectionsClass(config=config) - edge_data: list[tuple[str, str, dict[str, Any]]] = [] - init_inputs = { - c.name: ReadEdge._from_connection(label, c, edge_data, is_init=True) - for c in iterConnections(connections, "initInputs") - } - prerequisite_inputs = { - c.name: ReadEdge._from_connection(label, c, edge_data, is_init=False, is_prerequisite=True) - for c in iterConnections(connections, "prerequisiteInputs") - } - inputs = { - c.name: ReadEdge._from_connection(label, c, edge_data, is_init=False) - for c in iterConnections(connections, "inputs") - } - init_outputs = { - c.name: WriteEdge._from_connection(label, c, edge_data, is_init=True) - for c in iterConnections(connections, "initOutputs") - } - outputs = { - c.name: WriteEdge._from_connection(label, c, edge_data, is_init=False) - for c in iterConnections(connections, "outputs") - } - instance = TaskVertex( - label=label, - task_class=task_class, - config=config, - init_inputs=init_inputs, - prerequisite_inputs=prerequisite_inputs, - inputs=inputs, - init_outputs=init_outputs, - outputs=outputs, - ) - return ( - instance, - dict( - instance=instance, - connections=connections, - bipartite=TaskVertex.BIPARTITE_CONSTANT, - ), - edge_data, - ) - - def _resolve(self, state: dict[str, Any], graph: networkx.DiGraph, registry: Registry) -> None: - state[self.label] = ResolvedTaskVertex( - label=self.label, - task_class=self.task_class, - config=self.config, - init_inputs=self.init_inputs, - prerequisite_inputs=self.prerequisite_inputs, - inputs=self.inputs, - init_outputs=self.init_outputs, - outputs=self.outputs, - dimensions=registry.dimensions.extract(state["connections"].dimensions), - ) - - -@dataclasses.dataclass(frozen=True, eq=False) -class ResolvedTaskVertex(TaskVertex): - dimensions: DimensionGraph - - def _resolve(self, state: dict[str, Any], graph: networkx.DiGraph, registry: Registry) -> None: - pass - - -@dataclasses.dataclass(frozen=True, eq=False) -class DatasetTypeVertex: - BIPARTITE_CONSTANT: ClassVar[int] = 1 - - name: str - is_calibration: bool - is_init: bool - is_prerequisite: bool - - def _resolve(self, state: dict[str, Any], graph: networkx.DiGraph, registry: Registry) -> None: - try: - dataset_type = registry.getDatasetType(self.name) - except MissingDatasetTypeError: - dataset_type = None - for edge_state in itertools.chain( - graph.in_edges(self.name, data=True), graph.out_edges(self.name, data=True) - ): - edge: Edge = edge_state["instance"] - dataset_type = edge._resolve_dataset_type( - edge_state, - current=dataset_type, - universe=registry.dimensions, - ) - assert dataset_type is not None, "Graph structure guarantees at least one edge." - state[self.name] = ResolvedDatasetTypeVertex( - name=self.name, - is_calibration=self.is_calibration, - is_init=self.is_init, - is_prerequisite=self.is_prerequisite, - dataset_type=dataset_type, - ) - - -@dataclasses.dataclass(frozen=True, eq=False) -class ResolvedDatasetTypeVertex(DatasetTypeVertex): - dataset_type: DatasetType - - def _resolve(self, state: dict[str, Any], graph: networkx.DiGraph, registry: Registry) -> None: - pass - - -class Edge(ABC): - - task_label: str - parent_dataset_type_name: str - - @property - @abstractmethod - def key(self) -> tuple[str, str]: - raise NotImplementedError() - - @property - def dataset_type_name(self) -> str: - return self.parent_dataset_type_name - - @classmethod - @abstractmethod - def _from_connection( - cls, - task_label: str, - connection: BaseConnection, - edge_data: list[tuple[str, str, dict[str, Any]]], - *, - is_init: bool, - is_prerequisite: bool = False, - ) -> Edge: - raise NotImplementedError() - - @abstractmethod - def _check_dataset_type( - self, - state: dict[str, Any], - xgraph: networkx.DiGraph, - dataset_type_vertex: DatasetTypeVertex, - ) -> None: - if state["is_init"] != dataset_type_vertex.is_init: - referencing_tasks = list( - itertools.chain( - xgraph.predecessors(dataset_type_vertex.name), - xgraph.successors(dataset_type_vertex.name), - ) - ) - if state["is_init"]: - raise ConnectionTypeConsistencyError( - f"{dataset_type_vertex.name!r} is an init dataset in task {self.task_label!r}, " - f"but a run dataset in task(s) {referencing_tasks}." - ) - else: - raise ConnectionTypeConsistencyError( - f"{dataset_type_vertex.name!r} is a run dataset in task {self.task_label!r}, " - f"but an init dataset in task(s) {referencing_tasks}." - ) - if state["is_prerequisite"] != dataset_type_vertex.is_prerequisite: - referencing_tasks = list(xgraph.successors(dataset_type_vertex.name)) - if state["is_prerequisite"]: - raise ConnectionTypeConsistencyError( - f"Dataset type {dataset_type_vertex.name!r} is a prerequisite input in " - f"task {self.task_label!r}, but it was not a prerequisite to " - f"{referencing_tasks}." - ) - else: - raise ConnectionTypeConsistencyError( - f"Dataset type {dataset_type_vertex.name!r} is not a prerequisite input in " - f"task {self.task_label!r}, but it was a prerequisite to " - f"{referencing_tasks}." - ) - connection: BaseConnection = state["connection"] - if connection.isCalibration != dataset_type_vertex.is_calibration: - referencing_tasks = list( - itertools.chain( - xgraph.predecessors(dataset_type_vertex.name), - xgraph.successors(dataset_type_vertex.name), - ) - ) - if connection.isCalibration: - raise IncompatibleDatasetTypeError( - f"Dataset type {dataset_type_vertex.name!r} is a calibration in " - f"task {self.task_label}, but it was not in task(s) {referencing_tasks}." - ) - else: - raise IncompatibleDatasetTypeError( - f"Dataset type {dataset_type_vertex.name!r} is not a calibration in " - f"task {self.task_label}, but it was in task(s) {referencing_tasks}." - ) - - @abstractmethod - def _make_dataset_type_state(self, state: dict[str, Any]) -> dict[str, Any]: - raise NotImplementedError() - - @abstractmethod - def _resolve_dataset_type( - self, - state: dict[str, Any], - current: DatasetType | None, - universe: DimensionUniverse, - ) -> DatasetType: - raise NotImplementedError() - - -@dataclasses.dataclass(frozen=True, eq=False) -class ReadEdge(Edge): - parent_dataset_type_name: str - task_label: str - component: str | None - storage_class_name: str - lookup_function: Callable[ - [DatasetType, Registry, DataCoordinate, Sequence[str]], Iterable[DatasetRef] - ] | None = None - - @property - def key(self) -> tuple[str, str]: - return (self.parent_dataset_type_name, self.task_label) - - @property - def dataset_type_name(self) -> str: - return f"{self.parent_dataset_type_name}.{self.component}" - - @classmethod - def _from_connection( - cls, - task_label: str, - connection: BaseConnection, - edge_data: list[tuple[str, str, dict[str, Any]]], - *, - is_init: bool, - is_prerequisite: bool = False, - ) -> ReadEdge: - parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) - instance = cls( - task_label, - parent_dataset_type_name, - component, - connection.storageClass, - ) - edge_data.append( - ( - instance.parent_dataset_type_name, - instance.task_label, - dict( - instance=instance, connection=connection, is_init=is_init, is_prerequisite=is_prerequisite - ), - ) - ) - return instance - - def _check_dataset_type( - self, - state: dict[str, Any], - xgraph: networkx.DiGraph, - dataset_type_vertex: DatasetTypeVertex, - ) -> None: - super()._check_dataset_type(state, xgraph, dataset_type_vertex) - if state["is_prerequisite"] != dataset_type_vertex.is_prerequisite: - referencing_tasks = list(xgraph.successors(dataset_type_vertex.name)) - if state["is_prerequisite"]: - raise ConnectionTypeConsistencyError( - f"Dataset type {dataset_type_vertex.name!r} is a prerequisite input in " - f"task {self.task_label!r}, but it was a regular input to " - f"{referencing_tasks}." - ) - else: - raise ConnectionTypeConsistencyError( - f"Dataset type {dataset_type_vertex.name!r} is a regular input in " - f"task {self.task_label!r}, but it was a prerequisite input to " - f"{referencing_tasks}." - ) - - def _make_dataset_type_state(self, state: dict[str, Any]) -> dict[str, Any]: - connection: BaseConnection = state["connection"] - return dict( - instance=DatasetTypeVertex( - self.parent_dataset_type_name, - is_calibration=connection.isCalibration, - is_init=state["is_init"], - is_prerequisite=state["is_prerequisite"], - ), - bipartite=DatasetTypeVertex.BIPARTITE_CONSTANT, - ) - - def _resolve_dataset_type( - self, - state: dict[str, Any], - current: DatasetType | None, - universe: DimensionUniverse, - ) -> DatasetType: - connection: BaseConnection = state["connection"] - dimensions = connection.resolve_dimensions( - universe, - current.dimensions if current is not None else None, - self.task_label, - ) - if self.component is not None: - if current is None: - raise MissingDatasetTypeError( - f"Dataset type {self.parent_dataset_type_name!r} is not registered and not produced by " - f"this pipeline, but it used by task {self.task_label!r}, via component " - f"{self.component!r}. This pipeline cannot be resolved until the parent dataset type is " - "registered." - ) - if self.component not in current.storageClass.components: - raise IncompatibleDatasetTypeError( - f"Dataset type {self.parent_dataset_type_name!r} has storage class " - f"{current.storageClass_name!r}, which does not include component {self.component!r} " - f"as requested by task {self.task_label!r}." - ) - if current.storageClass.components[self.component].name != connection.storageClass: - raise IncompatibleDatasetTypeError( - f"Dataset type '{self.parent_dataset_type_name}.{self.component}' has storage class " - f"{current.storageClass.components[self.component].name!r}, which does not match " - f"{connection.storageClass!r}, as requested by task {self.task_label!r}. " - "Note that storage class conversions of components are not supported." - ) - return current - else: - dataset_type = DatasetType( - self.parent_dataset_type_name, - dimensions, - storageClass=connection.storageClass, - isCalibration=connection.isCalibration, - ) - if current is not None: - if not dataset_type.is_compatible_with(current): - raise IncompatibleDatasetTypeError( - f"Incompatible definition for input dataset type {self.parent_dataset_type_name!r} " - f"to task {self.task_label!r}: {current} != {dataset_type}." - ) - return current - else: - return dataset_type - - -@dataclasses.dataclass(frozen=True, eq=False) -class WriteEdge(Edge): - task_label: str - parent_dataset_type_name: str - storage_class_name: str - - @property - def key(self) -> tuple[str, str]: - return (self.task_label, self.parent_dataset_type_name) - - @classmethod - def _from_connection( - cls, - task_label: str, - connection: BaseConnection, - edge_data: list[tuple[str, str, dict[str, Any]]], - *, - is_init: bool, - is_prerequisite: bool = False, - ) -> WriteEdge: - parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) - if component is not None: - raise ValueError(f"Illegal output component dataset {connection.name!r} in task {task_label!r}.") - instance = cls(task_label, parent_dataset_type_name, connection.storageClass) - edge_data.append( - ( - instance.task_label, - instance.parent_dataset_type_name, - dict( - instance=instance, - connection=connection, - is_init=is_init, - is_prerequisite=is_prerequisite, - ), - ) - ) - return instance - - def _check_dataset_type( - self, - state: dict[str, Any], - xgraph: networkx.DiGraph, - dataset_type_vertex: DatasetTypeVertex, - ) -> None: - super()._check_dataset_type(state, xgraph, dataset_type_vertex) - if dataset_type_vertex.is_prerequisite: - referencing_tasks = list(xgraph.successors(dataset_type_vertex.name)) - raise ConnectionTypeConsistencyError( - f"Dataset type {dataset_type_vertex.name!r} is an output of " - f"task {self.task_label!r}, but it was a prerequisite input to " - f"{referencing_tasks}." - ) - for existing_producer in xgraph.successors(dataset_type_vertex.name): - raise MultipleProducersError( - f"Dataset type {dataset_type_vertex.name} is produced by both {self.task_label!r} " - f"and {existing_producer!r}." - ) - - def _make_dataset_type_state(self, state: dict[str, Any]) -> dict[str, Any]: - connection: BaseConnection = state["connection"] - return dict( - instance=DatasetTypeVertex( - self.parent_dataset_type_name, - is_calibration=connection.isCalibration, - is_init=state["is_init"], - is_prerequisite=False, - ), - bipartite=DatasetTypeVertex.BIPARTITE_CONSTANT, - ) - - def _resolve_dataset_type( - self, - state: dict[str, Any], - current: DatasetType | None, - universe: DimensionUniverse, - ) -> DatasetType: - connection: BaseConnection = state["connection"] - dimensions = connection.resolve_dimensions( - universe, - current.dimensions if current is not None else None, - self.task_label, - ) - dataset_type = DatasetType( - self.parent_dataset_type_name, - dimensions, - storageClass=connection.storageClass, - isCalibration=connection.isCalibration, - ) - if current is not None: - if not current.is_compatible_with(dataset_type): - raise IncompatibleDatasetTypeError( - f"Incompatible definition for output dataset type {self.parent_dataset_type_name!r} " - f"from task {self.task_label!r}: {current} != {dataset_type}." - ) - return current - else: - return dataset_type