Skip to content

Commit

Permalink
Use PipelineGraph instead of PipelineDatasetTypes in step tester.
Browse files Browse the repository at this point in the history
  • Loading branch information
TallJimbo committed Jun 23, 2023
1 parent 386ecd1 commit b1d8206
Showing 1 changed file with 15 additions and 25 deletions.
40 changes: 15 additions & 25 deletions python/lsst/pipe/base/tests/pipelineStepTester.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import unittest

from lsst.daf.butler import Butler, DatasetType
from lsst.pipe.base import Pipeline, PipelineDatasetTypes
from lsst.pipe.base import Pipeline


@dataclasses.dataclass
Expand Down Expand Up @@ -88,32 +88,22 @@ def run(self, butler: Butler, test_case: unittest.TestCase) -> None:
pure_inputs: dict[str, str] = dict()

for suffix in self.step_suffixes:
pipeline = Pipeline.from_uri(self.filename + suffix)
dataset_types = PipelineDatasetTypes.fromPipeline(
pipeline,
registry=butler.registry,
include_configs=False,
include_packages=False,
)
step_graph = Pipeline.from_uri(self.filename + suffix).to_graph()
step_graph.resolve(butler.registry)

pure_inputs.update({k: suffix for k in dataset_types.prerequisites.names})
parent_inputs = {t.nameAndComponent()[0] for t in dataset_types.inputs}
pure_inputs.update({k: suffix for k in parent_inputs - all_outputs.keys()})
all_outputs.update(dataset_types.outputs.asMapping())
all_outputs.update(dataset_types.intermediates.asMapping())

for name in dataset_types.inputs.names & all_outputs.keys():
test_case.assertTrue(
all_outputs[name].is_compatible_with(dataset_types.inputs[name]),
msg=(
f"dataset type {name} is defined as {dataset_types.inputs[name]} as an "
f"input, but {all_outputs[name]} as an output, and these are not compatible."
),
)
pure_inputs.update(
{name: suffix for name, _ in step_graph.iter_overall_inputs() if name not in all_outputs}
)
all_outputs.update(
{
name: node.dataset_type
for name, node in step_graph.dataset_types.items()
if step_graph.producer_of(name) is not None
}
)

for dataset_type in dataset_types.outputs | dataset_types.intermediates:
if not dataset_type.isComponent():
butler.registry.registerDatasetType(dataset_type)
for node in step_graph.dataset_types.values():
butler.registry.registerDatasetType(node.dataset_type)

if not pure_inputs.keys() <= self.expected_inputs:
missing = [f"{k} ({pure_inputs[k]})" for k in pure_inputs.keys() - self.expected_inputs]
Expand Down

0 comments on commit b1d8206

Please sign in to comment.