From edf04383957882745d5289cb0eee62448c1a892a Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Thu, 14 Jul 2022 15:04:27 -0400 Subject: [PATCH] WIP: add PipelineGraph --- python/lsst/pipe/base/_pipeline_graph.py | 601 +++++++++++++++++++++++ python/lsst/pipe/base/pipeline.py | 10 +- 2 files changed, 608 insertions(+), 3 deletions(-) create mode 100644 python/lsst/pipe/base/_pipeline_graph.py diff --git a/python/lsst/pipe/base/_pipeline_graph.py b/python/lsst/pipe/base/_pipeline_graph.py new file mode 100644 index 00000000..21e3c208 --- /dev/null +++ b/python/lsst/pipe/base/_pipeline_graph.py @@ -0,0 +1,601 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ( + "TaskVertex", + "DatasetTypeVertex", + "WriteEdge", + "ReadEdge", + "GenericPipelineGraph", + "PipelineGraph", + "ResolvedPipelineGraph", +) + +import dataclasses +import itertools +from abc import abstractmethod +from collections import ChainMap +from collections.abc import Collection, Iterable, Iterator, Mapping, Set +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeVar, cast + +import networkx +from lsst.daf.butler import DatasetType, DimensionGraph, DimensionUniverse, Registry +from lsst.daf.butler.registry import MissingDatasetTypeError +from lsst.utils.classes import cached_getter + +from .connections import iterConnections + +if TYPE_CHECKING: + from .config import PipelineTaskConfig + from .connectionTypes import BaseConnection + from .pipelineTask import PipelineTask + + +_C = TypeVar("_C", bound=Collection[str]) +_V = TypeVar("_V", covariant=True) +_D = TypeVar("_D", bound="DatasetTypeVertex", covariant=True) +_T = TypeVar("_T", bound="TaskVertex", covariant=True) + + +class MultipleProducersError(RuntimeError): + pass + + +class DatasetDependencyError(RuntimeError): + pass + + +class ConnectionTypeConsistencyError(RuntimeError): + pass + + +class IncompatibleDatasetTypeError(RuntimeError): + pass + + +class TaskVertex: + def __init__(self, label: str, state: dict[str, Any]): + if state["bipartite"] == PipelineGraph.DATASET_TYPE_BIPARTITE: + raise KeyError(f"{label!r} is a dataset type name, not a task label") + self._label = label + self._state = state + + @property + def label(self) -> str: + return self._state["label"] + + @property + def task_class(self) -> type[PipelineTask]: + return self._state["task_class"] + + @property + def config(self) -> PipelineTaskConfig: + return self._state["config"] + + @property + def init_inputs(self) -> Mapping[str, ReadEdge]: + return self._state["init_inputs"] + + @property + def prerequisite_inputs(self) -> Mapping[str, ReadEdge]: + return self._state["prerequisite_inputs"] + + @property + def inputs(self) -> Mapping[str, ReadEdge]: + return self._state["inputs"] + + @property + def init_outputs(self) -> Mapping[str, WriteEdge]: + return self._state["init_outputs"] + + @property + def outputs(self) -> Mapping[str, WriteEdge]: + return self._state["outputs"] + + @property + def all_inputs(self) -> Mapping[str, ReadEdge]: + return ChainMap( + cast(dict, self.init_inputs), + cast(dict, self.prerequisite_inputs), + cast(dict, self.inputs), + ) + + @property + def all_outputs(self) -> Mapping[str, WriteEdge]: + return ChainMap( + cast(dict, self.init_outputs), + cast(dict, self.outputs), + ) + + @property + def init_dataset_types(self) -> Set[str]: + return self.init_inputs.keys() & self.init_outputs.keys() + + @property + def run_dataset_types(self) -> Set[str]: + return self.prerequisite_inputs.keys() & self.inputs.keys() & self.outputs.keys() + + @classmethod + def from_pipeline_task( + cls, + label: str, + task_class: type[PipelineTask], + config: PipelineTaskConfig, + ) -> TaskVertex: + return cls(label, cls._state_from_pipeline_task(label, task_class, config)) + + @classmethod + def _state_from_pipeline_task( + cls, + label: str, + task_class: type[PipelineTask], + config: PipelineTaskConfig, + ) -> dict[str, Any]: + connections = task_class.ConfigClass.ConnectionsClass(config=config) + init_inputs = { + c.name: ReadEdge.from_connection(label, c) for c in iterConnections(connections, "initInputs") + } + prerequisite_inputs = { + c.name: ReadEdge.from_connection(label, c, is_prerequisite=True) + for c in iterConnections(connections, "prerequisiteInputs") + } + inputs = {c.name: ReadEdge.from_connection(label, c) for c in iterConnections(connections, "inputs")} + init_outputs = { + c.name: WriteEdge.from_connection(label, c) for c in iterConnections(connections, "initOutputs") + } + outputs = { + c.name: WriteEdge.from_connection(label, c) for c in iterConnections(connections, "outputs") + } + connection_reverse_mapping = { + connection.name: internal_name for internal_name, connection in connections.allConnections.items() + } + return dict( + task_class=task_class, + config=config, + init_inputs=init_inputs, + prerequisite_inputs=prerequisite_inputs, + inputs=inputs, + init_outputs=init_outputs, + outputs=outputs, + connections=connections, + connection_reverse_mapping=connection_reverse_mapping, + bipartite=PipelineGraph.TASK_BIPARTITE, + ) + + +class ResolvedTaskVertex(TaskVertex): + @property + def dimensions(self) -> DimensionGraph: + return self._state["dimensions"] + + +class DatasetTypeVertex: + def __init__(self, name: str, state: dict[str, Any]): + if state["bipartite"] == PipelineGraph.TASK_BIPARTITE: + raise KeyError(f"{name!r} is a task label, not a dataset type name.") + self._name = name + self._state = state + + @property + def name(self) -> str: + return self._name + + @property + def is_init(self) -> bool: + return self._state["is_init"] + + +class ResolvedDatasetTypeVertex(DatasetTypeVertex): + @property + def dataset_type(self) -> DatasetType: + return self._state["dataset_type"] + + +@dataclasses.dataclass(frozen=True) +class ReadEdge: + task_label: str + parent_dataset_type_name: str + component: str | None + storage_class_name: str + is_prerequisite: bool = False + + @classmethod + def from_connection( + cls, task_label: str, connection: BaseConnection, is_prerequisite: bool = False + ) -> ReadEdge: + parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) + return cls( + task_label, + parent_dataset_type_name, + component, + connection.storageClass, + is_prerequisite=is_prerequisite, + ) + + +@dataclasses.dataclass(frozen=True) +class WriteEdge: + task_label: str + dataset_type_name: str + storage_class_name: str + + @classmethod + def from_connection(cls, task_label: str, connection: BaseConnection) -> WriteEdge: + parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) + if component is not None: + raise ValueError(f"Illegal output component dataset {connection.name!r} in task {task_label!r}.") + return cls(task_label, parent_dataset_type_name, connection.storageClass) + + +class _GraphMappingView(Mapping[str, _V]): + def __init__(self, keys: Collection[str], view_factory: Callable[[str], _V]): + self._keys = keys + self._view_factory = view_factory + + def __iter__(self) -> Iterator[str]: + return iter(self._keys) + + def __getitem__(self, key: str) -> _V: + return self._view_factory(key) + + def __len__(self) -> int: + return len(self._keys) + + +class GenericPipelineGraph(Generic[_T, _D]): + def __init__( + self, + bipartite_graph: networkx.DiGraph | None = None, + *, + task_labels: Iterable[str] = (), + init_dataset_types: Iterable[str] = (), + run_dataset_types: Iterable[str] = (), + prerequisite_inputs: Iterable[str] = (), + is_sorted: bool = False, + ) -> None: + self._bipartite_graph = networkx.DiGraph() if bipartite_graph is None else bipartite_graph.copy() + self._task_labels = list(task_labels) + self._init_dataset_types = list(init_dataset_types) + self._run_dataset_types = list(run_dataset_types) + self._prerequisite_inputs = set(prerequisite_inputs) + self._is_sorted = is_sorted + + TASK_BIPARTITE: ClassVar[int] = 0 + DATASET_TYPE_BIPARTITE: ClassVar[int] = 1 + + @property + def tasks(self) -> Mapping[str, _T]: + return _GraphMappingView(self._task_labels, self._make_task_vertex) + + @property + def init_dataset_types(self) -> Mapping[str, _D]: + return _GraphMappingView(self._init_dataset_types, self._make_dataset_type_vertex) + + @property + def run_dataset_types(self) -> Mapping[str, _D]: + return _GraphMappingView(self._run_dataset_types, self._make_dataset_type_vertex) + + @property + def dataset_types(self) -> Mapping[str, _D]: + # ChainMap is a MutableMapping, so it wants to be passed + # MutableMappings, but we've annotated the return type as just + # Mapping to ensure it really only needs Mappings. + return ChainMap(self.init_dataset_types, self.run_dataset_types) # type: ignore + + @property + @cached_getter + def bipartite_graph(self) -> networkx.DiGraph: + return self._bipartite_graph.copy(as_view=True) + + @property + @cached_getter + def task_graph(self) -> networkx.MultiDiGraph: + return networkx.freeze( + networkx.bipartite.projected_graph(self._bipartite_graph, self._task_labels, multigraph=True) + ) + + @property + @cached_getter + def init_dataset_type_graph(self) -> networkx.DiGraph: + init_subgraph = self._bipartite_graph.subgraph( + itertools.chain(self._task_labels, self._init_dataset_types) + ) + return networkx.freeze(networkx.bipartite.projected_graph(init_subgraph, self._init_dataset_types)) + + @property + @cached_getter + def run_dataset_type_graph(self) -> networkx.DiGraph: + run_subgraph = self._bipartite_graph.subgraph( + itertools.chain(self._task_labels, self._init_dataset_types) + ) + return networkx.freeze(networkx.bipartite.projected_graph(run_subgraph, self._run_dataset_types)) + + @property + def is_sorted(self) -> bool: + return self._is_sorted + + def sort(self) -> None: + if not self._is_sorted: + task_labels = [] + init_dataset_types = [] + run_dataset_types = [] + for k in networkx.lexicographical_topological_sort(self._bipartite_graph): + match self._bipartite_graph[k]: + case {"bipartite": self.TASK_BIPARTITE}: + task_labels.append(k) + case {"bipartite": self.DATASET_TYPE_BIPARTITE, "is_init": True}: + init_dataset_types.append(k) + case {"bipartite": self.DATASET_TYPE_BIPARTITE, "is_init": False}: + run_dataset_types.append(k) + self._task_labels = task_labels + self._init_dataset_types = init_dataset_types + self._run_dataset_types = run_dataset_types + self._is_sorted = True + + def producer_of(self, dataset_type_name: str) -> WriteEdge | None: + for _, _, edge in self._bipartite_graph.in_edges(dataset_type_name, "state"): + return edge + return None + + def consumers_of(self, dataset_type_name: str) -> dict[str, ReadEdge]: + return { + task_label: edge + for task_label, _, edge in self._bipartite_graph.out_edges(dataset_type_name, "state") + } + + @abstractmethod + def _make_task_vertex(self, task_label: str) -> _T: + raise NotImplementedError() + + @abstractmethod + def _make_dataset_type_vertex(self, dataset_type_name: str) -> _D: + raise NotImplementedError() + + +class PipelineGraph(GenericPipelineGraph[TaskVertex, DatasetTypeVertex]): + def __init__(self) -> None: + super().__init__() + + def add_task(self, label: str, task_class: type[PipelineTask], config: PipelineTaskConfig) -> None: + # Invalidate the sorted list of dataset type names and all cached + # subgraph views. + self._is_sorted = False + for attr_name in ("task_graph", "init_dataset_type_graph", "run_dataset_type_graph"): + cache_name = f"_cached_{attr_name}" + if hasattr(self, cache_name): + delattr(self, cache_name) + # Add the task node itself and then get a vertex view back for ease + # of use later in this method. + self._task_labels.append(label) + state = TaskVertex._state_from_pipeline_task(label, task_class, config) + self._bipartite_graph.add_node(label, **state) + task = self._make_task_vertex(label) + # Add [init_]input edges to the graph. + for read_edge in itertools.chain(task.init_inputs.values(), task.inputs.values()): + if read_edge.parent_dataset_type_name in self._prerequisite_inputs: + raise ConnectionTypeConsistencyError( + f"Dataset type {read_edge.parent_dataset_type_name!r} is not a prerequisite input in " + f"task {task.label!r}, but it was a prerequisite to " + f"{self.consumers_of(read_edge.parent_dataset_type_name)}." + ) + self._bipartite_graph.add_edge( + read_edge.parent_dataset_type_name, read_edge.task_label, state=read_edge + ) + self._bipartite_graph.nodes[read_edge.parent_dataset_type_name][ + "bipartite" + ] = self.DATASET_TYPE_BIPARTITE + # Add prerequisite input edges to the graph. + for read_edge in task.prerequisite_inputs.values(): + if ( + read_edge.parent_dataset_type_name in self._run_dataset_types + and read_edge.parent_dataset_type_name not in self._prerequisite_inputs + ): + if producer := self.producer_of(read_edge.parent_dataset_type_name): + details = f"an output of task {producer!r}" + else: + details = ( + f"a regular input of tasks {self.consumers_of(read_edge.parent_dataset_type_name)}" + ) + raise ConnectionTypeConsistencyError( + f"Dataset type {read_edge.parent_dataset_type_name!r} is a prerequisite input to task " + f"{task.label!r}, but it is {details}." + ) + self._bipartite_graph.add_edge( + read_edge.parent_dataset_type_name, read_edge.task_label, state=read_edge + ) + self._bipartite_graph.nodes[read_edge.parent_dataset_type_name][ + "bipartite" + ] = self.DATASET_TYPE_BIPARTITE + # Add [init_]output edges to the graph. + for write_edge in itertools.chain(task.init_outputs.values(), task.outputs.values()): + if write_edge.dataset_type_name in self._prerequisite_inputs: + raise ConnectionTypeConsistencyError( + f"Dataset type {write_edge.dataset_type_name!r} is an output of task {task.label!r}, " + f"but it was a prerequisite input to {self.consumers_of(write_edge.dataset_type_name)}." + ) + self._bipartite_graph.add_edge( + write_edge.task_label, write_edge.dataset_type_name, state=write_edge + ) + self._bipartite_graph.nodes[write_edge.dataset_type_name][ + "bipartite" + ] = self.DATASET_TYPE_BIPARTITE + # Update boolean flags for init vs. run to state dicts and check for + # consistency. + for dataset_type_name in task.init_dataset_types: + if not self._bipartite_graph[dataset_type_name].setdefault("is_init", True): + raise ConnectionTypeConsistencyError( + f"Dataset type {dataset_type_name!r} is init-time in task {task.label!r}, " + "but run-time in the current pipeline." + ) + for dataset_type_name in task.run_dataset_types: + if self._bipartite_graph[dataset_type_name].setdefault("is_init", False): + raise ConnectionTypeConsistencyError( + f"Dataset type {dataset_type_name!r} is run-time in task {task.label!r}, " + "but init-time in the current pipeline." + ) + # Update the lists that categorize dataset types as init- or run-time, + # and sometimes sort them (but they're not sorted now, at least not + # in general). + self._init_dataset_types.extend(task.init_dataset_types) + self._run_dataset_types.extend(task.run_dataset_types) + # Update the set of prerequisites. + self._prerequisite_inputs.update(task.prerequisite_inputs) + + def _make_task_vertex(self, task_label: str) -> TaskVertex: + return TaskVertex(task_label, self._bipartite_graph[task_label]) + + def _make_dataset_type_vertex(self, dataset_type_name: str) -> DatasetTypeVertex: + return DatasetTypeVertex(dataset_type_name, self._bipartite_graph[dataset_type_name]) + + +class ResolvedPipelineGraph(GenericPipelineGraph[ResolvedTaskVertex, ResolvedDatasetTypeVertex]): + def __init__(self, unresolved: PipelineGraph, registry: Registry): + super().__init__( + unresolved._bipartite_graph.copy(), + task_labels=unresolved._task_labels, + init_dataset_types=unresolved._init_dataset_types, + run_dataset_types=unresolved._run_dataset_types, + prerequisite_inputs=unresolved._prerequisite_inputs, + is_sorted=unresolved._is_sorted, + ) + for task_label in self._task_labels: + state = self._bipartite_graph[task_label] + state["dimensions"] = registry.dimensions.extract(state["connections"].dimensions) + for dataset_type_name in itertools.chain(self._init_dataset_types, self._run_dataset_types): + state = self._bipartite_graph[dataset_type_name] + state["dataset_type"] = self._resolve_dataset_type(dataset_type_name, registry) + + def _make_task_vertex(self, task_label: str) -> ResolvedTaskVertex: + return ResolvedTaskVertex(task_label, self._bipartite_graph[task_label]) + + def _make_dataset_type_vertex(self, dataset_type_name: str) -> ResolvedDatasetTypeVertex: + return ResolvedDatasetTypeVertex(dataset_type_name, self._bipartite_graph[dataset_type_name]) + + def _resolve_dataset_type(self, name: str, registry: Registry) -> DatasetType: + try: + dataset_type = registry.getDatasetType(name) + except MissingDatasetTypeError: + dataset_type = None + if (write_edge := self.producer_of(name)) is not None: + dataset_type = self._make_connection_dataset_type( + write_edge.task_label, + name, + current=dataset_type, + universe=registry.dimensions, + is_input=False, + ) + if dataset_type is not None: + # Registry and/or producer task have locked down the dataset type + # definition, so all we need to do here is check for consistency. + for consumer in self.consumers_of(name): + self._make_connection_dataset_type( + consumer, + name, + current=dataset_type, + universe=registry.dimensions, + is_input=False, + ) + else: + consumer_dataset_types = { + self._make_connection_dataset_type( + consumer, + name, + current=dataset_type, + universe=registry.dimensions, + is_input=False, + ) + for consumer in self.consumers_of(name) + } + if len(consumer_dataset_types) > 1: + raise MissingDatasetTypeError( + f"Dataset type {name!r} is not registered and not produced by this pipeline, " + f"but it used as an input by multiple tasks with different storage classes. " + f"This pipeline cannot be resolved until the dataset type is registered." + ) + assert len(consumer_dataset_types) == 1, "Can't have a dataset type node with no edges." + (dataset_type,) = consumer_dataset_types + return dataset_type + + def _make_connection_dataset_type( + self, + task_label: str, + dataset_type_name: str, + current: DatasetType | None, + universe: DimensionUniverse, + is_input: bool, + ) -> DatasetType: + task_state = self._bipartite_graph[task_label] + connection: BaseConnection = task_state["connections"].allConnections[ + task_state["connection_reverse_mapping"][dataset_type_name] + ] + dimensions = connection.resolve_dimensions( + universe, + current.dimensions if current is not None else None, + task_label, + ) + parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name) + assert parent_dataset_type_name == dataset_type_name, "Guaranteed by class invariants." + if component is not None: + if current is None: + raise MissingDatasetTypeError( + f"Dataset type {dataset_type_name!r} is not registered and not produced by this " + f"pipeline, but it used by task {task_label!r} via component {component!r}. " + f"This pipeline cannot be resolved until the parent dataset type is registered." + ) + if component not in current.storageClass.components: + raise IncompatibleDatasetTypeError( + f"Dataset type {parent_dataset_type_name!r} has storage class " + f"{current.storageClass_name!r}, which does not include a {component!r} " + f"as requested by task {task_label!r}." + ) + if current.storageClass.components[component].name != connection.storageClass: + raise IncompatibleDatasetTypeError( + f"Dataset type {dataset_type_name!r} has storage class " + f"{current.storageClass.components[component].name!r}, which does not match " + f"{connection.storageClass!r}, as requested by task {task_label!r}. " + "Note that storage class conversions of components are not supported." + ) + dataset_type = DatasetType( + parent_dataset_type_name, + dimensions, + storageClass=current.storageClass, + isCalibration=connection.isCalibration, + ) + + else: + dataset_type = connection.makeDatasetType(universe) + if current is not None: + if is_input: + if not dataset_type.is_compatible_with(current): + raise IncompatibleDatasetTypeError( + f"Incompatible definition for input dataset type {dataset_type_name!r} to task " + f"{task_label!r}: {current} != {dataset_type}." + ) + else: + if not current.is_compatible_with(dataset_type): + raise IncompatibleDatasetTypeError( + f"Incompatible definition for output dataset type {dataset_type_name!r} from task " + f"{task_label!r}: {current} != {dataset_type}." + ) + return current + else: + return dataset_type diff --git a/python/lsst/pipe/base/pipeline.py b/python/lsst/pipe/base/pipeline.py index c4107e7c..2798be95 100644 --- a/python/lsst/pipe/base/pipeline.py +++ b/python/lsst/pipe/base/pipeline.py @@ -62,7 +62,8 @@ from lsst.utils import doImportType from lsst.utils.introspection import get_full_type_name -from . import pipelineIR, pipeTools +from . import pipelineIR +from ._pipeline_graph import PipelineGraph from ._task_metadata import TaskMetadata from .config import PipelineTaskConfig from .configOverrides import ConfigOverrides @@ -677,8 +678,11 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]: f"Contract(s) '{contract.contract}' were not satisfied{extra_info}" ) - taskDefs = sorted(taskDefs, key=lambda x: x.label) - yield from pipeTools.orderPipeline(taskDefs) + graph = PipelineGraph() + # TODO: add nodes + + sort_map = {label: n for n, label in enumerate(graph.tasks)} + yield from sorted(taskDefs, key=lambda x: sort_map[x.label]) def _buildTaskDef(self, label: str) -> TaskDef: if (taskIR := self._pipelineIR.tasks.get(label)) is None: