diff --git a/doc/changes/DM-39582.feature.md b/doc/changes/DM-39582.feature.md new file mode 100644 index 00000000..5c56c3a0 --- /dev/null +++ b/doc/changes/DM-39582.feature.md @@ -0,0 +1,3 @@ +The back-end to quantum graph loading has been optimized such that duplicate objects are not created in +memory, but create shared references. This results in a large decrease in memory usage, and decrease in load +times. diff --git a/doc/changes/DM-39582.removal.md b/doc/changes/DM-39582.removal.md new file mode 100644 index 00000000..b489f3bb --- /dev/null +++ b/doc/changes/DM-39582.removal.md @@ -0,0 +1 @@ +Deprecated reconstituteDimensions argument from `QuantumNode.from_simple` diff --git a/python/lsst/pipe/base/_instrument.py b/python/lsst/pipe/base/_instrument.py index 002a8575..856ce75b 100644 --- a/python/lsst/pipe/base/_instrument.py +++ b/python/lsst/pipe/base/_instrument.py @@ -668,7 +668,7 @@ class _DummyConfig(Config): config = _DummyConfig() - return config.packer.apply(data_id, is_exposure=is_exposure) + return config.packer.apply(data_id, is_exposure=is_exposure) # type: ignore @staticmethod @final diff --git a/python/lsst/pipe/base/_quantumContext.py b/python/lsst/pipe/base/_quantumContext.py index 62c8df00..f5a1fc31 100644 --- a/python/lsst/pipe/base/_quantumContext.py +++ b/python/lsst/pipe/base/_quantumContext.py @@ -270,7 +270,7 @@ def get( n_connections = len(dataset) n_retrieved = 0 for i, (name, ref) in enumerate(dataset): - if isinstance(ref, list): + if isinstance(ref, (list, tuple)): val = [] n_refs = len(ref) for j, r in enumerate(ref): @@ -301,7 +301,7 @@ def get( "Completed retrieval of %d datasets from %d connections", n_retrieved, n_connections ) return retVal - elif isinstance(dataset, list): + elif isinstance(dataset, (list, tuple)): n_datasets = len(dataset) retrieved = [] for i, x in enumerate(dataset): @@ -363,14 +363,14 @@ def put( ) for name, refs in dataset: valuesAttribute = getattr(values, name) - if isinstance(refs, list): + if isinstance(refs, (list, tuple)): if len(refs) != len(valuesAttribute): raise ValueError(f"There must be a object to put for every Dataset ref in {name}") for i, ref in enumerate(refs): self._put(valuesAttribute[i], ref) else: self._put(valuesAttribute, refs) - elif isinstance(dataset, list): + elif isinstance(dataset, (list, tuple)): if not isinstance(values, Sequence): raise ValueError("Values to put must be a sequence") if len(dataset) != len(values): @@ -401,7 +401,7 @@ def _checkMembership(self, ref: list[DatasetRef] | DatasetRef, inout: set) -> No which may be important for Quanta with lots of `~lsst.daf.butler.DatasetRef`. """ - if not isinstance(ref, list): + if not isinstance(ref, (list, tuple)): ref = [ref] for r in ref: if (r.datasetType, r.dataId) not in inout: diff --git a/python/lsst/pipe/base/connections.py b/python/lsst/pipe/base/connections.py index e02d99e2..3d54150f 100644 --- a/python/lsst/pipe/base/connections.py +++ b/python/lsst/pipe/base/connections.py @@ -39,7 +39,7 @@ import itertools import string from collections import UserDict -from collections.abc import Collection, Generator, Iterable, Mapping, Set +from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set from dataclasses import dataclass from types import MappingProxyType, SimpleNamespace from typing import TYPE_CHECKING, Any @@ -934,12 +934,12 @@ class AdjustQuantumHelper: connection-oriented mappings used inside `PipelineTaskConnections`. """ - inputs: NamedKeyMapping[DatasetType, list[DatasetRef]] + inputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]] """Mapping of regular input and prerequisite input datasets, grouped by `~lsst.daf.butler.DatasetType`. """ - outputs: NamedKeyMapping[DatasetType, list[DatasetRef]] + outputs: NamedKeyMapping[DatasetType, Sequence[DatasetRef]] """Mapping of output datasets, grouped by `~lsst.daf.butler.DatasetType`. """ @@ -997,7 +997,7 @@ def adjust_in_place( # Translate adjustments to DatasetType-keyed, Quantum-oriented form, # installing new mappings in self if necessary. if adjusted_inputs_by_connection: - adjusted_inputs = NamedKeyDict[DatasetType, list[DatasetRef]](self.inputs) + adjusted_inputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.inputs) for name, (connection, updated_refs) in adjusted_inputs_by_connection.items(): dataset_type_name = connection.name if not set(updated_refs).issubset(self.inputs[dataset_type_name]): @@ -1006,21 +1006,22 @@ def adjust_in_place( f"({dataset_type_name}) input datasets that are not a subset of those " f"it was given for data ID {data_id}." ) - adjusted_inputs[dataset_type_name] = list(updated_refs) + adjusted_inputs[dataset_type_name] = tuple(updated_refs) self.inputs = adjusted_inputs.freeze() self.inputs_adjusted = True else: self.inputs_adjusted = False if adjusted_outputs_by_connection: - adjusted_outputs = NamedKeyDict[DatasetType, list[DatasetRef]](self.outputs) + adjusted_outputs = NamedKeyDict[DatasetType, tuple[DatasetRef, ...]](self.outputs) for name, (connection, updated_refs) in adjusted_outputs_by_connection.items(): + dataset_type_name = connection.name if not set(updated_refs).issubset(self.outputs[dataset_type_name]): raise RuntimeError( f"adjustQuantum implementation for task with label {label} returned {name} " f"({dataset_type_name}) output datasets that are not a subset of those " f"it was given for data ID {data_id}." ) - adjusted_outputs[dataset_type_name] = list(updated_refs) + adjusted_outputs[dataset_type_name] = tuple(updated_refs) self.outputs = adjusted_outputs.freeze() self.outputs_adjusted = True else: diff --git a/python/lsst/pipe/base/executionButlerBuilder.py b/python/lsst/pipe/base/executionButlerBuilder.py index 638dc4e5..84204cae 100644 --- a/python/lsst/pipe/base/executionButlerBuilder.py +++ b/python/lsst/pipe/base/executionButlerBuilder.py @@ -162,7 +162,7 @@ def _accumulate( for type, refs in attr.items(): # This if block is because init inputs has a different # signature for its items - if not isinstance(refs, list): + if not isinstance(refs, (list, tuple)): refs = [refs] for ref in refs: if ref.isComponent(): @@ -177,7 +177,7 @@ def _accumulate( attr = getattr(quantum, attrName) for type, refs in attr.items(): - if not isinstance(refs, list): + if not isinstance(refs, (list, tuple)): refs = [refs] if type.component() is not None: type = type.makeCompositeDatasetType() diff --git a/python/lsst/pipe/base/graph/_implDetails.py b/python/lsst/pipe/base/graph/_implDetails.py index 3f27faec..6472ab06 100644 --- a/python/lsst/pipe/base/graph/_implDetails.py +++ b/python/lsst/pipe/base/graph/_implDetails.py @@ -313,13 +313,15 @@ def _pruner( # from the graph. try: helper.adjust_in_place(node.taskDef.connections, node.taskDef.label, node.quantum.dataId) + # ignore the types because quantum really can take a sequence + # of inputs newQuantum = Quantum( taskName=node.quantum.taskName, taskClass=node.quantum.taskClass, dataId=node.quantum.dataId, initInputs=node.quantum.initInputs, - inputs=helper.inputs, - outputs=helper.outputs, + inputs=helper.inputs, # type: ignore + outputs=helper.outputs, # type: ignore ) # If the inputs or outputs were adjusted to something different # than what was supplied by the graph builder, dissassociate diff --git a/python/lsst/pipe/base/graph/_loadHelpers.py b/python/lsst/pipe/base/graph/_loadHelpers.py index 5d112d2a..0190842f 100644 --- a/python/lsst/pipe/base/graph/_loadHelpers.py +++ b/python/lsst/pipe/base/graph/_loadHelpers.py @@ -31,7 +31,7 @@ from typing import TYPE_CHECKING, BinaryIO from uuid import UUID -from lsst.daf.butler import DimensionUniverse +from lsst.daf.butler import DimensionUniverse, PersistenceContextVars from lsst.resources import ResourceHandleProtocol, ResourcePath if TYPE_CHECKING: @@ -219,7 +219,11 @@ def load( _readBytes = self._readBytes if universe is None: universe = headerInfo.universe - return self.deserializer.constructGraph(nodeSet, _readBytes, universe) + # use the daf butler context vars to aid in ensuring deduplication in + # object instantiation. + runner = PersistenceContextVars() + graph = runner.run(self.deserializer.constructGraph, nodeSet, _readBytes, universe) + return graph def _readBytes(self, start: int, stop: int) -> bytes: """Load the specified byte range from the ResourcePath object diff --git a/python/lsst/pipe/base/graph/_versionDeserializers.py b/python/lsst/pipe/base/graph/_versionDeserializers.py index 8ea98463..8b98f214 100644 --- a/python/lsst/pipe/base/graph/_versionDeserializers.py +++ b/python/lsst/pipe/base/graph/_versionDeserializers.py @@ -557,6 +557,8 @@ def constructGraph( # Turn the json back into the pydandtic model nodeDeserialized = SerializedQuantumNode.direct(**dump) + del dump + # attach the dictionary of dimension records to the pydantic model # these are stored separately because the are stored over and over # and this saves a lot of space and time. diff --git a/python/lsst/pipe/base/graph/graph.py b/python/lsst/pipe/base/graph/graph.py index 9fb294a6..8f0bbb31 100644 --- a/python/lsst/pipe/base/graph/graph.py +++ b/python/lsst/pipe/base/graph/graph.py @@ -1276,49 +1276,32 @@ def updateRun(self, run: str, *, metadata_key: str | None = None, update_graph_i update_graph_id : `bool`, optional If `True` then also update graph ID with a new unique value. """ - dataset_id_map = {} - def _update_output_refs_in_place(refs: list[DatasetRef], run: str) -> None: + def _update_refs_in_place(refs: list[DatasetRef], run: str) -> None: """Update list of `~lsst.daf.butler.DatasetRef` with new run and dataset IDs. """ - new_refs = [] for ref in refs: - new_ref = DatasetRef(ref.datasetType, ref.dataId, run=run, conform=False) - dataset_id_map[ref.id] = new_ref.id - new_refs.append(new_ref) - refs[:] = new_refs - - def _update_input_refs_in_place(refs: list[DatasetRef], run: str) -> None: - """Update list of `~lsst.daf.butler.DatasetRef` with IDs from - dataset_id_map. - """ - new_refs = [] - for ref in refs: - if (new_id := dataset_id_map.get(ref.id)) is not None: - new_ref = DatasetRef(ref.datasetType, ref.dataId, id=new_id, run=run, conform=False) - new_refs.append(new_ref) - else: - new_refs.append(ref) - refs[:] = new_refs + # hack the run to be replaced explicitly + object.__setattr__(ref, "run", run) # Loop through all outputs and update their datasets. for node in self._connectedQuanta: for refs in node.quantum.outputs.values(): - _update_output_refs_in_place(refs, run) + _update_refs_in_place(refs, run) for refs in self._initOutputRefs.values(): - _update_output_refs_in_place(refs, run) + _update_refs_in_place(refs, run) - _update_output_refs_in_place(self._globalInitOutputRefs, run) + _update_refs_in_place(self._globalInitOutputRefs, run) # Update all intermediates from their matching outputs. for node in self._connectedQuanta: for refs in node.quantum.inputs.values(): - _update_input_refs_in_place(refs, run) + _update_refs_in_place(refs, run) for refs in self._initInputRefs.values(): - _update_input_refs_in_place(refs, run) + _update_refs_in_place(refs, run) if update_graph_id: self._buildId = BuildId(f"{time.time()}-{os.getpid()}") diff --git a/python/lsst/pipe/base/graph/quantumNode.py b/python/lsst/pipe/base/graph/quantumNode.py index 2c3c9606..b5df5d01 100644 --- a/python/lsst/pipe/base/graph/quantumNode.py +++ b/python/lsst/pipe/base/graph/quantumNode.py @@ -23,6 +23,7 @@ __all__ = ("QuantumNode", "NodeId", "BuildId") import uuid +import warnings from dataclasses import dataclass from typing import Any, NewType @@ -34,6 +35,7 @@ Quantum, SerializedQuantum, ) +from lsst.utils.introspection import find_outside_stacklevel from pydantic import BaseModel from ..pipeline import TaskDef @@ -96,6 +98,8 @@ class QuantumNode: creation. """ + __slots__ = ("quantum", "taskDef", "nodeId", "_precomputedHash") + def __post_init__(self) -> None: # use setattr here to preserve the frozenness of the QuantumNode self._precomputedHash: int @@ -135,15 +139,22 @@ def from_simple( universe: DimensionUniverse, recontitutedDimensions: dict[int, tuple[str, DimensionRecord]] | None = None, ) -> QuantumNode: + if recontitutedDimensions is not None: + warnings.warn( + "The recontitutedDimensions argument is now ignored and may be removed after v 27", + category=FutureWarning, + stacklevel=find_outside_stacklevel("lsst.pipe.base"), + ) return QuantumNode( - quantum=Quantum.from_simple( - simple.quantum, universe, reconstitutedDimensions=recontitutedDimensions - ), + quantum=Quantum.from_simple(simple.quantum, universe), taskDef=taskDefMap[simple.taskLabel], nodeId=simple.nodeId, ) +_fields_set = {"quantum", "taskLabel", "nodeId"} + + class SerializedQuantumNode(BaseModel): quantum: SerializedQuantum taskLabel: str @@ -156,5 +167,5 @@ def direct(cls, *, quantum: dict[str, Any], taskLabel: str, nodeId: str) -> Seri setter(node, "quantum", SerializedQuantum.direct(**quantum)) setter(node, "taskLabel", taskLabel) setter(node, "nodeId", uuid.UUID(nodeId)) - setter(node, "__fields_set__", {"quantum", "taskLabel", "nodeId"}) + setter(node, "__fields_set__", _fields_set) return node diff --git a/python/lsst/pipe/base/graphBuilder.py b/python/lsst/pipe/base/graphBuilder.py index c9ee2978..fb5620f0 100644 --- a/python/lsst/pipe/base/graphBuilder.py +++ b/python/lsst/pipe/base/graphBuilder.py @@ -455,13 +455,14 @@ def makeQuantum(self, datastore_records: Mapping[str, DatastoreRecordData] | Non matching_records = records.subset(input_ids) if matching_records is not None: quantum_records[datastore_name] = matching_records + # ignore the types because quantum really can take a sequence of inputs return Quantum( taskName=self.task.taskDef.taskName, taskClass=self.task.taskDef.taskClass, dataId=self.dataId, initInputs=initInputs, - inputs=helper.inputs, - outputs=helper.outputs, + inputs=helper.inputs, # type: ignore + outputs=helper.outputs, # type: ignore datastore_records=quantum_records, ) diff --git a/python/lsst/pipe/base/pipeline.py b/python/lsst/pipe/base/pipeline.py index 781defcc..48f3ad05 100644 --- a/python/lsst/pipe/base/pipeline.py +++ b/python/lsst/pipe/base/pipeline.py @@ -192,7 +192,7 @@ def logOutputDatasetName(self) -> str | None: """Name of a dataset type for log output from this task, `None` if logs are not to be saved (`str`) """ - if cast(PipelineTaskConfig, self.config).saveLogOutput: + if self.config.saveLogOutput: return acc.LOG_OUTPUT_TEMPLATE.format(label=self.label) else: return None @@ -623,7 +623,7 @@ def get_data_id(self, universe: DimensionUniverse) -> DataCoordinate: """ instrument_class_name = self._pipelineIR.instrument if instrument_class_name is not None: - instrument_class = doImportType(instrument_class_name) + instrument_class = cast(PipeBaseInstrument, doImportType(instrument_class_name)) if instrument_class is not None: return DataCoordinate.standardize(instrument=instrument_class.getName(), universe=universe) return DataCoordinate.makeEmpty(universe) @@ -654,8 +654,8 @@ def addTask(self, task: type[PipelineTask] | str, label: str) -> None: # be defined without label which is not acceptable, use task # _DefaultName in that case if isinstance(task, str): - task_class = doImportType(task) - label = task_class._DefaultName + task_class = cast(PipelineTask, doImportType(task)) + label = task_class._DefaultName self._pipelineIR.tasks[label] = pipelineIR.TaskIR(label, taskName) def removeTask(self, label: str) -> None: diff --git a/python/lsst/pipe/base/script/transfer_from_graph.py b/python/lsst/pipe/base/script/transfer_from_graph.py index 547885ae..ad1a2549 100644 --- a/python/lsst/pipe/base/script/transfer_from_graph.py +++ b/python/lsst/pipe/base/script/transfer_from_graph.py @@ -66,8 +66,8 @@ def transfer_from_graph( if refs := qgraph.initOutputRefs(task_def): original_output_refs.update(refs) for qnode in qgraph: - for refs in qnode.quantum.outputs.values(): - original_output_refs.update(refs) + for otherRefs in qnode.quantum.outputs.values(): + original_output_refs.update(otherRefs) # Get data repository definitions from the QuantumGraph; these can have # different storage classes than those in the quanta. diff --git a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py index 61edcd24..514f0a04 100644 --- a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py +++ b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py @@ -283,9 +283,7 @@ def runQuantum( # store mock outputs for name, refs in outputRefs: - if not isinstance(refs, list): - refs = [refs] - for ref in refs: + for ref in ensure_iterable(refs): output = MockDataset( ref=ref.to_simple(), quantum=mock_dataset_quantum, output_connection_name=name ) diff --git a/python/lsst/pipe/base/tests/simpleQGraph.py b/python/lsst/pipe/base/tests/simpleQGraph.py index 8b50bf42..6a5289d4 100644 --- a/python/lsst/pipe/base/tests/simpleQGraph.py +++ b/python/lsst/pipe/base/tests/simpleQGraph.py @@ -307,7 +307,7 @@ def populateButler( instrument = pipeline.getInstrument() if instrument is not None: instrument_class = doImportType(instrument) - instrumentName = instrument_class.getName() + instrumentName = cast(Instrument, instrument_class).getName() instrumentClass = get_full_type_name(instrument_class) else: instrumentName = "INSTR" diff --git a/python/lsst/pipe/base/tests/util.py b/python/lsst/pipe/base/tests/util.py index 832cd91c..7d9aa2b5 100644 --- a/python/lsst/pipe/base/tests/util.py +++ b/python/lsst/pipe/base/tests/util.py @@ -46,8 +46,8 @@ def check_output_run(graph: QuantumGraph, run: str) -> list[DatasetRef]: the specified run. """ # Collect all inputs/outputs, so that we can build intermediate refs. - output_refs = [] - input_refs = [] + output_refs: list[DatasetRef] = [] + input_refs: list[DatasetRef] = [] for node in graph: for refs in node.quantum.outputs.values(): output_refs += refs @@ -61,10 +61,10 @@ def check_output_run(graph: QuantumGraph, run: str) -> list[DatasetRef]: if init_refs: input_refs += init_refs output_refs += graph.globalInitOutputRefs() - refs = [ref for ref in output_refs if ref.run != run] + newRefs = [ref for ref in output_refs if ref.run != run] output_ids = {ref.id for ref in output_refs} intermediates = [ref for ref in input_refs if ref.id in output_ids] - refs += [ref for ref in intermediates if ref.run != run] + newRefs += [ref for ref in intermediates if ref.run != run] - return refs + return newRefs diff --git a/requirements.txt b/requirements.txt index 882f2164..79d5616f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ pyyaml >= 5.1 -pydantic +pydantic < 2 numpy >= 1.17 networkx frozendict