Skip to content

Commit

Permalink
MyPy fixes and a small bug fix
Browse files Browse the repository at this point in the history
Fixing mypy annotations revealed a small bug in adjustQuantum with
dataset type names, this commit also fixes that.
  • Loading branch information
natelust committed Jul 4, 2023
1 parent 628e539 commit 4f276c9
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
13 changes: 7 additions & 6 deletions python/lsst/pipe/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,12 +934,12 @@ class AdjustQuantumHelper:
connection-oriented mappings used inside `PipelineTaskConnections`.
"""

inputs: NamedKeyMapping[DatasetType, list[DatasetRef]]
inputs: NamedKeyMapping[DatasetType, tuple[DatasetRef]]
"""Mapping of regular input and prerequisite input datasets, grouped by
`~lsst.daf.butler.DatasetType`.
"""

outputs: NamedKeyMapping[DatasetType, list[DatasetRef]]
outputs: NamedKeyMapping[DatasetType, tuple[DatasetRef]]
"""Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`.
"""

Expand Down Expand Up @@ -997,7 +997,7 @@ def adjust_in_place(
# Translate adjustments to DatasetType-keyed, Quantum-oriented form,
# installing new mappings in self if necessary.
if adjusted_inputs_by_connection:
adjusted_inputs = NamedKeyDict[DatasetType, list[DatasetRef]](self.inputs)
adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef]](self.inputs)
for name, (connection, updated_refs) in adjusted_inputs_by_connection.items():
dataset_type_name = connection.name
if not set(updated_refs).issubset(self.inputs[dataset_type_name]):
Expand All @@ -1006,21 +1006,22 @@ def adjust_in_place(
f"({dataset_type_name}) input datasets that are not a subset of those "
f"it was given for data ID {data_id}."
)
adjusted_inputs[dataset_type_name] = list(updated_refs)
adjusted_inputs[dataset_type_name] = tuple(updated_refs)
self.inputs = adjusted_inputs.freeze()
self.inputs_adjusted = True
else:
self.inputs_adjusted = False
if adjusted_outputs_by_connection:
adjusted_outputs = NamedKeyDict[DatasetType, list[DatasetRef]](self.outputs)
adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef]](self.outputs)
for name, (connection, updated_refs) in adjusted_outputs_by_connection.items():
dataset_type_name = connection.name
if not set(updated_refs).issubset(self.outputs[dataset_type_name]):
raise RuntimeError(
f"adjustQuantum implementation for task with label {label} returned {name} "
f"({dataset_type_name}) output datasets that are not a subset of those "
f"it was given for data ID {data_id}."
)
adjusted_outputs[dataset_type_name] = list(updated_refs)
adjusted_outputs[dataset_type_name] = tuple(updated_refs)
self.outputs = adjusted_outputs.freeze()
self.outputs_adjusted = True
else:
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/pipe/base/script/transfer_from_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def transfer_from_graph(
if refs := qgraph.initOutputRefs(task_def):
original_output_refs.update(refs)
for qnode in qgraph:
for refs in qnode.quantum.outputs.values():
original_output_refs.update(refs)
for otherRefs in qnode.quantum.outputs.values():
original_output_refs.update(otherRefs)

# Get data repository definitions from the QuantumGraph; these can have
# different storage classes than those in the quanta.
Expand Down
10 changes: 5 additions & 5 deletions python/lsst/pipe/base/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def check_output_run(graph: QuantumGraph, run: str) -> list[DatasetRef]:
the specified run.
"""
# Collect all inputs/outputs, so that we can build intermediate refs.
output_refs = []
input_refs = []
output_refs: list[DatasetRef] = []
input_refs: list[DatasetRef] = []
for node in graph:
for refs in node.quantum.outputs.values():
output_refs += refs
Expand All @@ -61,10 +61,10 @@ def check_output_run(graph: QuantumGraph, run: str) -> list[DatasetRef]:
if init_refs:
input_refs += init_refs
output_refs += graph.globalInitOutputRefs()
refs = [ref for ref in output_refs if ref.run != run]
newRefs = [ref for ref in output_refs if ref.run != run]

output_ids = {ref.id for ref in output_refs}
intermediates = [ref for ref in input_refs if ref.id in output_ids]
refs += [ref for ref in intermediates if ref.run != run]
newRefs += [ref for ref in intermediates if ref.run != run]

return refs
return newRefs

0 comments on commit 4f276c9

Please sign in to comment.