Skip to content

Commit

Permalink
WIP: add PipelineGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
TallJimbo committed Jan 13, 2023
1 parent a136fad commit 44ac030
Show file tree
Hide file tree
Showing 8 changed files with 845 additions and 6 deletions.
6 changes: 5 additions & 1 deletion python/lsst/pipe/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import connectionTypes, pipelineIR
from . import connectionTypes, pipeline_graph, pipelineIR
from ._dataset_handle import *
from ._instrument import *
from ._status import *
Expand All @@ -10,6 +10,10 @@
from .graph import *
from .graphBuilder import *
from .pipeline import *

# We import the main PipelineGraph types and the module (above), but we don't
# lift all symbols to package scope.
from .pipeline_graph import PipelineGraph, ResolvedPipelineGraph
from .pipelineTask import *
from .struct import *
from .task import *
Expand Down
19 changes: 14 additions & 5 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,12 @@
from lsst.utils import doImportType
from lsst.utils.introspection import get_full_type_name

from . import pipelineIR, pipeTools
from . import pipelineIR
from ._task_metadata import TaskMetadata
from .config import PipelineTaskConfig
from .configOverrides import ConfigOverrides
from .connections import iterConnections
from .pipeline_graph import PipelineGraph
from .pipelineTask import PipelineTask
from .task import _TASK_METADATA_TYPE

