Skip to content

Commit

Permalink
Fixups from review
Browse files Browse the repository at this point in the history
  • Loading branch information
natelust committed Jun 11, 2024
1 parent 6830160 commit b99eaf5
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 1 deletion.
108 changes: 108 additions & 0 deletions python/lsst/pipe/base/tests/simpleQGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,114 @@ def makeTask(
return task


class SubTaskConnections(
PipelineTaskConnections,
dimensions=("instrument", "detector"),
defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
):
"""Connections for SubTask, has one input and two outputs,
plus one init output.
"""

input = cT.Input(
name="add_dataset{in_tmpl}",
dimensions=["instrument", "detector"],
storageClass="NumpyArray",
doc="Input dataset type for this task",
)
output = cT.Output(
name="add_dataset{out_tmpl}",
dimensions=["instrument", "detector"],
storageClass="NumpyArray",
doc="Output dataset type for this task",
)
output2 = cT.Output(
name="add2_dataset{out_tmpl}",
dimensions=["instrument", "detector"],
storageClass="NumpyArray",
doc="Output dataset type for this task",
)
initout = cT.InitOutput(
name="add_init_output{out_tmpl}",
storageClass="NumpyArray",
doc="Init Output dataset type for this task",
)


class SubTaskConfig(PipelineTaskConfig, pipelineConnections=SubTaskConnections):
"""Config for SubTask."""

subtract = pexConfig.Field[int](doc="amount to subtract", default=3)


class SubTask(PipelineTask):
"""Trivial PipelineTask for testing, has some extras useful for specific
unit tests.
"""

ConfigClass = SubTaskConfig
_DefaultName = "sub_task"

initout = numpy.array([999])
"""InitOutputs for this task"""

taskFactory: SubTaskFactoryMock | None = None
"""Factory that makes instances"""

def run(self, input: int) -> Struct:
if self.taskFactory:
# do some bookkeeping
if self.taskFactory.stopAt == self.taskFactory.countExec:
raise RuntimeError("pretend something bad happened")
self.taskFactory.countExec -= 1

Check warning on line 261 in python/lsst/pipe/base/tests/simpleQGraph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/tests/simpleQGraph.py#L260-L261

Added lines #L260 - L261 were not covered by tests

self.config = cast(SubTaskConfig, self.config)
self.metadata.add("sub", self.config.subtract)
output = input - self.config.subtract
output2 = output + self.config.subtract
_LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
return Struct(output=output, output2=output2)

Check warning on line 268 in python/lsst/pipe/base/tests/simpleQGraph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/tests/simpleQGraph.py#L263-L268

Added lines #L263 - L268 were not covered by tests


class SubTaskFactoryMock(TaskFactory):
"""Special task factory that instantiates AddTask.
It also defines some bookkeeping variables used by SubTask to report
progress to unit tests.
Parameters
----------
stopAt : `int`, optional
Number of times to call `run` before stopping.
"""

def __init__(self, stopAt: int = -1):
self.countExec = 100 # reduced by SubTask
self.stopAt = stopAt # AddTask raises exception at this call to run()

Check warning on line 285 in python/lsst/pipe/base/tests/simpleQGraph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/tests/simpleQGraph.py#L284-L285

Added lines #L284 - L285 were not covered by tests

def makeTask(
self,
task_node: TaskDef | TaskNode,
/,
butler: LimitedButler,
initInputRefs: Iterable[DatasetRef] | None,
) -> PipelineTask:
if isinstance(task_node, TaskDef):
# TODO: remove support on DM-40443.
warnings.warn(

Check warning on line 296 in python/lsst/pipe/base/tests/simpleQGraph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/tests/simpleQGraph.py#L296

Added line #L296 was not covered by tests
"Passing TaskDef to TaskFactory is deprecated and will not be supported after v27.",
FutureWarning,
find_outside_stacklevel("lsst.pipe.base"),
)
task_class = task_node.taskClass
assert task_class is not None

Check warning on line 302 in python/lsst/pipe/base/tests/simpleQGraph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/tests/simpleQGraph.py#L301-L302

Added lines #L301 - L302 were not covered by tests
else:
task_class = task_node.task_class
task = task_class(config=task_node.config, initInputs=None, name=task_node.label)
task.taskFactory = self # type: ignore
return task

Check warning on line 307 in python/lsst/pipe/base/tests/simpleQGraph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/tests/simpleQGraph.py#L304-L307

Added lines #L304 - L307 were not covered by tests


def registerDatasetTypes(registry: Registry, pipeline: Pipeline | Iterable[TaskDef] | PipelineGraph) -> None:
"""Register all dataset types used by tasks in a registry.
Expand Down
44 changes: 43 additions & 1 deletion tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import lsst.utils.tests
from lsst.pipe.base import LabelSpecifier, Pipeline, TaskDef
from lsst.pipe.base.pipelineIR import LabeledSubset
from lsst.pipe.base.tests.simpleQGraph import AddTask, makeSimplePipeline
from lsst.pipe.base.tests.simpleQGraph import AddTask, SubTask, makeSimplePipeline


class PipelineTestCase(unittest.TestCase):
Expand Down Expand Up @@ -131,6 +131,48 @@ def testMergingPipelines(self):
pipeline1.mergePipeline(pipeline2)
self.assertEqual(pipeline1._pipelineIR.tasks.keys(), {"task0", "task1", "task2", "task3"})

# Test merging pipelines with ambiguous tasks
pipeline1 = makeSimplePipeline(2)
pipeline2 = makeSimplePipeline(2)
pipeline2.addTask(SubTask, "task1")
pipeline2.mergePipeline(pipeline1)

# Now merge in another pipeline with a config applied.
pipeline3 = makeSimplePipeline(2)
pipeline3.addTask(SubTask, "task1")
pipeline3.addConfigOverride("task1", "subtract", 10)
pipeline3.mergePipeline(pipeline2)
graph = pipeline3.to_graph()
# assert equality from the graph to trigger ambiquity resolution
self.assertEqual(graph.tasks["task1"].config.subtract, 10)

# Now change the order of the merging
pipeline1 = makeSimplePipeline(2)
pipeline2 = makeSimplePipeline(2)
pipeline2.addTask(SubTask, "task1")
pipeline3 = makeSimplePipeline(2)
pipeline3.mergePipeline(pipeline2)
pipeline3.mergePipeline(pipeline1)
graph = pipeline3.to_graph()
# assert equality from the graph to trigger ambiquity resolution
self.assertEqual(graph.tasks["task1"].config.addend, 3)

# Now do two ambiguous chains
pipeline1 = makeSimplePipeline(2)
pipeline2 = makeSimplePipeline(2)
pipeline2.addTask(SubTask, "task1")
pipeline2.addConfigOverride("task1", "subtract", 10)
pipeline2.mergePipeline(pipeline1)

pipeline3 = makeSimplePipeline(2)
pipeline4 = makeSimplePipeline(2)
pipeline4.addTask(SubTask, "task1")
pipeline4.addConfigOverride("task1", "subtract", 7)
pipeline4.mergePipeline(pipeline3)
graph = pipeline4.to_graph()
# assert equality from the graph to trigger ambiquity resolution
self.assertEqual(graph.tasks["task1"].config.subtract, 7)

def testFindingSubset(self):
pipeline = makeSimplePipeline(2)
pipeline._pipelineIR.labeled_subsets["test1"] = LabeledSubset("test1", set(), None)
Expand Down

0 comments on commit b99eaf5

Please sign in to comment.