Skip to content

Commit

Permalink
WIP: Move pipeline graph code around, add copy methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
TallJimbo committed Jan 20, 2023
1 parent d1ea94e commit d976e5a
Show file tree
Hide file tree
Showing 12 changed files with 920 additions and 905 deletions.
2 changes: 1 addition & 1 deletion python/lsst/pipe/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# We import the main PipelineGraph types and the module (above), but we don't
# lift all symbols to package scope.
from .pipeline_graph import ResolvedPipelineGraph, UnresolvedPipelineGraph
from .pipeline_graph import MutablePipelineGraph, ResolvedPipelineGraph
from .pipelineTask import *
from .struct import *
from .task import *
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
from .config import PipelineTaskConfig
from .configOverrides import ConfigOverrides
from .connections import iterConnections
from .pipeline_graph import UnresolvedPipelineGraph
from .pipeline_graph import MutablePipelineGraph
from .pipelineTask import PipelineTask
from .task import _TASK_METADATA_TYPE

Expand Down Expand Up @@ -677,7 +677,7 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
raise pipelineIR.ContractError(
f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
)
graph = UnresolvedPipelineGraph()
graph = MutablePipelineGraph()
for task_def in taskDefs:
graph.add_task(
task_def.label,
Expand Down
7 changes: 4 additions & 3 deletions python/lsst/pipe/base/pipeline_graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations

from ._abcs import *
from ._dataset_types import *
from ._edges import *
from ._exceptions import *
from ._generic_pipeline_graph import *
from ._pipeline_graph import *
from ._resolved_pipeline_graph import *
from ._unresolved_pipeline_graph import *
from ._vertices_and_edges import *
from ._tasks import *
210 changes: 210 additions & 0 deletions python/lsst/pipe/base/pipeline_graph/_abcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations

__all__ = (
"Node",
"Edge",
"SubgraphView",
"MappingSubgraphView",
)

import itertools
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator, Mapping
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar

import networkx
from lsst.daf.butler import DatasetType, DimensionUniverse, Registry

from ._exceptions import ConnectionTypeConsistencyError, IncompatibleDatasetTypeError

if TYPE_CHECKING:
from ..connectionTypes import BaseConnection
from ._dataset_types import DatasetTypeNode


class Node(ABC):

BIPARTITE_CONSTANT: ClassVar[int]

@abstractmethod
def _resolve(self, state: dict[str, Any], graph: networkx.DiGraph, registry: Registry) -> None:
raise NotImplementedError()

@abstractmethod
def _unresolve(self, state: dict[str, Any]) -> None:
raise NotImplementedError()


class Edge(ABC):

task_label: str
parent_dataset_type_name: str

@property
@abstractmethod
def key(self) -> tuple[str, str]:
raise NotImplementedError()

@property
def dataset_type_name(self) -> str:
return self.parent_dataset_type_name

@classmethod
@abstractmethod
def _from_connection(
cls,
task_label: str,
connection: BaseConnection,
edge_data: list[tuple[str, str, dict[str, Any]]],
*,
is_init: bool,
is_prerequisite: bool = False,
) -> Edge:
raise NotImplementedError()

@abstractmethod
def _check_dataset_type(
self,
state: dict[str, Any],
xgraph: networkx.DiGraph,
dataset_type_node: DatasetTypeNode,
) -> None:
if state["is_init"] != dataset_type_node.is_init:
referencing_tasks = list(
itertools.chain(
xgraph.predecessors(dataset_type_node.name),
xgraph.successors(dataset_type_node.name),
)
)
if state["is_init"]:
raise ConnectionTypeConsistencyError(
f"{dataset_type_node.name!r} is an init dataset in task {self.task_label!r}, "
f"but a run dataset in task(s) {referencing_tasks}."
)
else:
raise ConnectionTypeConsistencyError(
f"{dataset_type_node.name!r} is a run dataset in task {self.task_label!r}, "
f"but an init dataset in task(s) {referencing_tasks}."
)
if state["is_prerequisite"] != dataset_type_node.is_prerequisite:
referencing_tasks = list(xgraph.successors(dataset_type_node.name))
if state["is_prerequisite"]:
raise ConnectionTypeConsistencyError(
f"Dataset type {dataset_type_node.name!r} is a prerequisite input in "
f"task {self.task_label!r}, but it was not a prerequisite to "
f"{referencing_tasks}."
)
else:
raise ConnectionTypeConsistencyError(
f"Dataset type {dataset_type_node.name!r} is not a prerequisite input in "
f"task {self.task_label!r}, but it was a prerequisite to "
f"{referencing_tasks}."
)
connection: BaseConnection = state["connection"]
if connection.isCalibration != dataset_type_node.is_calibration:
referencing_tasks = list(
itertools.chain(
xgraph.predecessors(dataset_type_node.name),
xgraph.successors(dataset_type_node.name),
)
)
if connection.isCalibration:
raise IncompatibleDatasetTypeError(
f"Dataset type {dataset_type_node.name!r} is a calibration in "
f"task {self.task_label}, but it was not in task(s) {referencing_tasks}."
)
else:
raise IncompatibleDatasetTypeError(
f"Dataset type {dataset_type_node.name!r} is not a calibration in "
f"task {self.task_label}, but it was in task(s) {referencing_tasks}."
)

@abstractmethod
def _make_dataset_type_state(self, state: dict[str, Any]) -> dict[str, Any]:
raise NotImplementedError()

@abstractmethod
def _resolve_dataset_type(
self,
state: dict[str, Any],
current: DatasetType | None,
universe: DimensionUniverse,
) -> DatasetType:
raise NotImplementedError()


_G = TypeVar("_G", bound=networkx.DiGraph, covariant=True)


class SubgraphView(Generic[_G]):
def __init__(self, parent_xgraph: networkx.DiGraph) -> None:
self._parent_xgraph = parent_xgraph
self._xgraph: _G | None = None

@property
def xgraph(self) -> _G:
if self._xgraph is None:
self._xgraph = self._make_xgraph()
return self._xgraph

@abstractmethod
def _make_xgraph(self) -> _G:
raise NotImplementedError()

def _reset(self) -> None:
self._xgraph = None


_N = TypeVar("_N", bound=Node, covariant=True)


class MappingSubgraphView(Generic[_G, _N], Mapping[str, _N], SubgraphView[_G]):
def __init__(self, parent_xgraph: networkx.DiGraph) -> None:
self._parent_xgraph = parent_xgraph
self._keys: list[str] | None = None

@abstractmethod
def __contains__(self, key: object) -> bool:
raise NotImplementedError()

def __iter__(self) -> Iterator[str]:
if self._keys is None:
self._keys = [k for k in self._parent_xgraph if k in self]
return iter(self._keys)

def __getitem__(self, key: str) -> _N:
if key not in self:
raise KeyError(key)
return self._parent_xgraph[key]["instance"]

def __len__(self) -> int:
if self._keys is None:
self._keys = [k for k in self._parent_xgraph if k in self]
return len(self._keys)

def _reorder(self, parent_keys: Iterable[str]) -> None:
self._keys = [k for k in parent_keys if k in self]

def _reset(self) -> None:
super()._reset()
self._keys = None
109 changes: 109 additions & 0 deletions python/lsst/pipe/base/pipeline_graph/_dataset_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations

__all__ = (
"DatasetTypeNode",
"ResolvedDatasetTypeNode",
)

import dataclasses
import itertools
from typing import Any, ClassVar, TypeVar

import networkx
from lsst.daf.butler import DatasetType, Registry
from lsst.daf.butler.registry import MissingDatasetTypeError

from ._abcs import Edge, MappingSubgraphView, Node


@dataclasses.dataclass(frozen=True, eq=False)
class DatasetTypeNode(Node):
BIPARTITE_CONSTANT: ClassVar[int] = 1

name: str
is_calibration: bool
is_init: bool
is_prerequisite: bool

def _resolve(self, state: dict[str, Any], graph: networkx.DiGraph, registry: Registry) -> None:
try:
dataset_type = registry.getDatasetType(self.name)
except MissingDatasetTypeError:
dataset_type = None
for edge_state in itertools.chain(
graph.in_edges(self.name, data=True), graph.out_edges(self.name, data=True)
):
edge: Edge = edge_state["instance"]
dataset_type = edge._resolve_dataset_type(
edge_state,
current=dataset_type,
universe=registry.dimensions,
)
assert dataset_type is not None, "Graph structure guarantees at least one edge."
state[self.name] = ResolvedDatasetTypeNode(
name=self.name,
is_calibration=self.is_calibration,
is_init=self.is_init,
is_prerequisite=self.is_prerequisite,
dataset_type=dataset_type,
)

def _unresolve(self, state: dict[str, Any]) -> None:
pass


@dataclasses.dataclass(frozen=True, eq=False)
class ResolvedDatasetTypeNode(DatasetTypeNode):
dataset_type: DatasetType

def _resolve(self, state: dict[str, Any], graph: networkx.DiGraph, registry: Registry) -> None:
pass

def _unresolve(self, state: dict[str, Any]) -> None:
state[self.name] = DatasetTypeNode(
name=self.name,
is_calibration=self.is_calibration,
is_init=self.is_init,
is_prerequisite=self.is_prerequisite,
)


_D = TypeVar("_D", bound=DatasetTypeNode, covariant=True)


class DatasetTypeSubgraphView(MappingSubgraphView[networkx.DiGraph, _D]):
def __init__(self, parent_xgraph: networkx.DiGraph, is_init: bool):
super().__init__(parent_xgraph)
self._is_init = is_init

def __contains__(self, key: object) -> bool:
if state := self._parent_xgraph.nodes.get(key):
return (
state["bipartite"] == DatasetTypeNode.BIPARTITE_CONSTANT and state["is_init"] == self._is_init
)
return False

def _make_xgraph(self) -> networkx.DiGraph:
return networkx.freeze(
networkx.bipartite.projected_graph(self._parent_xgraph, self, multigraph=False)
)
Loading

0 comments on commit d976e5a

Please sign in to comment.