Skip to content

Commit

Permalink
More typing changes
Browse files Browse the repository at this point in the history
  • Loading branch information
natelust committed Jul 4, 2023
1 parent 4f276c9 commit 44f5f6c
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
10 changes: 5 additions & 5 deletions python/lsst/pipe/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import itertools
import string
from collections import UserDict
from collections.abc import Collection, Generator, Iterable, Mapping, Set
from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set
from dataclasses import dataclass
from types import MappingProxyType, SimpleNamespace
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -934,12 +934,12 @@ class AdjustQuantumHelper:
connection-oriented mappings used inside `PipelineTaskConnections`.
"""

inputs: NamedKeyMapping[DatasetType, tuple[DatasetRef]]
inputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
"""Mapping of regular input and prerequisite input datasets, grouped by
`~lsst.daf.butler.DatasetType`.
"""

outputs: NamedKeyMapping[DatasetType, tuple[DatasetRef]]
outputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
"""Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`.
"""

Expand Down Expand Up @@ -997,7 +997,7 @@ def adjust_in_place(
# Translate adjustments to DatasetType-keyed, Quantum-oriented form,
# installing new mappings in self if necessary.
if adjusted_inputs_by_connection:
adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef]](self.inputs)
adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.inputs)

Check warning on line 1000 in python/lsst/pipe/base/connections.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/connections.py#L1000

Added line #L1000 was not covered by tests
for name, (connection, updated_refs) in adjusted_inputs_by_connection.items():
dataset_type_name = connection.name
if not set(updated_refs).issubset(self.inputs[dataset_type_name]):
Expand All @@ -1012,7 +1012,7 @@ def adjust_in_place(
else:
self.inputs_adjusted = False
if adjusted_outputs_by_connection:
adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef]](self.outputs)
adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.outputs)

Check warning on line 1015 in python/lsst/pipe/base/connections.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/connections.py#L1015

Added line #L1015 was not covered by tests
for name, (connection, updated_refs) in adjusted_outputs_by_connection.items():
dataset_type_name = connection.name

Check warning on line 1017 in python/lsst/pipe/base/connections.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/connections.py#L1017

Added line #L1017 was not covered by tests
if not set(updated_refs).issubset(self.outputs[dataset_type_name]):
Expand Down
6 changes: 4 additions & 2 deletions python/lsst/pipe/base/graph/_implDetails.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,15 @@ def _pruner(
# from the graph.
try:
helper.adjust_in_place(node.taskDef.connections, node.taskDef.label, node.quantum.dataId)
# ignore the types because quantum really can take a sequence
# of inputs
newQuantum = Quantum(
taskName=node.quantum.taskName,
taskClass=node.quantum.taskClass,
dataId=node.quantum.dataId,
initInputs=node.quantum.initInputs,
inputs=helper.inputs,
outputs=helper.outputs,
inputs=helper.inputs, # type: ignore
outputs=helper.outputs, # type: ignore
)
# If the inputs or outputs were adjusted to something different
# than what was supplied by the graph builder, dissassociate
Expand Down
5 changes: 3 additions & 2 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,13 +455,14 @@ def makeQuantum(self, datastore_records: Mapping[str, DatastoreRecordData] | Non
matching_records = records.subset(input_ids)
if matching_records is not None:
quantum_records[datastore_name] = matching_records
# ignore the types because quantum really can take a sequence of inputs
return Quantum(
taskName=self.task.taskDef.taskName,
taskClass=self.task.taskDef.taskClass,
dataId=self.dataId,
initInputs=initInputs,
inputs=helper.inputs,
outputs=helper.outputs,
inputs=helper.inputs, # type: ignore
outputs=helper.outputs, # type: ignore
datastore_records=quantum_records,
)

Expand Down

0 comments on commit 44f5f6c

Please sign in to comment.