Expand Down Expand Up @@ -660,7 +661,7 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
If a dataId is supplied in a config block. This is in place for
future use
"""
taskDefs = []
taskDefs: list[TaskDef] = []
for label in self._pipelineIR.tasks:
taskDefs.append(self._buildTaskDef(label))

Expand All @@ -676,9 +677,17 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
raise pipelineIR.ContractError(
f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
)

taskDefs = sorted(taskDefs, key=lambda x: x.label)
yield from pipeTools.orderPipeline(taskDefs)
graph = PipelineGraph()
for task_def in taskDefs:
graph.add_task(
task_def.label,
task_def.taskClass,
config=task_def.config,
connections=task_def.connections,
)
graph.sort()
sort_map = {label: n for n, label in enumerate(graph.tasks)}
yield from sorted(taskDefs, key=lambda x: sort_map[x.label])

def _buildTaskDef(self, label: str) -> TaskDef:
if (taskIR := self._pipelineIR.tasks.get(label)) is None:
Expand Down
27 changes: 27 additions & 0 deletions python/lsst/pipe/base/pipeline_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations

from ._exceptions import *
from ._generic_pipeline_graph import *
from ._pipeline_graph import *
from ._resolved_pipeline_graph import *
from ._vertices_and_edges import *
54 changes: 54 additions & 0 deletions python/lsst/pipe/base/pipeline_graph/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations

__all__ = (
"PipelineGraphError",
"MultipleProducersError",
"DatasetDependencyError",
"ConnectionTypeConsistencyError",
"IncompatibleDatasetTypeError",
"VertexTypeError",
)


class PipelineGraphError(RuntimeError):
pass


class MultipleProducersError(PipelineGraphError):
pass


class DatasetDependencyError(PipelineGraphError):
pass


class ConnectionTypeConsistencyError(PipelineGraphError):
pass


class IncompatibleDatasetTypeError(PipelineGraphError):
pass


class VertexTypeError(KeyError):
pass
223 changes: 223 additions & 0 deletions python/lsst/pipe/base/pipeline_graph/_generic_pipeline_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations

__all__ = (
"GenericPipelineGraph",
"TasksView",
"InitDatasetTypesView",
"RunDatasetTypesView",
)

import itertools
from abc import abstractmethod
from collections import ChainMap
from collections.abc import Iterable, Iterator, Mapping
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar

import networkx
from lsst.utils.classes import cached_getter

from ._exceptions import DatasetDependencyError

if TYPE_CHECKING:
from ._vertices_and_edges import DatasetTypeVertex, ReadEdge, TaskVertex, WriteEdge


_D = TypeVar("_D", bound=DatasetTypeVertex, covariant=True)
_T = TypeVar("_T", bound=TaskVertex, covariant=True)


class TasksView(Mapping[str, _T]):
def __init__(self, parent: GenericPipelineGraph[_T, _D]):
self._parent = parent

def __iter__(self) -> Iterator[str]:
return iter(self._parent._task_labels)

def __getitem__(self, key: str) -> _T:
return self._parent._make_task_vertex(key)

def __len__(self) -> int:
return len(self._parent._task_labels)

@property
@cached_getter
def xgraph(self) -> networkx.MultiDiGraph:
self._parent._get_task_xgraph()

@xgraph.deleter
def xgraph(self) -> None:
if hasattr(self._parent, "_cached_get_task_xggraph"):
delattr(self._parent, "_cached_get_task_xggraph")


class InitDatasetTypesView(Mapping[str, _D]):
def __init__(self, parent: GenericPipelineGraph[_T, _D]):
self._parent = parent

def __iter__(self) -> Iterator[str]:
return iter(self._parent._init_dataset_types)

def __getitem__(self, key: str) -> _D:
return self._parent._make_dataset_type_vertex(key, is_init=True)

def __len__(self) -> int:
return len(self._parent._init_dataset_types)

@property
@cached_getter
def xgraph(self) -> networkx.DiGraph:
self._parent._get_init_dataset_type_xgraph()

@xgraph.deleter
def xgraph(self) -> None:
if hasattr(self._parent, "_cached_get_init_dataset_type_xggraph"):
delattr(self._parent, "_cached_get_init_dataset_type_xggraph")


class RunDatasetTypesView(Mapping[str, _D]):
def __init__(self, parent: GenericPipelineGraph[_T, _D]):
self._parent = parent

def __iter__(self) -> Iterator[str]:
return iter(self._parent._run_dataset_types)

def __getitem__(self, key: str) -> _D:
return self._parent._make_dataset_type_vertex(key, is_init=False)

def __len__(self) -> int:
return len(self._parent._run_dataset_types)

@property
@cached_getter
def xgraph(self) -> networkx.DiGraph:
self._parent._get_run_dataset_type_xgraph()

@xgraph.deleter
def xgraph(self) -> None:
if hasattr(self._parent, "_cached_get_run_dataset_type_xggraph"):
delattr(self._parent, "_cached_get_run_dataset_type_xggraph")


class GenericPipelineGraph(Generic[_T, _D]):
def __init__(
self,
xgraph: networkx.DiGraph | None = None,
*,
task_labels: Iterable[str] = (),
init_dataset_types: Iterable[str] = (),
run_dataset_types: Iterable[str] = (),
prerequisite_inputs: Iterable[str] = (),
is_sorted: bool = False,
) -> None:
self._xgraph = networkx.DiGraph() if xgraph is None else xgraph.copy()
self._task_labels = list(task_labels)
self._init_dataset_types = list(init_dataset_types)
self._run_dataset_types = list(run_dataset_types)
self._prerequisite_inputs = set(prerequisite_inputs)
self._is_sorted = is_sorted

TASK_BIPARTITE: ClassVar[int] = 0
DATASET_TYPE_BIPARTITE: ClassVar[int] = 1

@property
def tasks(self) -> TasksView[_T]:
return TasksView(self)

@property
def init_dataset_types(self) -> InitDatasetTypesView[_D]:
return InitDatasetTypesView(self)

@property
def run_dataset_types(self) -> RunDatasetTypesView[_D]:
return RunDatasetTypesView(self)

@property
def dataset_types(self) -> Mapping[str, _D]:
# ChainMap is a MutableMapping, so it wants to be passed
# MutableMappings, but we've annotated the return type as just
# Mapping to ensure it really only needs Mappings.
return ChainMap(self.init_dataset_types, self.run_dataset_types) # type: ignore

@property
@cached_getter
def xgraph(self) -> networkx.DiGraph:
return self._xgraph.copy(as_view=True)

@property
def is_sorted(self) -> bool:
return self._is_sorted

def sort(self) -> None:
if not self._is_sorted:
task_labels = []
init_dataset_types = []
run_dataset_types = []
try:
for k in networkx.lexicographical_topological_sort(self._xgraph):
match self._xgraph[k]:
case {"bipartite": self.TASK_BIPARTITE}:
task_labels.append(k)
case {"bipartite": self.DATASET_TYPE_BIPARTITE, "is_init": True}:
init_dataset_types.append(k)
case {"bipartite": self.DATASET_TYPE_BIPARTITE, "is_init": False}:
run_dataset_types.append(k)
except networkx.NetworkXUnfeasible as err:
raise DatasetDependencyError("Cycle detected while attempting to sort graph.") from err
self._task_labels = task_labels
self._init_dataset_types = init_dataset_types
self._run_dataset_types = run_dataset_types
self._is_sorted = True

def producer_of(self, dataset_type_name: str) -> WriteEdge | None:
for _, _, edge in self._xgraph.in_edges(dataset_type_name, "state"):
return edge
return None

def consumers_of(self, dataset_type_name: str) -> dict[str, ReadEdge]:
return {
task_label: edge for task_label, _, edge in self._xgraph.out_edges(dataset_type_name, "state")
}

@abstractmethod
def _make_task_vertex(self, task_label: str) -> _T:
raise NotImplementedError()

@abstractmethod
def _make_dataset_type_vertex(self, dataset_type_name: str, is_init: bool) -> _D:
raise NotImplementedError()

@cached_getter
def _get_task_xgraph(self) -> networkx.MultiDiGraph:
return networkx.freeze(
networkx.bipartite.projected_graph(self._xgraph, self._task_labels, multigraph=True)
)

@cached_getter
def _get_init_dataset_type_xgraph(self) -> networkx.DiGraph:
init_subgraph = self._xgraph.subgraph(itertools.chain(self._task_labels, self._init_dataset_types))
return networkx.freeze(networkx.bipartite.projected_graph(init_subgraph, self._init_dataset_types))

@cached_getter
def _get_run_dataset_type_xgraph(self) -> networkx.DiGraph:
run_subgraph = self._xgraph.subgraph(itertools.chain(self._task_labels, self._init_dataset_types))
return networkx.freeze(networkx.bipartite.projected_graph(run_subgraph, self._run_dataset_types))
Loading

0 comments on commit 44ac030

Please sign in to comment.