-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
845 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# This file is part of pipe_base. | ||
# | ||
# Developed for the LSST Data Management System. | ||
# This product includes software developed by the LSST Project | ||
# (http://www.lsst.org). | ||
# See the COPYRIGHT file at the top-level directory of this distribution | ||
# for details of code ownership. | ||
# | ||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU General Public License | ||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
from __future__ import annotations | ||
|
||
from ._exceptions import * | ||
from ._generic_pipeline_graph import * | ||
from ._pipeline_graph import * | ||
from ._resolved_pipeline_graph import * | ||
from ._vertices_and_edges import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# This file is part of pipe_base. | ||
# | ||
# Developed for the LSST Data Management System. | ||
# This product includes software developed by the LSST Project | ||
# (http://www.lsst.org). | ||
# See the COPYRIGHT file at the top-level directory of this distribution | ||
# for details of code ownership. | ||
# | ||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU General Public License | ||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
from __future__ import annotations | ||
|
||
__all__ = ( | ||
"PipelineGraphError", | ||
"MultipleProducersError", | ||
"DatasetDependencyError", | ||
"ConnectionTypeConsistencyError", | ||
"IncompatibleDatasetTypeError", | ||
"VertexTypeError", | ||
) | ||
|
||
|
||
class PipelineGraphError(RuntimeError): | ||
pass | ||
|
||
|
||
class MultipleProducersError(PipelineGraphError): | ||
pass | ||
|
||
|
||
class DatasetDependencyError(PipelineGraphError): | ||
pass | ||
|
||
|
||
class ConnectionTypeConsistencyError(PipelineGraphError): | ||
pass | ||
|
||
|
||
class IncompatibleDatasetTypeError(PipelineGraphError): | ||
pass | ||
|
||
|
||
class VertexTypeError(KeyError): | ||
pass |
223 changes: 223 additions & 0 deletions
223
python/lsst/pipe/base/pipeline_graph/_generic_pipeline_graph.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
# This file is part of pipe_base. | ||
# | ||
# Developed for the LSST Data Management System. | ||
# This product includes software developed by the LSST Project | ||
# (http://www.lsst.org). | ||
# See the COPYRIGHT file at the top-level directory of this distribution | ||
# for details of code ownership. | ||
# | ||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU General Public License | ||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
from __future__ import annotations | ||
|
||
__all__ = ( | ||
"GenericPipelineGraph", | ||
"TasksView", | ||
"InitDatasetTypesView", | ||
"RunDatasetTypesView", | ||
) | ||
|
||
import itertools | ||
from abc import abstractmethod | ||
from collections import ChainMap | ||
from collections.abc import Iterable, Iterator, Mapping | ||
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar | ||
|
||
import networkx | ||
from lsst.utils.classes import cached_getter | ||
|
||
from ._exceptions import DatasetDependencyError | ||
|
||
if TYPE_CHECKING: | ||
from ._vertices_and_edges import DatasetTypeVertex, ReadEdge, TaskVertex, WriteEdge | ||
|
||
|
||
_D = TypeVar("_D", bound=DatasetTypeVertex, covariant=True) | ||
_T = TypeVar("_T", bound=TaskVertex, covariant=True) | ||
|
||
|
||
class TasksView(Mapping[str, _T]): | ||
def __init__(self, parent: GenericPipelineGraph[_T, _D]): | ||
self._parent = parent | ||
|
||
def __iter__(self) -> Iterator[str]: | ||
return iter(self._parent._task_labels) | ||
|
||
def __getitem__(self, key: str) -> _T: | ||
return self._parent._make_task_vertex(key) | ||
|
||
def __len__(self) -> int: | ||
return len(self._parent._task_labels) | ||
|
||
@property | ||
@cached_getter | ||
def xgraph(self) -> networkx.MultiDiGraph: | ||
self._parent._get_task_xgraph() | ||
|
||
@xgraph.deleter | ||
def xgraph(self) -> None: | ||
if hasattr(self._parent, "_cached_get_task_xggraph"): | ||
delattr(self._parent, "_cached_get_task_xggraph") | ||
|
||
|
||
class InitDatasetTypesView(Mapping[str, _D]): | ||
def __init__(self, parent: GenericPipelineGraph[_T, _D]): | ||
self._parent = parent | ||
|
||
def __iter__(self) -> Iterator[str]: | ||
return iter(self._parent._init_dataset_types) | ||
|
||
def __getitem__(self, key: str) -> _D: | ||
return self._parent._make_dataset_type_vertex(key, is_init=True) | ||
|
||
def __len__(self) -> int: | ||
return len(self._parent._init_dataset_types) | ||
|
||
@property | ||
@cached_getter | ||
def xgraph(self) -> networkx.DiGraph: | ||
self._parent._get_init_dataset_type_xgraph() | ||
|
||
@xgraph.deleter | ||
def xgraph(self) -> None: | ||
if hasattr(self._parent, "_cached_get_init_dataset_type_xggraph"): | ||
delattr(self._parent, "_cached_get_init_dataset_type_xggraph") | ||
|
||
|
||
class RunDatasetTypesView(Mapping[str, _D]): | ||
def __init__(self, parent: GenericPipelineGraph[_T, _D]): | ||
self._parent = parent | ||
|
||
def __iter__(self) -> Iterator[str]: | ||
return iter(self._parent._run_dataset_types) | ||
|
||
def __getitem__(self, key: str) -> _D: | ||
return self._parent._make_dataset_type_vertex(key, is_init=False) | ||
|
||
def __len__(self) -> int: | ||
return len(self._parent._run_dataset_types) | ||
|
||
@property | ||
@cached_getter | ||
def xgraph(self) -> networkx.DiGraph: | ||
self._parent._get_run_dataset_type_xgraph() | ||
|
||
@xgraph.deleter | ||
def xgraph(self) -> None: | ||
if hasattr(self._parent, "_cached_get_run_dataset_type_xggraph"): | ||
delattr(self._parent, "_cached_get_run_dataset_type_xggraph") | ||
|
||
|
||
class GenericPipelineGraph(Generic[_T, _D]): | ||
def __init__( | ||
self, | ||
xgraph: networkx.DiGraph | None = None, | ||
*, | ||
task_labels: Iterable[str] = (), | ||
init_dataset_types: Iterable[str] = (), | ||
run_dataset_types: Iterable[str] = (), | ||
prerequisite_inputs: Iterable[str] = (), | ||
is_sorted: bool = False, | ||
) -> None: | ||
self._xgraph = networkx.DiGraph() if xgraph is None else xgraph.copy() | ||
self._task_labels = list(task_labels) | ||
self._init_dataset_types = list(init_dataset_types) | ||
self._run_dataset_types = list(run_dataset_types) | ||
self._prerequisite_inputs = set(prerequisite_inputs) | ||
self._is_sorted = is_sorted | ||
|
||
TASK_BIPARTITE: ClassVar[int] = 0 | ||
DATASET_TYPE_BIPARTITE: ClassVar[int] = 1 | ||
|
||
@property | ||
def tasks(self) -> TasksView[_T]: | ||
return TasksView(self) | ||
|
||
@property | ||
def init_dataset_types(self) -> InitDatasetTypesView[_D]: | ||
return InitDatasetTypesView(self) | ||
|
||
@property | ||
def run_dataset_types(self) -> RunDatasetTypesView[_D]: | ||
return RunDatasetTypesView(self) | ||
|
||
@property | ||
def dataset_types(self) -> Mapping[str, _D]: | ||
# ChainMap is a MutableMapping, so it wants to be passed | ||
# MutableMappings, but we've annotated the return type as just | ||
# Mapping to ensure it really only needs Mappings. | ||
return ChainMap(self.init_dataset_types, self.run_dataset_types) # type: ignore | ||
|
||
@property | ||
@cached_getter | ||
def xgraph(self) -> networkx.DiGraph: | ||
return self._xgraph.copy(as_view=True) | ||
|
||
@property | ||
def is_sorted(self) -> bool: | ||
return self._is_sorted | ||
|
||
def sort(self) -> None: | ||
if not self._is_sorted: | ||
task_labels = [] | ||
init_dataset_types = [] | ||
run_dataset_types = [] | ||
try: | ||
for k in networkx.lexicographical_topological_sort(self._xgraph): | ||
match self._xgraph[k]: | ||
case {"bipartite": self.TASK_BIPARTITE}: | ||
task_labels.append(k) | ||
case {"bipartite": self.DATASET_TYPE_BIPARTITE, "is_init": True}: | ||
init_dataset_types.append(k) | ||
case {"bipartite": self.DATASET_TYPE_BIPARTITE, "is_init": False}: | ||
run_dataset_types.append(k) | ||
except networkx.NetworkXUnfeasible as err: | ||
raise DatasetDependencyError("Cycle detected while attempting to sort graph.") from err | ||
self._task_labels = task_labels | ||
self._init_dataset_types = init_dataset_types | ||
self._run_dataset_types = run_dataset_types | ||
self._is_sorted = True | ||
|
||
def producer_of(self, dataset_type_name: str) -> WriteEdge | None: | ||
for _, _, edge in self._xgraph.in_edges(dataset_type_name, "state"): | ||
return edge | ||
return None | ||
|
||
def consumers_of(self, dataset_type_name: str) -> dict[str, ReadEdge]: | ||
return { | ||
task_label: edge for task_label, _, edge in self._xgraph.out_edges(dataset_type_name, "state") | ||
} | ||
|
||
@abstractmethod | ||
def _make_task_vertex(self, task_label: str) -> _T: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def _make_dataset_type_vertex(self, dataset_type_name: str, is_init: bool) -> _D: | ||
raise NotImplementedError() | ||
|
||
@cached_getter | ||
def _get_task_xgraph(self) -> networkx.MultiDiGraph: | ||
return networkx.freeze( | ||
networkx.bipartite.projected_graph(self._xgraph, self._task_labels, multigraph=True) | ||
) | ||
|
||
@cached_getter | ||
def _get_init_dataset_type_xgraph(self) -> networkx.DiGraph: | ||
init_subgraph = self._xgraph.subgraph(itertools.chain(self._task_labels, self._init_dataset_types)) | ||
return networkx.freeze(networkx.bipartite.projected_graph(init_subgraph, self._init_dataset_types)) | ||
|
||
@cached_getter | ||
def _get_run_dataset_type_xgraph(self) -> networkx.DiGraph: | ||
run_subgraph = self._xgraph.subgraph(itertools.chain(self._task_labels, self._init_dataset_types)) | ||
return networkx.freeze(networkx.bipartite.projected_graph(run_subgraph, self._run_dataset_types)) |
Oops, something went wrong.