Skip to content

Commit

Permalink
Switch to a single PipelineGraph class.
Browse files Browse the repository at this point in the history
  • Loading branch information
TallJimbo committed Jun 10, 2023
1 parent f370017 commit e8d3349
Show file tree
Hide file tree
Showing 18 changed files with 863 additions and 1,327 deletions.
4 changes: 2 additions & 2 deletions python/lsst/pipe/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from .graphBuilder import *
from .pipeline import *

# We import the main PipelineGraph types and the module (above), but we don't
# We import the main PipelineGraph type and the module (above), but we don't
# lift all symbols to package scope.
from .pipeline_graph import MutablePipelineGraph, ResolvedPipelineGraph
from .pipeline_graph import PipelineGraph
from .pipelineTask import *
from .struct import *
from .task import *
Expand Down
6 changes: 3 additions & 3 deletions python/lsst/pipe/base/pipeTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from .pipeline import Pipeline, TaskDef

# Exceptions re-exported here for backwards compatibility.
from .pipeline_graph import DuplicateOutputError, MutablePipelineGraph, PipelineDataCycleError # noqa: F401
from .pipeline_graph import DuplicateOutputError, PipelineDataCycleError, PipelineGraph # noqa: F401

if TYPE_CHECKING:
from .taskFactory import TaskFactory
Expand Down Expand Up @@ -73,7 +73,7 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF
if isinstance(pipeline, Pipeline):
graph = pipeline.to_graph()
else:
graph = MutablePipelineGraph()
graph = PipelineGraph()
for task_def in pipeline:
graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
# Can't use graph.is_sorted because that requires sorted dataset type names
Expand Down Expand Up @@ -111,7 +111,7 @@ def orderPipeline(pipeline: Pipeline | Iterable[TaskDef]) -> list[TaskDef]:
if isinstance(pipeline, Pipeline):
graph = pipeline.to_graph()
else:
graph = MutablePipelineGraph()
graph = PipelineGraph()
for task_def in pipeline:
graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
graph.sort()
Expand Down
67 changes: 32 additions & 35 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,7 @@

# -----------------------------
# Imports for other modules --
from lsst.daf.butler import (
DataCoordinate,
DatasetType,
DimensionUniverse,
NamedValueSet,
Registry,
)
from lsst.daf.butler import DataCoordinate, DatasetType, DimensionUniverse, NamedValueSet, Registry
from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.utils import doImportType
from lsst.utils.introspection import get_full_type_name
Expand Down Expand Up @@ -757,7 +751,7 @@ def write_to_uri(self, uri: ResourcePathExpression) -> None:
"""
self._pipelineIR.write_to_uri(uri)

def to_graph(self) -> pipeline_graph.MutablePipelineGraph:
def to_graph(self) -> pipeline_graph.PipelineGraph:
"""Construct a pipeline graph from this pipeline.
Constructing a graph applies all configuration overrides, freezes all
Expand All @@ -767,10 +761,10 @@ def to_graph(self) -> pipeline_graph.MutablePipelineGraph:
Returns
-------
graph : `pipeline_graph.MutablePipelineGraph`
graph : `pipeline_graph.PipelineGraph`
Representation of the pipeline as a graph.
"""
graph = pipeline_graph.MutablePipelineGraph()
graph = pipeline_graph.PipelineGraph()
graph.description = self._pipelineIR.description
for label in self._pipelineIR.tasks:
self._add_task_to_graph(label, graph)
Expand Down Expand Up @@ -810,15 +804,15 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
"""
yield from self.to_graph()._iter_task_defs()

