From 44ac0303fee3f841d4caf768189c02ecdd4d382f Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Thu, 14 Jul 2022 15:04:27 -0400 Subject: [PATCH] WIP: add PipelineGraph --- python/lsst/pipe/base/__init__.py | 6 +- python/lsst/pipe/base/pipeline.py | 19 +- .../lsst/pipe/base/pipeline_graph/__init__.py | 27 +++ .../pipe/base/pipeline_graph/_exceptions.py | 54 +++++ .../pipeline_graph/_generic_pipeline_graph.py | 223 ++++++++++++++++++ .../base/pipeline_graph/_pipeline_graph.py | 130 ++++++++++ .../_resolved_pipeline_graph.py | 171 ++++++++++++++ .../pipeline_graph/_vertices_and_edges.py | 221 +++++++++++++++++ 8 files changed, 845 insertions(+), 6 deletions(-) create mode 100644 python/lsst/pipe/base/pipeline_graph/__init__.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_exceptions.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_generic_pipeline_graph.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py create mode 100644 python/lsst/pipe/base/pipeline_graph/_vertices_and_edges.py diff --git a/python/lsst/pipe/base/__init__.py b/python/lsst/pipe/base/__init__.py index 0a20ff356..25ac614cf 100644 --- a/python/lsst/pipe/base/__init__.py +++ b/python/lsst/pipe/base/__init__.py @@ -1,4 +1,4 @@ -from . import connectionTypes, pipelineIR +from . import connectionTypes, pipeline_graph, pipelineIR from ._dataset_handle import * from ._instrument import * from ._status import * @@ -10,6 +10,10 @@ from .graph import * from .graphBuilder import * from .pipeline import * + +# We import the main PipelineGraph types and the module (above), but we don't +# lift all symbols to package scope. +from .pipeline_graph import PipelineGraph, ResolvedPipelineGraph from .pipelineTask import * from .struct import * from .task import * diff --git a/python/lsst/pipe/base/pipeline.py b/python/lsst/pipe/base/pipeline.py index c4107e7c2..8462ca9f2 100644 --- a/python/lsst/pipe/base/pipeline.py +++ b/python/lsst/pipe/base/pipeline.py @@ -62,11 +62,12 @@ from lsst.utils import doImportType from lsst.utils.introspection import get_full_type_name -from . import pipelineIR, pipeTools +from . import pipelineIR from ._task_metadata import TaskMetadata from .config import PipelineTaskConfig from .configOverrides import ConfigOverrides from .connections import iterConnections +from .pipeline_graph import PipelineGraph from .pipelineTask import PipelineTask from .task import _TASK_METADATA_TYPE @@ -660,7 +661,7 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]: If a dataId is supplied in a config block. This is in place for future use """ - taskDefs = [] + taskDefs: list[TaskDef] = [] for label in self._pipelineIR.tasks: taskDefs.append(self._buildTaskDef(label)) @@ -676,9 +677,17 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]: raise pipelineIR.ContractError( f"Contract(s) '{contract.contract}' were not satisfied{extra_info}" ) - - taskDefs = sorted(taskDefs, key=lambda x: x.label) - yield from pipeTools.orderPipeline(taskDefs) + graph = PipelineGraph() + for task_def in taskDefs: + graph.add_task( + task_def.label, + task_def.taskClass, + config=task_def.config, + connections=task_def.connections, + ) + graph.sort() + sort_map = {label: n for n, label in enumerate(graph.tasks)} + yield from sorted(taskDefs, key=lambda x: sort_map[x.label]) def _buildTaskDef(self, label: str) -> TaskDef: if (taskIR := self._pipelineIR.tasks.get(label)) is None: diff --git a/python/lsst/pipe/base/pipeline_graph/__init__.py b/python/lsst/pipe/base/pipeline_graph/__init__.py new file mode 100644 index 000000000..5c077d811 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/__init__.py @@ -0,0 +1,27 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +from ._exceptions import * +from ._generic_pipeline_graph import * +from ._pipeline_graph import * +from ._resolved_pipeline_graph import * +from ._vertices_and_edges import * diff --git a/python/lsst/pipe/base/pipeline_graph/_exceptions.py b/python/lsst/pipe/base/pipeline_graph/_exceptions.py new file mode 100644 index 000000000..a8eba5da5 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_exceptions.py @@ -0,0 +1,54 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ( + "PipelineGraphError", + "MultipleProducersError", + "DatasetDependencyError", + "ConnectionTypeConsistencyError", + "IncompatibleDatasetTypeError", + "VertexTypeError", +) + + +class PipelineGraphError(RuntimeError): + pass + + +class MultipleProducersError(PipelineGraphError): + pass + + +class DatasetDependencyError(PipelineGraphError): + pass + + +class ConnectionTypeConsistencyError(PipelineGraphError): + pass + + +class IncompatibleDatasetTypeError(PipelineGraphError): + pass + + +class VertexTypeError(KeyError): + pass diff --git a/python/lsst/pipe/base/pipeline_graph/_generic_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_generic_pipeline_graph.py new file mode 100644 index 000000000..730cb05ab --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_generic_pipeline_graph.py @@ -0,0 +1,223 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ( + "GenericPipelineGraph", + "TasksView", + "InitDatasetTypesView", + "RunDatasetTypesView", +) + +import itertools +from abc import abstractmethod +from collections import ChainMap +from collections.abc import Iterable, Iterator, Mapping +from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar + +import networkx +from lsst.utils.classes import cached_getter + +from ._exceptions import DatasetDependencyError + +if TYPE_CHECKING: + from ._vertices_and_edges import DatasetTypeVertex, ReadEdge, TaskVertex, WriteEdge + + +_D = TypeVar("_D", bound=DatasetTypeVertex, covariant=True) +_T = TypeVar("_T", bound=TaskVertex, covariant=True) + + +class TasksView(Mapping[str, _T]): + def __init__(self, parent: GenericPipelineGraph[_T, _D]): + self._parent = parent + + def __iter__(self) -> Iterator[str]: + return iter(self._parent._task_labels) + + def __getitem__(self, key: str) -> _T: + return self._parent._make_task_vertex(key) + + def __len__(self) -> int: + return len(self._parent._task_labels) + + @property + @cached_getter + def xgraph(self) -> networkx.MultiDiGraph: + self._parent._get_task_xgraph() + + @xgraph.deleter + def xgraph(self) -> None: + if hasattr(self._parent, "_cached_get_task_xggraph"): + delattr(self._parent, "_cached_get_task_xggraph") + + +class InitDatasetTypesView(Mapping[str, _D]): + def __init__(self, parent: GenericPipelineGraph[_T, _D]): + self._parent = parent + + def __iter__(self) -> Iterator[str]: + return iter(self._parent._init_dataset_types) + + def __getitem__(self, key: str) -> _D: + return self._parent._make_dataset_type_vertex(key, is_init=True) + + def __len__(self) -> int: + return len(self._parent._init_dataset_types) + + @property + @cached_getter + def xgraph(self) -> networkx.DiGraph: + self._parent._get_init_dataset_type_xgraph() + + @xgraph.deleter + def xgraph(self) -> None: + if hasattr(self._parent, "_cached_get_init_dataset_type_xggraph"): + delattr(self._parent, "_cached_get_init_dataset_type_xggraph") + + +class RunDatasetTypesView(Mapping[str, _D]): + def __init__(self, parent: GenericPipelineGraph[_T, _D]): + self._parent = parent + + def __iter__(self) -> Iterator[str]: + return iter(self._parent._run_dataset_types) + + def __getitem__(self, key: str) -> _D: + return self._parent._make_dataset_type_vertex(key, is_init=False) + + def __len__(self) -> int: + return len(self._parent._run_dataset_types) + + @property + @cached_getter + def xgraph(self) -> networkx.DiGraph: + self._parent._get_run_dataset_type_xgraph() + + @xgraph.deleter + def xgraph(self) -> None: + if hasattr(self._parent, "_cached_get_run_dataset_type_xggraph"): + delattr(self._parent, "_cached_get_run_dataset_type_xggraph") + + +class GenericPipelineGraph(Generic[_T, _D]): + def __init__( + self, + xgraph: networkx.DiGraph | None = None, + *, + task_labels: Iterable[str] = (), + init_dataset_types: Iterable[str] = (), + run_dataset_types: Iterable[str] = (), + prerequisite_inputs: Iterable[str] = (), + is_sorted: bool = False, + ) -> None: + self._xgraph = networkx.DiGraph() if xgraph is None else xgraph.copy() + self._task_labels = list(task_labels) + self._init_dataset_types = list(init_dataset_types) + self._run_dataset_types = list(run_dataset_types) + self._prerequisite_inputs = set(prerequisite_inputs) + self._is_sorted = is_sorted + + TASK_BIPARTITE: ClassVar[int] = 0 + DATASET_TYPE_BIPARTITE: ClassVar[int] = 1 + + @property + def tasks(self) -> TasksView[_T]: + return TasksView(self) + + @property + def init_dataset_types(self) -> InitDatasetTypesView[_D]: + return InitDatasetTypesView(self) + + @property + def run_dataset_types(self) -> RunDatasetTypesView[_D]: + return RunDatasetTypesView(self) + + @property + def dataset_types(self) -> Mapping[str, _D]: + # ChainMap is a MutableMapping, so it wants to be passed + # MutableMappings, but we've annotated the return type as just + # Mapping to ensure it really only needs Mappings. + return ChainMap(self.init_dataset_types, self.run_dataset_types) # type: ignore + + @property + @cached_getter + def xgraph(self) -> networkx.DiGraph: + return self._xgraph.copy(as_view=True) + + @property + def is_sorted(self) -> bool: + return self._is_sorted + + def sort(self) -> None: + if not self._is_sorted: + task_labels = [] + init_dataset_types = [] + run_dataset_types = [] + try: + for k in networkx.lexicographical_topological_sort(self._xgraph): + match self._xgraph[k]: + case {"bipartite": self.TASK_BIPARTITE}: + task_labels.append(k) + case {"bipartite": self.DATASET_TYPE_BIPARTITE, "is_init": True}: + init_dataset_types.append(k) + case {"bipartite": self.DATASET_TYPE_BIPARTITE, "is_init": False}: + run_dataset_types.append(k) + except networkx.NetworkXUnfeasible as err: + raise DatasetDependencyError("Cycle detected while attempting to sort graph.") from err + self._task_labels = task_labels + self._init_dataset_types = init_dataset_types + self._run_dataset_types = run_dataset_types + self._is_sorted = True + + def producer_of(self, dataset_type_name: str) -> WriteEdge | None: + for _, _, edge in self._xgraph.in_edges(dataset_type_name, "state"): + return edge + return None + + def consumers_of(self, dataset_type_name: str) -> dict[str, ReadEdge]: + return { + task_label: edge for task_label, _, edge in self._xgraph.out_edges(dataset_type_name, "state") + } + + @abstractmethod + def _make_task_vertex(self, task_label: str) -> _T: + raise NotImplementedError() + + @abstractmethod + def _make_dataset_type_vertex(self, dataset_type_name: str, is_init: bool) -> _D: + raise NotImplementedError() + + @cached_getter + def _get_task_xgraph(self) -> networkx.MultiDiGraph: + return networkx.freeze( + networkx.bipartite.projected_graph(self._xgraph, self._task_labels, multigraph=True) + ) + + @cached_getter + def _get_init_dataset_type_xgraph(self) -> networkx.DiGraph: + init_subgraph = self._xgraph.subgraph(itertools.chain(self._task_labels, self._init_dataset_types)) + return networkx.freeze(networkx.bipartite.projected_graph(init_subgraph, self._init_dataset_types)) + + @cached_getter + def _get_run_dataset_type_xgraph(self) -> networkx.DiGraph: + run_subgraph = self._xgraph.subgraph(itertools.chain(self._task_labels, self._init_dataset_types)) + return networkx.freeze(networkx.bipartite.projected_graph(run_subgraph, self._run_dataset_types)) diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py new file mode 100644 index 000000000..7451f8e20 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py @@ -0,0 +1,130 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ("PipelineGraph",) + +import itertools +from typing import TYPE_CHECKING + +from ..connections import PipelineTaskConnections +from ._exceptions import ConnectionTypeConsistencyError, MultipleProducersError +from ._generic_pipeline_graph import GenericPipelineGraph +from ._vertices_and_edges import DatasetTypeVertex, TaskVertex + +if TYPE_CHECKING: + from ..config import PipelineTaskConfig + from ..pipelineTask import PipelineTask + + +class PipelineGraph(GenericPipelineGraph[TaskVertex, DatasetTypeVertex]): + def __init__(self) -> None: + super().__init__() + + def add_task( + self, + label: str, + task_class: type[PipelineTask], + config: PipelineTaskConfig, + connections: PipelineTaskConnections | None = None, + ) -> None: + # Invalidate the sorted list of dataset type names and all cached + # subgraph views. + self._is_sorted = False + del self.xgraph + del self.tasks.xgraph + del self.init_dataset_types.xgraph + del self.run_dataset_types.xgraph + # Add the task node itself and then get a vertex view back for ease + # of use later in this method. + self._task_labels.append(label) + state = TaskVertex._state_from_pipeline_task(label, task_class, config, connections=connections) + self._xgraph.add_node(label, **state) + task = self._make_task_vertex(label) + # Add [init_]input edges to the graph. + for read_edge in itertools.chain(task.init_inputs.values(), task.inputs.values()): + if read_edge.parent_dataset_type_name in self._prerequisite_inputs: + raise ConnectionTypeConsistencyError( + f"Dataset type {read_edge.parent_dataset_type_name!r} is not a prerequisite input in " + f"task {task.label!r}, but it was a prerequisite to " + f"{self.consumers_of(read_edge.parent_dataset_type_name)}." + ) + self._xgraph.add_edge(read_edge.parent_dataset_type_name, read_edge.task_label, state=read_edge) + self._xgraph.nodes[read_edge.parent_dataset_type_name]["bipartite"] = self.DATASET_TYPE_BIPARTITE + # Add prerequisite input edges to the graph. + for read_edge in task.prerequisite_inputs.values(): + if ( + read_edge.parent_dataset_type_name in self._run_dataset_types + and read_edge.parent_dataset_type_name not in self._prerequisite_inputs + ): + if producer := self.producer_of(read_edge.parent_dataset_type_name): + details = f"an output of task {producer!r}" + else: + details = ( + f"a regular input of tasks {self.consumers_of(read_edge.parent_dataset_type_name)}" + ) + raise ConnectionTypeConsistencyError( + f"Dataset type {read_edge.parent_dataset_type_name!r} is a prerequisite input to task " + f"{task.label!r}, but it is {details}." + ) + self._xgraph.add_edge(read_edge.parent_dataset_type_name, read_edge.task_label, state=read_edge) + self._xgraph.nodes[read_edge.parent_dataset_type_name]["bipartite"] = self.DATASET_TYPE_BIPARTITE + # Add [init_]output edges to the graph. + for write_edge in itertools.chain(task.init_outputs.values(), task.outputs.values()): + if write_edge.dataset_type_name in self._prerequisite_inputs: + raise ConnectionTypeConsistencyError( + f"Dataset type {write_edge.dataset_type_name!r} is an output of task {task.label!r}, " + f"but it was a prerequisite input to {self.consumers_of(write_edge.dataset_type_name)}." + ) + if (existing_producer := self.producer_of(write_edge.dataset_type_name)) is not None: + raise MultipleProducersError( + f"Dataset type {write_edge.dataset_type_name} is produced by both {task.label!r} and " + f"{existing_producer!r}." + ) + self._xgraph.add_edge(write_edge.task_label, write_edge.dataset_type_name, state=write_edge) + self._xgraph.nodes[write_edge.dataset_type_name]["bipartite"] = self.DATASET_TYPE_BIPARTITE + # Update boolean flags for init vs. run to state dicts and check for + # consistency. + for dataset_type_name in task.init_dataset_types: + if not self._xgraph[dataset_type_name].setdefault("is_init", True): + raise ConnectionTypeConsistencyError( + f"Dataset type {dataset_type_name!r} is init-time in task {task.label!r}, " + "but run-time in the current pipeline." + ) + for dataset_type_name in task.run_dataset_types: + if self._xgraph[dataset_type_name].setdefault("is_init", False): + raise ConnectionTypeConsistencyError( + f"Dataset type {dataset_type_name!r} is run-time in task {task.label!r}, " + "but init-time in the current pipeline." + ) + # Update the lists that categorize dataset types as init- or run-time, + # and sometimes sort them (but they're not sorted now, at least not + # in general). + self._init_dataset_types.extend(task.init_dataset_types) + self._run_dataset_types.extend(task.run_dataset_types) + # Update the set of prerequisites. + self._prerequisite_inputs.update(task.prerequisite_inputs) + + def _make_task_vertex(self, task_label: str) -> TaskVertex: + return TaskVertex(task_label, self._xgraph[task_label]) + + def _make_dataset_type_vertex(self, dataset_type_name: str, is_init: bool) -> DatasetTypeVertex: + return DatasetTypeVertex(dataset_type_name, self._xgraph[dataset_type_name], is_init) diff --git a/python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py new file mode 100644 index 000000000..8169f4048 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_resolved_pipeline_graph.py @@ -0,0 +1,171 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ("ResolvedPipelineGraph",) + +import itertools +from typing import TYPE_CHECKING + +from lsst.daf.butler import DatasetType, DimensionUniverse, Registry +from lsst.daf.butler.registry import MissingDatasetTypeError + +from ._exceptions import IncompatibleDatasetTypeError +from ._generic_pipeline_graph import GenericPipelineGraph +from ._pipeline_graph import PipelineGraph +from ._vertices_and_edges import ResolvedDatasetTypeVertex, ResolvedTaskVertex + +if TYPE_CHECKING: + from ..connectionTypes import BaseConnection + + +class ResolvedPipelineGraph(GenericPipelineGraph[ResolvedTaskVertex, ResolvedDatasetTypeVertex]): + def __init__(self, unresolved: PipelineGraph, registry: Registry): + super().__init__( + unresolved._xgraph.copy(), + task_labels=unresolved._task_labels, + init_dataset_types=unresolved._init_dataset_types, + run_dataset_types=unresolved._run_dataset_types, + prerequisite_inputs=unresolved._prerequisite_inputs, + is_sorted=unresolved._is_sorted, + ) + for task_label in self._task_labels: + state = self._xgraph[task_label] + state["dimensions"] = registry.dimensions.extract(state["connections"].dimensions) + for dataset_type_name in itertools.chain(self._init_dataset_types, self._run_dataset_types): + state = self._xgraph[dataset_type_name] + state["dataset_type"] = self._resolve_dataset_type(dataset_type_name, registry) + + def _make_task_vertex(self, task_label: str) -> ResolvedTaskVertex: + return ResolvedTaskVertex(task_label, self._xgraph[task_label]) + + def _make_dataset_type_vertex(self, dataset_type_name: str, is_init: bool) -> ResolvedDatasetTypeVertex: + return ResolvedDatasetTypeVertex(dataset_type_name, self._xgraph[dataset_type_name], is_init) + + def _resolve_dataset_type(self, name: str, registry: Registry) -> DatasetType: + try: + dataset_type = registry.getDatasetType(name) + except MissingDatasetTypeError: + dataset_type = None + if (write_edge := self.producer_of(name)) is not None: + dataset_type = self._make_connection_dataset_type( + write_edge.task_label, + name, + current=dataset_type, + universe=registry.dimensions, + is_input=False, + ) + if dataset_type is not None: + # Registry and/or producer task have locked down the dataset type + # definition, so all we need to do here is check for consistency. + for consumer in self.consumers_of(name): + self._make_connection_dataset_type( + consumer, + name, + current=dataset_type, + universe=registry.dimensions, + is_input=False, + ) + else: + consumer_dataset_types = { + self._make_connection_dataset_type( + consumer, + name, + current=dataset_type, + universe=registry.dimensions, + is_input=False, + ) + for consumer in self.consumers_of(name) + } + if len(consumer_dataset_types) > 1: + raise MissingDatasetTypeError( + f"Dataset type {name!r} is not registered and not produced by this pipeline, " + f"but it used as an input by multiple tasks with different storage classes. " + f"This pipeline cannot be resolved until the dataset type is registered." + ) + assert len(consumer_dataset_types) == 1, "Can't have a dataset type node with no edges." + (dataset_type,) = consumer_dataset_types + return dataset_type + + def _make_connection_dataset_type( + self, + task_label: str, + dataset_type_name: str, + current: DatasetType | None, + universe: DimensionUniverse, + is_input: bool, + ) -> DatasetType: + task_state = self._xgraph[task_label] + connection: BaseConnection = task_state["connections"].allConnections[ + task_state["connection_reverse_mapping"][dataset_type_name] + ] + dimensions = connection.resolve_dimensions( + universe, + current.dimensions if current is not None else None, + task_label, + ) + parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) + assert parent_dataset_type_name == dataset_type_name, "Guaranteed by class invariants." + if component is not None: + if current is None: + raise MissingDatasetTypeError( + f"Dataset type {dataset_type_name!r} is not registered and not produced by this " + f"pipeline, but it used by task {task_label!r} via component {component!r}. " + f"This pipeline cannot be resolved until the parent dataset type is registered." + ) + if component not in current.storageClass.components: + raise IncompatibleDatasetTypeError( + f"Dataset type {parent_dataset_type_name!r} has storage class " + f"{current.storageClass_name!r}, which does not include a {component!r} " + f"as requested by task {task_label!r}." + ) + if current.storageClass.components[component].name != connection.storageClass: + raise IncompatibleDatasetTypeError( + f"Dataset type {dataset_type_name!r} has storage class " + f"{current.storageClass.components[component].name!r}, which does not match " + f"{connection.storageClass!r}, as requested by task {task_label!r}. " + "Note that storage class conversions of components are not supported." + ) + dataset_type = DatasetType( + parent_dataset_type_name, + dimensions, + storageClass=current.storageClass, + isCalibration=connection.isCalibration, + ) + + else: + dataset_type = connection.makeDatasetType(universe) + if current is not None: + if is_input: + if not dataset_type.is_compatible_with(current): + raise IncompatibleDatasetTypeError( + f"Incompatible definition for input dataset type {dataset_type_name!r} to task " + f"{task_label!r}: {current} != {dataset_type}." + ) + else: + if not current.is_compatible_with(dataset_type): + raise IncompatibleDatasetTypeError( + f"Incompatible definition for output dataset type {dataset_type_name!r} from task " + f"{task_label!r}: {current} != {dataset_type}." + ) + return current + else: + return dataset_type diff --git a/python/lsst/pipe/base/pipeline_graph/_vertices_and_edges.py b/python/lsst/pipe/base/pipeline_graph/_vertices_and_edges.py new file mode 100644 index 000000000..c1344b747 --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/_vertices_and_edges.py @@ -0,0 +1,221 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ( + "TaskVertex", + "ResolvedTaskVertex", + "DatasetTypeVertex", + "ResolvedDatasetTypeVertex", + "WriteEdge", + "ReadEdge", +) + +import dataclasses +from collections import ChainMap +from collections.abc import Mapping, Set +from typing import TYPE_CHECKING, Any, ClassVar, cast + +from lsst.daf.butler import DatasetType, DimensionGraph + +from ..connections import PipelineTaskConnections, iterConnections +from ._exceptions import VertexTypeError + +if TYPE_CHECKING: + from ..config import PipelineTaskConfig + from ..connectionTypes import BaseConnection + from ..pipelineTask import PipelineTask + + +class TaskVertex: + def __init__(self, label: str, state: dict[str, Any]): + if state["bipartite"] == DatasetTypeVertex.BIPARTITE_CONSTANT: + raise VertexTypeError(f"{label!r} is a dataset type name, not a task label") + self._label = label + self._state = state + + BIPARTITE_CONSTANT: ClassVar[int] = 0 + + @property + def label(self) -> str: + return self._state["label"] + + @property + def task_class(self) -> type[PipelineTask]: + return self._state["task_class"] + + @property + def config(self) -> PipelineTaskConfig: + return self._state["config"] + + @property + def init_inputs(self) -> Mapping[str, ReadEdge]: + return self._state["init_inputs"] + + @property + def prerequisite_inputs(self) -> Mapping[str, ReadEdge]: + return self._state["prerequisite_inputs"] + + @property + def inputs(self) -> Mapping[str, ReadEdge]: + return self._state["inputs"] + + @property + def init_outputs(self) -> Mapping[str, WriteEdge]: + return self._state["init_outputs"] + + @property + def outputs(self) -> Mapping[str, WriteEdge]: + return self._state["outputs"] + + @property + def all_inputs(self) -> Mapping[str, ReadEdge]: + return ChainMap( + cast(dict, self.init_inputs), + cast(dict, self.prerequisite_inputs), + cast(dict, self.inputs), + ) + + @property + def all_outputs(self) -> Mapping[str, WriteEdge]: + return ChainMap( + cast(dict, self.init_outputs), + cast(dict, self.outputs), + ) + + @property + def init_dataset_types(self) -> Set[str]: + return self.init_inputs.keys() & self.init_outputs.keys() + + @property + def run_dataset_types(self) -> Set[str]: + return self.prerequisite_inputs.keys() & self.inputs.keys() & self.outputs.keys() + + @classmethod + def _state_from_pipeline_task( + cls, + label: str, + task_class: type[PipelineTask], + config: PipelineTaskConfig, + connections: PipelineTaskConnections | None = None, + ) -> dict[str, Any]: + if connections is None: + connections = task_class.ConfigClass.ConnectionsClass(config=config) + init_inputs = { + c.name: ReadEdge.from_connection(label, c) for c in iterConnections(connections, "initInputs") + } + prerequisite_inputs = { + c.name: ReadEdge.from_connection(label, c, is_prerequisite=True) + for c in iterConnections(connections, "prerequisiteInputs") + } + inputs = {c.name: ReadEdge.from_connection(label, c) for c in iterConnections(connections, "inputs")} + init_outputs = { + c.name: WriteEdge.from_connection(label, c) for c in iterConnections(connections, "initOutputs") + } + outputs = { + c.name: WriteEdge.from_connection(label, c) for c in iterConnections(connections, "outputs") + } + connection_reverse_mapping = { + connection.name: internal_name for internal_name, connection in connections.allConnections.items() + } + return dict( + task_class=task_class, + config=config, + init_inputs=init_inputs, + prerequisite_inputs=prerequisite_inputs, + inputs=inputs, + init_outputs=init_outputs, + outputs=outputs, + connections=connections, + connection_reverse_mapping=connection_reverse_mapping, + bipartite=cls.BIPARTITE_CONSTANT, + ) + + +class ResolvedTaskVertex(TaskVertex): + @property + def dimensions(self) -> DimensionGraph: + return self._state["dimensions"] + + +class DatasetTypeVertex: + def __init__(self, name: str, state: dict[str, Any], is_init: bool): + if state["bipartite"] == TaskVertex.BIPARTITE_CONSTANT: + raise VertexTypeError(f"{name!r} is a task label, not a dataset type name.") + if state["is_init"] != is_init: + if is_init: + raise VertexTypeError(f"{name} is a run dataset type, not an init dataset type.") + else: + raise VertexTypeError(f"{name} is an init dataset type, not an run dataset type.") + self._name = name + self._state = state + + BIPARTITE_CONSTANT: ClassVar[int] = 1 + + @property + def name(self) -> str: + return self._name + + @property + def is_init(self) -> bool: + return self._state["is_init"] + + +class ResolvedDatasetTypeVertex(DatasetTypeVertex): + @property + def dataset_type(self) -> DatasetType: + return self._state["dataset_type"] + + +@dataclasses.dataclass(frozen=True) +class ReadEdge: + task_label: str + parent_dataset_type_name: str + component: str | None + storage_class_name: str + is_prerequisite: bool = False + + @classmethod + def from_connection( + cls, task_label: str, connection: BaseConnection, is_prerequisite: bool = False + ) -> ReadEdge: + parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) + return cls( + task_label, + parent_dataset_type_name, + component, + connection.storageClass, + is_prerequisite=is_prerequisite, + ) + + +@dataclasses.dataclass(frozen=True) +class WriteEdge: + task_label: str + dataset_type_name: str + storage_class_name: str + + @classmethod + def from_connection(cls, task_label: str, connection: BaseConnection) -> WriteEdge: + parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) + if component is not None: + raise ValueError(f"Illegal output component dataset {connection.name!r} in task {task_label!r}.") + return cls(task_label, parent_dataset_type_name, connection.storageClass)