From b99eaf507f085751ac030d10b3e09cadfa3b4d80 Mon Sep 17 00:00:00 2001 From: Nate Lust Date: Tue, 11 Jun 2024 13:55:11 -0400 Subject: [PATCH] Fixups from review --- python/lsst/pipe/base/tests/simpleQGraph.py | 108 ++++++++++++++++++++ tests/test_pipeline.py | 44 +++++++- 2 files changed, 151 insertions(+), 1 deletion(-) diff --git a/python/lsst/pipe/base/tests/simpleQGraph.py b/python/lsst/pipe/base/tests/simpleQGraph.py index 60cd86ebb..616f7b46b 100644 --- a/python/lsst/pipe/base/tests/simpleQGraph.py +++ b/python/lsst/pipe/base/tests/simpleQGraph.py @@ -199,6 +199,114 @@ def makeTask( return task +class SubTaskConnections( + PipelineTaskConnections, + dimensions=("instrument", "detector"), + defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}, +): + """Connections for SubTask, has one input and two outputs, + plus one init output. + """ + + input = cT.Input( + name="add_dataset{in_tmpl}", + dimensions=["instrument", "detector"], + storageClass="NumpyArray", + doc="Input dataset type for this task", + ) + output = cT.Output( + name="add_dataset{out_tmpl}", + dimensions=["instrument", "detector"], + storageClass="NumpyArray", + doc="Output dataset type for this task", + ) + output2 = cT.Output( + name="add2_dataset{out_tmpl}", + dimensions=["instrument", "detector"], + storageClass="NumpyArray", + doc="Output dataset type for this task", + ) + initout = cT.InitOutput( + name="add_init_output{out_tmpl}", + storageClass="NumpyArray", + doc="Init Output dataset type for this task", + ) + + +class SubTaskConfig(PipelineTaskConfig, pipelineConnections=SubTaskConnections): + """Config for SubTask.""" + + subtract = pexConfig.Field[int](doc="amount to subtract", default=3) + + +class SubTask(PipelineTask): + """Trivial PipelineTask for testing, has some extras useful for specific + unit tests. + """ + + ConfigClass = SubTaskConfig + _DefaultName = "sub_task" + + initout = numpy.array([999]) + """InitOutputs for this task""" + + taskFactory: SubTaskFactoryMock | None = None + """Factory that makes instances""" + + def run(self, input: int) -> Struct: + if self.taskFactory: + # do some bookkeeping + if self.taskFactory.stopAt == self.taskFactory.countExec: + raise RuntimeError("pretend something bad happened") + self.taskFactory.countExec -= 1 + + self.config = cast(SubTaskConfig, self.config) + self.metadata.add("sub", self.config.subtract) + output = input - self.config.subtract + output2 = output + self.config.subtract + _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) + return Struct(output=output, output2=output2) + + +class SubTaskFactoryMock(TaskFactory): + """Special task factory that instantiates AddTask. + + It also defines some bookkeeping variables used by SubTask to report + progress to unit tests. + + Parameters + ---------- + stopAt : `int`, optional + Number of times to call `run` before stopping. + """ + + def __init__(self, stopAt: int = -1): + self.countExec = 100 # reduced by SubTask + self.stopAt = stopAt # AddTask raises exception at this call to run() + + def makeTask( + self, + task_node: TaskDef | TaskNode, + /, + butler: LimitedButler, + initInputRefs: Iterable[DatasetRef] | None, + ) -> PipelineTask: + if isinstance(task_node, TaskDef): + # TODO: remove support on DM-40443. + warnings.warn( + "Passing TaskDef to TaskFactory is deprecated and will not be supported after v27.", + FutureWarning, + find_outside_stacklevel("lsst.pipe.base"), + ) + task_class = task_node.taskClass + assert task_class is not None + else: + task_class = task_node.task_class + task = task_class(config=task_node.config, initInputs=None, name=task_node.label) + task.taskFactory = self # type: ignore + return task + + def registerDatasetTypes(registry: Registry, pipeline: Pipeline | Iterable[TaskDef] | PipelineGraph) -> None: """Register all dataset types used by tasks in a registry. diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index a59aa9cc8..eca8e6cd0 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -35,7 +35,7 @@ import lsst.utils.tests from lsst.pipe.base import LabelSpecifier, Pipeline, TaskDef from lsst.pipe.base.pipelineIR import LabeledSubset -from lsst.pipe.base.tests.simpleQGraph import AddTask, makeSimplePipeline +from lsst.pipe.base.tests.simpleQGraph import AddTask, SubTask, makeSimplePipeline class PipelineTestCase(unittest.TestCase): @@ -131,6 +131,48 @@ def testMergingPipelines(self): pipeline1.mergePipeline(pipeline2) self.assertEqual(pipeline1._pipelineIR.tasks.keys(), {"task0", "task1", "task2", "task3"}) + # Test merging pipelines with ambiguous tasks + pipeline1 = makeSimplePipeline(2) + pipeline2 = makeSimplePipeline(2) + pipeline2.addTask(SubTask, "task1") + pipeline2.mergePipeline(pipeline1) + + # Now merge in another pipeline with a config applied. + pipeline3 = makeSimplePipeline(2) + pipeline3.addTask(SubTask, "task1") + pipeline3.addConfigOverride("task1", "subtract", 10) + pipeline3.mergePipeline(pipeline2) + graph = pipeline3.to_graph() + # assert equality from the graph to trigger ambiquity resolution + self.assertEqual(graph.tasks["task1"].config.subtract, 10) + + # Now change the order of the merging + pipeline1 = makeSimplePipeline(2) + pipeline2 = makeSimplePipeline(2) + pipeline2.addTask(SubTask, "task1") + pipeline3 = makeSimplePipeline(2) + pipeline3.mergePipeline(pipeline2) + pipeline3.mergePipeline(pipeline1) + graph = pipeline3.to_graph() + # assert equality from the graph to trigger ambiquity resolution + self.assertEqual(graph.tasks["task1"].config.addend, 3) + + # Now do two ambiguous chains + pipeline1 = makeSimplePipeline(2) + pipeline2 = makeSimplePipeline(2) + pipeline2.addTask(SubTask, "task1") + pipeline2.addConfigOverride("task1", "subtract", 10) + pipeline2.mergePipeline(pipeline1) + + pipeline3 = makeSimplePipeline(2) + pipeline4 = makeSimplePipeline(2) + pipeline4.addTask(SubTask, "task1") + pipeline4.addConfigOverride("task1", "subtract", 7) + pipeline4.mergePipeline(pipeline3) + graph = pipeline4.to_graph() + # assert equality from the graph to trigger ambiquity resolution + self.assertEqual(graph.tasks["task1"].config.subtract, 7) + def testFindingSubset(self): pipeline = makeSimplePipeline(2) pipeline._pipelineIR.labeled_subsets["test1"] = LabeledSubset("test1", set(), None)