def _add_task_to_graph(self, label: str, graph: pipeline_graph.MutablePipelineGraph) -> None:
def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) -> None:
"""Add a single task from this pipeline to a pipeline graph that is
under construction.
Parameters
----------
label : `str`
Label for the task to be added.
graph : `pipeline_graph.MutablePipelineGraph`
graph : `pipeline_graph.PipelineGraph`
Graph to add the task to.
"""
if (taskIR := self._pipelineIR.tasks.get(label)) is None:
Expand All @@ -845,7 +839,7 @@ def __getitem__(self, item: str) -> TaskDef:
# Making a whole graph and then making a TaskDef from that is pretty
# backwards, but I'm hoping to deprecate this method shortly in favor
# of making the graph explicitly and working with its node objects.
graph = pipeline_graph.MutablePipelineGraph()
graph = pipeline_graph.PipelineGraph()
self._add_task_to_graph(item, graph)
(result,) = graph._iter_task_defs()
return result
Expand Down Expand Up @@ -971,17 +965,19 @@ def fromTaskDef(
# the whole class soon, but for now and before it's actually removed
# it's more important to avoid duplication with PipelineGraph's dataset
# type resolution logic.
mgraph = pipeline_graph.MutablePipelineGraph()
mgraph.add_task(taskDef.label, taskDef.taskClass, taskDef.config, taskDef.connections)
rgraph = mgraph.resolved(registry)
(task_node,) = rgraph.tasks.values()
return cls._from_graph_nodes(task_node, rgraph.dataset_types)
graph = pipeline_graph.PipelineGraph()
graph.add_task(taskDef.label, taskDef.taskClass, taskDef.config, taskDef.connections)
graph.resolve(registry)
(task_node,) = graph.tasks.values()
return cls._from_graph_nodes(
task_node, cast(Mapping[str, pipeline_graph.DatasetTypeNode], graph.dataset_types)
)

@classmethod
def _from_graph_nodes(
cls,
task_node: pipeline_graph.TaskNode,
dataset_type_nodes: Mapping[str, pipeline_graph.ResolvedDatasetTypeNode],
dataset_type_nodes: Mapping[str, pipeline_graph.DatasetTypeNode],
include_configs: bool = True,
) -> TaskDatasetTypes:
"""Construct from `PipelineGraph` nodes.
Expand Down Expand Up @@ -1146,17 +1142,17 @@ def fromPipeline(
of the same `Pipeline`.
"""
if isinstance(pipeline, Pipeline):
mgraph = pipeline.to_graph()
graph = pipeline.to_graph()
else:
mgraph = pipeline_graph.MutablePipelineGraph()
graph = pipeline_graph.PipelineGraph()
for task_def in pipeline:
mgraph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
rgraph = mgraph.resolved(registry)
graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
graph.resolve(registry)
byTask = dict()
for task_node in rgraph.tasks.values():
for task_node in graph.tasks.values():
byTask[task_node.label] = TaskDatasetTypes._from_graph_nodes(
task_node,
rgraph.dataset_types,
cast(Mapping[str, pipeline_graph.DatasetTypeNode], graph.dataset_types),
include_configs=include_configs,
)
result = cls(
Expand All @@ -1177,8 +1173,9 @@ def fromPipeline(
# PipelineGraph does, by putting components in the edge objects). But
# including all components as well is what this code has done in the
# past and changing that would break downstream code.
for dataset_type_node in rgraph.dataset_types.values():
if consumers := rgraph.consumers_of(dataset_type_node.name):
for dataset_type_node in graph.dataset_types.values():
assert dataset_type_node is not None, "Graph is expected to be resolved."
if consumers := graph.consumers_of(dataset_type_node.name):
dataset_types = [
(
dataset_type_node.dataset_type.makeComponentDatasetType(edge.component)
Expand All @@ -1188,21 +1185,21 @@ def fromPipeline(
for edge in consumers.values()
]
if any(edge.is_init for edge in consumers.values()):
if rgraph.producer_of(dataset_type_node.name) is None:
if graph.producer_of(dataset_type_node.name) is None:
result.initInputs.update(dataset_types)
else:
result.initIntermediates.update(dataset_types)
else:
if dataset_type_node.is_prerequisite:
result.prerequisites.update(dataset_types)
elif rgraph.producer_of(dataset_type_node.name) is None:
elif graph.producer_of(dataset_type_node.name) is None:
result.inputs.update(dataset_types)
if dataset_type_node.is_initial_query_constraint:
result.queryConstraints.add(dataset_type_node.dataset_type)
elif rgraph.consumers_of(dataset_type_node.name):
elif graph.consumers_of(dataset_type_node.name):
result.intermediates.update(dataset_types)
else:
producer = rgraph.producer_of(dataset_type_node.name)
producer = graph.producer_of(dataset_type_node.name)
assert (
producer is not None
), "Dataset type must have either a producer or consumers to be in graph."
Expand Down Expand Up @@ -1254,15 +1251,15 @@ def initOutputNames(
Name of the dataset type.
"""
if isinstance(pipeline, Pipeline):
mgraph = pipeline.to_graph()
graph = pipeline.to_graph()
else:
mgraph = pipeline_graph.MutablePipelineGraph()
graph = pipeline_graph.PipelineGraph()
for task_def in pipeline:
mgraph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
if include_packages:
# Package versions dataset type
yield cls.packagesDatasetName
for task_node in mgraph.tasks.values():
for task_node in graph.tasks.values():
edges = task_node.init.iter_all_outputs() if include_configs else task_node.init.outputs
for edge in edges:
yield edge.dataset_type_name
2 changes: 0 additions & 2 deletions python/lsst/pipe/base/pipeline_graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
from ._dataset_types import *
from ._edges import *
from ._exceptions import *
from ._mutable_pipeline_graph import *
from ._nodes import *
from ._pipeline_graph import *
from ._resolved_pipeline_graph import *
from ._task_subsets import *
from ._tasks import *
Loading

0 comments on commit e8d3349

Please sign in to comment.