From 4f276c94eceec1566d9a74cc5ab217307b9d6ea8 Mon Sep 17 00:00:00 2001 From: Nate Lust Date: Tue, 4 Jul 2023 09:18:12 -0400 Subject: [PATCH] MyPy fixes and a small bug fix Fixing mypy annotations revealed a small bug in adjustQuantum with dataset type names, this commit also fixes that. --- python/lsst/pipe/base/connections.py | 13 +++++++------ python/lsst/pipe/base/script/transfer_from_graph.py | 4 ++-- python/lsst/pipe/base/tests/util.py | 10 +++++----- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/python/lsst/pipe/base/connections.py b/python/lsst/pipe/base/connections.py index e02d99e2..556cfa01 100644 --- a/python/lsst/pipe/base/connections.py +++ b/python/lsst/pipe/base/connections.py @@ -934,12 +934,12 @@ class AdjustQuantumHelper: connection-oriented mappings used inside `PipelineTaskConnections`. """ - inputs: NamedKeyMapping[DatasetType, list[DatasetRef]] + inputs: NamedKeyMapping[DatasetType, tuple[DatasetRef]] """Mapping of regular input and prerequisite input datasets, grouped by `~lsst.daf.butler.DatasetType`. """ - outputs: NamedKeyMapping[DatasetType, list[DatasetRef]] + outputs: NamedKeyMapping[DatasetType, tuple[DatasetRef]] """Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`. """ @@ -997,7 +997,7 @@ def adjust_in_place( # Translate adjustments to DatasetType-keyed, Quantum-oriented form, # installing new mappings in self if necessary. if adjusted_inputs_by_connection: - adjusted_inputs = NamedKeyDict[DatasetType, list[DatasetRef]](self.inputs) + adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef]](self.inputs) for name, (connection, updated_refs) in adjusted_inputs_by_connection.items(): dataset_type_name = connection.name if not set(updated_refs).issubset(self.inputs[dataset_type_name]): @@ -1006,21 +1006,22 @@ def adjust_in_place( f"({dataset_type_name}) input datasets that are not a subset of those " f"it was given for data ID {data_id}." ) - adjusted_inputs[dataset_type_name] = list(updated_refs) + adjusted_inputs[dataset_type_name] = tuple(updated_refs) self.inputs = adjusted_inputs.freeze() self.inputs_adjusted = True else: self.inputs_adjusted = False if adjusted_outputs_by_connection: - adjusted_outputs = NamedKeyDict[DatasetType, list[DatasetRef]](self.outputs) + adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef]](self.outputs) for name, (connection, updated_refs) in adjusted_outputs_by_connection.items(): + dataset_type_name = connection.name if not set(updated_refs).issubset(self.outputs[dataset_type_name]): raise RuntimeError( f"adjustQuantum implementation for task with label {label} returned {name} " f"({dataset_type_name}) output datasets that are not a subset of those " f"it was given for data ID {data_id}." ) - adjusted_outputs[dataset_type_name] = list(updated_refs) + adjusted_outputs[dataset_type_name] = tuple(updated_refs) self.outputs = adjusted_outputs.freeze() self.outputs_adjusted = True else: diff --git a/python/lsst/pipe/base/script/transfer_from_graph.py b/python/lsst/pipe/base/script/transfer_from_graph.py index 547885ae..ad1a2549 100644 --- a/python/lsst/pipe/base/script/transfer_from_graph.py +++ b/python/lsst/pipe/base/script/transfer_from_graph.py @@ -66,8 +66,8 @@ def transfer_from_graph( if refs := qgraph.initOutputRefs(task_def): original_output_refs.update(refs) for qnode in qgraph: - for refs in qnode.quantum.outputs.values(): - original_output_refs.update(refs) + for otherRefs in qnode.quantum.outputs.values(): + original_output_refs.update(otherRefs) # Get data repository definitions from the QuantumGraph; these can have # different storage classes than those in the quanta. diff --git a/python/lsst/pipe/base/tests/util.py b/python/lsst/pipe/base/tests/util.py index 832cd91c..7d9aa2b5 100644 --- a/python/lsst/pipe/base/tests/util.py +++ b/python/lsst/pipe/base/tests/util.py @@ -46,8 +46,8 @@ def check_output_run(graph: QuantumGraph, run: str) -> list[DatasetRef]: the specified run. """ # Collect all inputs/outputs, so that we can build intermediate refs. - output_refs = [] - input_refs = [] + output_refs: list[DatasetRef] = [] + input_refs: list[DatasetRef] = [] for node in graph: for refs in node.quantum.outputs.values(): output_refs += refs @@ -61,10 +61,10 @@ def check_output_run(graph: QuantumGraph, run: str) -> list[DatasetRef]: if init_refs: input_refs += init_refs output_refs += graph.globalInitOutputRefs() - refs = [ref for ref in output_refs if ref.run != run] + newRefs = [ref for ref in output_refs if ref.run != run] output_ids = {ref.id for ref in output_refs} intermediates = [ref for ref in input_refs if ref.id in output_ids] - refs += [ref for ref in intermediates if ref.run != run] + newRefs += [ref for ref in intermediates if ref.run != run] - return refs + return newRefs