Skip to content

Commit

Permalink
Merge branch 'tickets/DM-39582'
Browse files Browse the repository at this point in the history
  • Loading branch information
natelust committed Jul 4, 2023
2 parents 806fb6d + 44f5f6c commit 07d78ed
Show file tree
Hide file tree
Showing 18 changed files with 72 additions and 66 deletions.
3 changes: 3 additions & 0 deletions doc/changes/DM-39582.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The back-end to quantum graph loading has been optimized such that duplicate objects are not created in
memory, but create shared references. This results in a large decrease in memory usage, and decrease in load
times.
1 change: 1 addition & 0 deletions doc/changes/DM-39582.removal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Deprecated reconstituteDimensions argument from `QuantumNode.from_simple`
2 changes: 1 addition & 1 deletion python/lsst/pipe/base/_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ class _DummyConfig(Config):

config = _DummyConfig()

return config.packer.apply(data_id, is_exposure=is_exposure)
return config.packer.apply(data_id, is_exposure=is_exposure) # type: ignore

@staticmethod
@final
Expand Down
10 changes: 5 additions & 5 deletions python/lsst/pipe/base/_quantumContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def get(
n_connections = len(dataset)
n_retrieved = 0
for i, (name, ref) in enumerate(dataset):
if isinstance(ref, list):
if isinstance(ref, (list, tuple)):
val = []
n_refs = len(ref)
for j, r in enumerate(ref):
Expand Down Expand Up @@ -301,7 +301,7 @@ def get(
"Completed retrieval of %d datasets from %d connections", n_retrieved, n_connections
)
return retVal
elif isinstance(dataset, list):
elif isinstance(dataset, (list, tuple)):
n_datasets = len(dataset)
retrieved = []
for i, x in enumerate(dataset):
Expand Down Expand Up @@ -363,14 +363,14 @@ def put(
)
for name, refs in dataset:
valuesAttribute = getattr(values, name)
if isinstance(refs, list):
if isinstance(refs, (list, tuple)):
if len(refs) != len(valuesAttribute):
raise ValueError(f"There must be a object to put for every Dataset ref in {name}")
for i, ref in enumerate(refs):
self._put(valuesAttribute[i], ref)
else:
self._put(valuesAttribute, refs)
elif isinstance(dataset, list):
elif isinstance(dataset, (list, tuple)):
if not isinstance(values, Sequence):
raise ValueError("Values to put must be a sequence")
if len(dataset) != len(values):
Expand Down Expand Up @@ -401,7 +401,7 @@ def _checkMembership(self, ref: list[DatasetRef] | DatasetRef, inout: set) -> No
which may be important for Quanta with lots of
`~lsst.daf.butler.DatasetRef`.
"""
if not isinstance(ref, list):
if not isinstance(ref, (list, tuple)):
ref = [ref]
for r in ref:
if (r.datasetType, r.dataId) not in inout:
Expand Down
15 changes: 8 additions & 7 deletions python/lsst/pipe/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import itertools
import string
from collections import UserDict
from collections.abc import Collection, Generator, Iterable, Mapping, Set
from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set
from dataclasses import dataclass
from types import MappingProxyType, SimpleNamespace
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -934,12 +934,12 @@ class AdjustQuantumHelper:
connection-oriented mappings used inside `PipelineTaskConnections`.
"""

inputs: NamedKeyMapping[DatasetType, list[DatasetRef]]
inputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
"""Mapping of regular input and prerequisite input datasets, grouped by
`~lsst.daf.butler.DatasetType`.
"""

outputs: NamedKeyMapping[DatasetType, list[DatasetRef]]
outputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]]
"""Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`.
"""

