Skip to content

Commit

Permalink
Move ExtractHelper to a separate file.
Browse files Browse the repository at this point in the history
  • Loading branch information
TallJimbo committed Mar 27, 2023
1 parent 65daa37 commit e4b109f
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 68 deletions.
106 changes: 106 additions & 0 deletions python/lsst/pipe/base/pipeline_graph/_extract_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations

__all__ = ("ExtractHelper",)

from collections.abc import Iterable
from types import EllipsisType
from typing import TYPE_CHECKING, Generic, TypeVar

import networkx
import networkx.algorithms.bipartite
import networkx.algorithms.dag
from lsst.utils.iteration import ensure_iterable

from ._nodes import Node, NodeKey, NodeType

if TYPE_CHECKING:
from ._mutable_pipeline_graph import MutablePipelineGraph
from ._pipeline_graph import PipelineGraph


_P = TypeVar("_P", bound="PipelineGraph", covariant=True)


class ExtractHelper(Generic[_P]):
def __init__(self, parent: _P) -> None:
self._parent = parent
self._run_xgraph: networkx.DiGraph | None = None
self._task_keys: set[NodeKey] = set()

def include_tasks(self, labels: str | Iterable[str] | EllipsisType = ...) -> None:
if labels is ...:
self._task_keys.update(key for key in self._parent._xgraph if key.node_type is NodeType.TASK)
else:
self._task_keys.update(
NodeKey(NodeType.TASK, task_label) for task_label in ensure_iterable(labels)
)

def exclude_tasks(self, labels: str | Iterable[str]) -> None:
self._task_keys.difference_update(
NodeKey(NodeType.TASK, task_label) for task_label in ensure_iterable(labels)
)

def include_subset(self, label: str) -> None:
self._task_keys.update(node.key for node in self._parent.task_subsets[label].values())

def exclude_subset(self, label: str) -> None:
self._task_keys.difference_update(node.key for node in self._parent.task_subsets[label].values())

def start_after(self, names: str | Iterable[str], node_type: NodeType) -> None:
to_exclude: set[NodeKey] = set()
for name in ensure_iterable(names):
key = NodeKey(node_type, name)
to_exclude.update(networkx.algorithms.dag.ancestors(self._get_run_xgraph(), key))
to_exclude.add(key)
self._task_keys.difference_update(to_exclude)

def stop_at(self, names: str | Iterable[str], node_type: NodeType) -> None:
to_exclude: set[NodeKey] = set()
for name in ensure_iterable(names):
key = NodeKey(node_type, name)
to_exclude.update(networkx.algorithms.dag.descendants(self._get_run_xgraph(), key))
self._task_keys.difference_update(to_exclude)

def finish(self, description: str | None = None) -> MutablePipelineGraph:
from ._mutable_pipeline_graph import MutablePipelineGraph

if description is None:
description = self._parent._description
# Combine the task_keys we're starting with and the keys for their init
# nodes.
keys = self._task_keys | {NodeKey(NodeType.TASK_INIT, key.name) for key in self._task_keys}
# Also add the keys for the adjacent dataset type nodes.
keys.update(networkx.node_boundary(self._parent._xgraph.to_undirected(as_view=True), keys))
# Make the new backing networkx graph.
xgraph: networkx.DiGraph = self._parent._xgraph.subgraph(keys).copy()
for state in xgraph.nodes.values():
node: Node = state["instance"]
state["instance"] = node._unresolved()
result = MutablePipelineGraph.__new__(MutablePipelineGraph)
result._init_from_args(xgraph, None, description=description)
return result

def _get_run_xgraph(self) -> networkx.DiGraph:
if self._run_xgraph is None:
self._run_xgraph = self._parent.make_bipartite_xgraph(init=False)
return self._run_xgraph
72 changes: 4 additions & 68 deletions python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,15 @@
import os
import tarfile
from abc import abstractmethod
from collections.abc import Iterable, Iterator, Mapping, Sequence
from collections.abc import Iterator, Mapping, Sequence
from datetime import datetime
from types import EllipsisType
from typing import TYPE_CHECKING, Any, BinaryIO, Generic, TypeVar, cast