Expand Down Expand Up @@ -997,7 +997,7 @@ def adjust_in_place(
# Translate adjustments to DatasetType-keyed, Quantum-oriented form,
# installing new mappings in self if necessary.
if adjusted_inputs_by_connection:
adjusted_inputs = NamedKeyDict[DatasetType, list[DatasetRef]](self.inputs)
adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.inputs)
for name, (connection, updated_refs) in adjusted_inputs_by_connection.items():
dataset_type_name = connection.name
if not set(updated_refs).issubset(self.inputs[dataset_type_name]):
Expand All @@ -1006,21 +1006,22 @@ def adjust_in_place(
f"({dataset_type_name}) input datasets that are not a subset of those "
f"it was given for data ID {data_id}."
)
adjusted_inputs[dataset_type_name] = list(updated_refs)
adjusted_inputs[dataset_type_name] = tuple(updated_refs)
self.inputs = adjusted_inputs.freeze()
self.inputs_adjusted = True
else:
self.inputs_adjusted = False
if adjusted_outputs_by_connection:
adjusted_outputs = NamedKeyDict[DatasetType, list[DatasetRef]](self.outputs)
adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.outputs)
for name, (connection, updated_refs) in adjusted_outputs_by_connection.items():
dataset_type_name = connection.name
if not set(updated_refs).issubset(self.outputs[dataset_type_name]):
raise RuntimeError(
f"adjustQuantum implementation for task with label {label} returned {name} "
f"({dataset_type_name}) output datasets that are not a subset of those "
f"it was given for data ID {data_id}."
)
adjusted_outputs[dataset_type_name] = list(updated_refs)
adjusted_outputs[dataset_type_name] = tuple(updated_refs)
self.outputs = adjusted_outputs.freeze()
self.outputs_adjusted = True
else:
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/pipe/base/executionButlerBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _accumulate(
for type, refs in attr.items():
# This if block is because init inputs has a different
# signature for its items
if not isinstance(refs, list):
if not isinstance(refs, (list, tuple)):
refs = [refs]
for ref in refs:
if ref.isComponent():
Expand All @@ -177,7 +177,7 @@ def _accumulate(
attr = getattr(quantum, attrName)

for type, refs in attr.items():
if not isinstance(refs, list):
if not isinstance(refs, (list, tuple)):
refs = [refs]
if type.component() is not None:
type = type.makeCompositeDatasetType()
Expand Down
6 changes: 4 additions & 2 deletions python/lsst/pipe/base/graph/_implDetails.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,15 @@ def _pruner(
# from the graph.
try:
helper.adjust_in_place(node.taskDef.connections, node.taskDef.label, node.quantum.dataId)
# ignore the types because quantum really can take a sequence
# of inputs
newQuantum = Quantum(
taskName=node.quantum.taskName,
taskClass=node.quantum.taskClass,
dataId=node.quantum.dataId,
initInputs=node.quantum.initInputs,
inputs=helper.inputs,
outputs=helper.outputs,
inputs=helper.inputs, # type: ignore
outputs=helper.outputs, # type: ignore
)
# If the inputs or outputs were adjusted to something different
# than what was supplied by the graph builder, dissassociate
Expand Down
8 changes: 6 additions & 2 deletions python/lsst/pipe/base/graph/_loadHelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from typing import TYPE_CHECKING, BinaryIO
from uuid import UUID

from lsst.daf.butler import DimensionUniverse
from lsst.daf.butler import DimensionUniverse, PersistenceContextVars
from lsst.resources import ResourceHandleProtocol, ResourcePath

if TYPE_CHECKING:
Expand Down Expand Up @@ -219,7 +219,11 @@ def load(
_readBytes = self._readBytes
if universe is None:
universe = headerInfo.universe
return self.deserializer.constructGraph(nodeSet, _readBytes, universe)
# use the daf butler context vars to aid in ensuring deduplication in
# object instantiation.
runner = PersistenceContextVars()
graph = runner.run(self.deserializer.constructGraph, nodeSet, _readBytes, universe)
return graph

def _readBytes(self, start: int, stop: int) -> bytes:
"""Load the specified byte range from the ResourcePath object
Expand Down
2 changes: 2 additions & 0 deletions python/lsst/pipe/base/graph/_versionDeserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,8 @@ def constructGraph(

# Turn the json back into the pydandtic model
nodeDeserialized = SerializedQuantumNode.direct(**dump)
del dump

# attach the dictionary of dimension records to the pydantic model
# these are stored separately because the are stored over and over
# and this saves a lot of space and time.
Expand Down
33 changes: 8 additions & 25 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,49 +1276,32 @@ def updateRun(self, run: str, *, metadata_key: str | None = None, update_graph_i
update_graph_id : `bool`, optional
If `True` then also update graph ID with a new unique value.
"""
dataset_id_map = {}

def _update_output_refs_in_place(refs: list[DatasetRef], run: str) -> None:
def _update_refs_in_place(refs: list[DatasetRef], run: str) -> None:
"""Update list of `~lsst.daf.butler.DatasetRef` with new run and
dataset IDs.
"""
new_refs = []
for ref in refs:
new_ref = DatasetRef(ref.datasetType, ref.dataId, run=run, conform=False)
dataset_id_map[ref.id] = new_ref.id
new_refs.append(new_ref)
refs[:] = new_refs

def _update_input_refs_in_place(refs: list[DatasetRef], run: str) -> None:
"""Update list of `~lsst.daf.butler.DatasetRef` with IDs from
dataset_id_map.
"""
new_refs = []
for ref in refs:
if (new_id := dataset_id_map.get(ref.id)) is not None:
new_ref = DatasetRef(ref.datasetType, ref.dataId, id=new_id, run=run, conform=False)
new_refs.append(new_ref)
else:
new_refs.append(ref)
refs[:] = new_refs
# hack the run to be replaced explicitly
object.__setattr__(ref, "run", run)

# Loop through all outputs and update their datasets.
for node in self._connectedQuanta:
for refs in node.quantum.outputs.values():
_update_output_refs_in_place(refs, run)
_update_refs_in_place(refs, run)

for refs in self._initOutputRefs.values():
_update_output_refs_in_place(refs, run)
_update_refs_in_place(refs, run)

_update_output_refs_in_place(self._globalInitOutputRefs, run)
_update_refs_in_place(self._globalInitOutputRefs, run)

# Update all intermediates from their matching outputs.
for node in self._connectedQuanta:
for refs in node.quantum.inputs.values():
_update_input_refs_in_place(refs, run)
_update_refs_in_place(refs, run)

for refs in self._initInputRefs.values():
_update_input_refs_in_place(refs, run)
_update_refs_in_place(refs, run)

if update_graph_id:
self._buildId = BuildId(f"{time.time()}-{os.getpid()}")
Expand Down
19 changes: 15 additions & 4 deletions python/lsst/pipe/base/graph/quantumNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
__all__ = ("QuantumNode", "NodeId", "BuildId")

import uuid
import warnings
from dataclasses import dataclass
from typing import Any, NewType

Expand All @@ -34,6 +35,7 @@
Quantum,
SerializedQuantum,
)
from lsst.utils.introspection import find_outside_stacklevel
from pydantic import BaseModel

from ..pipeline import TaskDef
Expand Down Expand Up @@ -96,6 +98,8 @@ class QuantumNode:
creation.
"""

__slots__ = ("quantum", "taskDef", "nodeId", "_precomputedHash")

def __post_init__(self) -> None:
# use setattr here to preserve the frozenness of the QuantumNode
self._precomputedHash: int
Expand Down Expand Up @@ -135,15 +139,22 @@ def from_simple(
universe: DimensionUniverse,
recontitutedDimensions: dict[int, tuple[str, DimensionRecord]] | None = None,
) -> QuantumNode:
if recontitutedDimensions is not None:
warnings.warn(
"The recontitutedDimensions argument is now ignored and may be removed after v 27",
category=FutureWarning,
stacklevel=find_outside_stacklevel("lsst.pipe.base"),
)
return QuantumNode(
quantum=Quantum.from_simple(
simple.quantum, universe, reconstitutedDimensions=recontitutedDimensions
),
quantum=Quantum.from_simple(simple.quantum, universe),
taskDef=taskDefMap[simple.taskLabel],
nodeId=simple.nodeId,
)


_fields_set = {"quantum", "taskLabel", "nodeId"}


class SerializedQuantumNode(BaseModel):
quantum: SerializedQuantum
taskLabel: str
Expand All @@ -156,5 +167,5 @@ def direct(cls, *, quantum: dict[str, Any], taskLabel: str, nodeId: str) -> Seri
setter(node, "quantum", SerializedQuantum.direct(**quantum))
setter(node, "taskLabel", taskLabel)
setter(node, "nodeId", uuid.UUID(nodeId))
setter(node, "__fields_set__", {"quantum", "taskLabel", "nodeId"})
setter(node, "__fields_set__", _fields_set)
return node
5 changes: 3 additions & 2 deletions python/lsst/pipe/base/graphBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,13 +455,14 @@ def makeQuantum(self, datastore_records: Mapping[str, DatastoreRecordData] | Non
matching_records = records.subset(input_ids)
if matching_records is not None:
quantum_records[datastore_name] = matching_records
# ignore the types because quantum really can take a sequence of inputs
return Quantum(
taskName=self.task.taskDef.taskName,
taskClass=self.task.taskDef.taskClass,
dataId=self.dataId,
initInputs=initInputs,
inputs=helper.inputs,
outputs=helper.outputs,
inputs=helper.inputs, # type: ignore
outputs=helper.outputs, # type: ignore
datastore_records=quantum_records,
)

Expand Down
8 changes: 4 additions & 4 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def logOutputDatasetName(self) -> str | None:
"""Name of a dataset type for log output from this task, `None` if
logs are not to be saved (`str`)
"""
if cast(PipelineTaskConfig, self.config).saveLogOutput:
if self.config.saveLogOutput:
return acc.LOG_OUTPUT_TEMPLATE.format(label=self.label)
else:
return None
Expand Down Expand Up @@ -623,7 +623,7 @@ def get_data_id(self, universe: DimensionUniverse) -> DataCoordinate:
"""
instrument_class_name = self._pipelineIR.instrument
if instrument_class_name is not None:
instrument_class = doImportType(instrument_class_name)
instrument_class = cast(PipeBaseInstrument, doImportType(instrument_class_name))
if instrument_class is not None:
return DataCoordinate.standardize(instrument=instrument_class.getName(), universe=universe)
return DataCoordinate.makeEmpty(universe)
Expand Down Expand Up @@ -654,8 +654,8 @@ def addTask(self, task: type[PipelineTask] | str, label: str) -> None:
# be defined without label which is not acceptable, use task
# _DefaultName in that case
if isinstance(task, str):
task_class = doImportType(task)
label = task_class._DefaultName
task_class = cast(PipelineTask, doImportType(task))
label = task_class._DefaultName
self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName)

def removeTask(self, label: str) -> None:
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/pipe/base/script/transfer_from_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def transfer_from_graph(
if refs := qgraph.initOutputRefs(task_def):
original_output_refs.update(refs)
for qnode in qgraph:
for refs in qnode.quantum.outputs.values():
original_output_refs.update(refs)
for otherRefs in qnode.quantum.outputs.values():
original_output_refs.update(otherRefs)

# Get data repository definitions from the QuantumGraph; these can have
# different storage classes than those in the quanta.
Expand Down
4 changes: 1 addition & 3 deletions python/lsst/pipe/base/tests/mocks/_pipeline_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,7 @@ def runQuantum(

# store mock outputs
for name, refs in outputRefs:
if not isinstance(refs, list):
refs = [refs]
for ref in refs:
for ref in ensure_iterable(refs):
output = MockDataset(
ref=ref.to_simple(), quantum=mock_dataset_quantum, output_connection_name=name
)
Expand Down
2 changes: 1 addition & 1 deletion python/lsst/pipe/base/tests/simpleQGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def populateButler(
instrument = pipeline.getInstrument()
if instrument is not None:
instrument_class = doImportType(instrument)
instrumentName = instrument_class.getName()
instrumentName = cast(Instrument, instrument_class).getName()
instrumentClass = get_full_type_name(instrument_class)
else:
instrumentName = "INSTR"
Expand Down
Loading

0 comments on commit 07d78ed

Please sign in to comment.