import networkx
import networkx.algorithms.bipartite
import networkx.algorithms.dag
from lsst.daf.butler import Registry
from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.utils.iteration import ensure_iterable

from ._edges import Edge, ReadEdge, WriteEdge
from ._exceptions import PipelineDataCycleError
Expand All @@ -47,6 +45,7 @@

if TYPE_CHECKING:
from ..pipeline import TaskDef
from ._extract_helper import ExtractHelper
from ._mutable_pipeline_graph import MutablePipelineGraph
from ._resolved_pipeline_graph import ResolvedPipelineGraph

Expand All @@ -55,71 +54,6 @@
_P = TypeVar("_P", bound="PipelineGraph", covariant=True)


class ExtractHelper(Generic[_P]):
def __init__(self, parent: _P) -> None:
self._parent = parent
self._run_xgraph: networkx.DiGraph | None = None
self._task_keys: set[NodeKey] = set()

def include_tasks(self, labels: str | Iterable[str] | EllipsisType = ...) -> None:
if labels is ...:
self._task_keys.update(key for key in self._parent._xgraph if key.node_type is NodeType.TASK)
else:
self._task_keys.update(
NodeKey(NodeType.TASK, task_label) for task_label in ensure_iterable(labels)
)

def exclude_tasks(self, labels: str | Iterable[str]) -> None:
self._task_keys.difference_update(
NodeKey(NodeType.TASK, task_label) for task_label in ensure_iterable(labels)
)

def include_subset(self, label: str) -> None:
self._task_keys.update(node.key for node in self._parent.task_subsets[label].values())

def exclude_subset(self, label: str) -> None:
self._task_keys.difference_update(node.key for node in self._parent.task_subsets[label].values())

def start_after(self, names: str | Iterable[str], node_type: NodeType) -> None:
to_exclude: set[NodeKey] = set()
for name in ensure_iterable(names):
key = NodeKey(node_type, name)
to_exclude.update(networkx.algorithms.dag.ancestors(self._get_run_xgraph(), key))
to_exclude.add(key)
self._task_keys.difference_update(to_exclude)

def stop_at(self, names: str | Iterable[str], node_type: NodeType) -> None:
to_exclude: set[NodeKey] = set()
for name in ensure_iterable(names):
key = NodeKey(node_type, name)
to_exclude.update(networkx.algorithms.dag.descendants(self._get_run_xgraph(), key))
self._task_keys.difference_update(to_exclude)

def finish(self, description: str | None = None) -> MutablePipelineGraph:
from ._mutable_pipeline_graph import MutablePipelineGraph

if description is None:
description = self._parent._description
# Combine the task_keys we're starting with and the keys for their init
# nodes.
keys = self._task_keys | {NodeKey(NodeType.TASK_INIT, key.name) for key in self._task_keys}
# Also add the keys for the adjacent dataset type nodes.
keys.update(networkx.node_boundary(self._parent._xgraph.to_undirected(as_view=True), keys))
# Make the new backing networkx graph.
xgraph: networkx.DiGraph = self._parent._xgraph.subgraph(keys).copy()
for state in xgraph.nodes.values():
node: Node = state["instance"]
state["instance"] = node._unresolved()
result = MutablePipelineGraph.__new__(MutablePipelineGraph)
result._init_from_args(xgraph, None, description=description)
return result

def _get_run_xgraph(self) -> networkx.DiGraph:
if self._run_xgraph is None:
self._run_xgraph = self._parent.make_bipartite_xgraph(init=False)
return self._run_xgraph


class PipelineGraph(Generic[_T, _D, _S]):
"""A base class for directed acyclic graph of `PipelineTask` definitions.
Expand Down Expand Up @@ -355,6 +289,8 @@ def extract(self) -> ExtractHelper:
"""Create a new `MutablePipelineGraph` containing just the tasks that
match the given criteria.
"""
from ._extract_helper import ExtractHelper

return ExtractHelper(self)

def _reorder(self, sorted_keys: Sequence[NodeKey]) -> None:
Expand Down

0 comments on commit e4b109f

Please sign in to comment.