From 97137a213f11d4a77f145d6826f15e9cd82db7ce Mon Sep 17 00:00:00 2001 From: Nate Lust Date: Tue, 27 Jun 2023 14:52:38 -0400 Subject: [PATCH] Optimize memory and load times on deserialization Often may butler primitives are deserialized at the same time, and it is useful for these objects to share references to each other. This reduces load time and memory usage. --- python/lsst/daf/butler/core/datasets/ref.py | 40 ++++++++--- python/lsst/daf/butler/core/datasets/type.py | 19 ++++- .../daf/butler/core/datastoreRecordData.py | 17 ++++- .../daf/butler/core/dimensions/_coordinate.py | 13 ++++ .../daf/butler/core/dimensions/_records.py | 23 +++++- python/lsst/daf/butler/core/quantum.py | 70 +++++++++---------- 6 files changed, 132 insertions(+), 50 deletions(-) diff --git a/python/lsst/daf/butler/core/datasets/ref.py b/python/lsst/daf/butler/core/datasets/ref.py index 005fd5b1ff..4bc0232f25 100644 --- a/python/lsst/daf/butler/core/datasets/ref.py +++ b/python/lsst/daf/butler/core/datasets/ref.py @@ -30,6 +30,7 @@ ] import enum +import sys import uuid from collections.abc import Iterable from typing import TYPE_CHECKING, Any, ClassVar @@ -41,6 +42,7 @@ from ..dimensions import DataCoordinate, DimensionGraph, DimensionUniverse, SerializedDataCoordinate from ..json import from_json_pydantic, to_json_pydantic from ..named import NamedKeyDict +from ..persistenceContext import PersistenceContextVars from .type import DatasetType, SerializedDatasetType if TYPE_CHECKING: @@ -142,6 +144,10 @@ def makeDatasetId( return uuid.uuid5(self.NS_UUID, data) +# This is constant, so don't recreate a set for each instance +_serializedDatasetRefFieldsSet = {"id", "datasetType", "dataId", "run", "component"} + + class SerializedDatasetRef(BaseModel): """Simplified model of a `DatasetRef` suitable for serialization.""" @@ -202,9 +208,9 @@ def direct( datasetType if datasetType is None else SerializedDatasetType.direct(**datasetType), ) setter(node, "dataId", dataId if dataId is None else SerializedDataCoordinate.direct(**dataId)) - setter(node, "run", run) + setter(node, "run", sys.intern(run)) setter(node, "component", component) - setter(node, "__fields_set__", {"id", "datasetType", "dataId", "run", "component"}) + setter(node, "__fields_set__", _serializedDatasetRefFieldsSet) return node @@ -254,7 +260,7 @@ class DatasetRef: _serializedType = SerializedDatasetRef __slots__ = ( - "id", + "_id", "datasetType", "dataId", "run", @@ -277,11 +283,15 @@ def __init__( self.dataId = dataId self.run = run if id is not None: - self.id = id + self._id = id.int else: - self.id = DatasetIdFactory().makeDatasetId( + self._id = DatasetIdFactory().makeDatasetId( self.run, self.datasetType, self.dataId, id_generation_mode - ) + ).int + + @property + def id(self) -> DatasetId: + return uuid.UUID(int=self._id) def __eq__(self, other: Any) -> bool: try: @@ -396,9 +406,18 @@ def from_simple( ref : `DatasetRef` Newly-constructed object. """ + cache = PersistenceContextVars.datasetRefs.get() + localName = sys.intern( + datasetType.name + if datasetType is not None + else (x.name if (x := simple.datasetType) is not None else "") + ) + key = (simple.id.int, localName) + if cache is not None and (cachedRef := cache.get(key, None)) is not None: + return cachedRef # Minimalist component will just specify component and id and # require registry to reconstruct - if set(simple.dict(exclude_unset=True, exclude_defaults=True)).issubset({"id", "component"}): + if not (simple.datasetType is not None or simple.dataId is not None or simple.run is not None): if registry is None: raise ValueError("Registry is required to construct component DatasetRef from integer id") if simple.id is None: @@ -408,6 +427,8 @@ def from_simple( raise RuntimeError(f"No matching dataset found in registry for id {simple.id}") if simple.component: ref = ref.makeComponentRef(simple.component) + if cache is not None: + cache[key] = ref return ref if universe is None and registry is None: @@ -443,7 +464,10 @@ def from_simple( f"Encountered with {simple!r}{dstr}." ) - return cls(datasetType, dataId, id=simple.id, run=simple.run) + newRef = cls(datasetType, dataId, id=simple.id, run=simple.run) + if cache is not None: + cache[key] = newRef + return newRef to_json = to_json_pydantic from_json: ClassVar = classmethod(from_json_pydantic) diff --git a/python/lsst/daf/butler/core/datasets/type.py b/python/lsst/daf/butler/core/datasets/type.py index 1ddbc018c0..72f714d624 100644 --- a/python/lsst/daf/butler/core/datasets/type.py +++ b/python/lsst/daf/butler/core/datasets/type.py @@ -34,6 +34,7 @@ from ..configSupport import LookupKey from ..dimensions import DimensionGraph, SerializedDimensionGraph from ..json import from_json_pydantic, to_json_pydantic +from ..persistenceContext import PersistenceContextVars from ..storageClass import StorageClass, StorageClassFactory if TYPE_CHECKING: @@ -74,6 +75,10 @@ def direct( This method should only be called when the inputs are trusted. """ + cache = PersistenceContextVars.serializedDatasetTypeMapping.get() + key = (name, storageClass or "") + if cache is not None and (type_ := cache.get(key, None)) is not None: + return type_ node = SerializedDatasetType.__new__(cls) setter = object.__setattr__ setter(node, "name", name) @@ -90,6 +95,8 @@ def direct( "__fields_set__", {"name", "storageClass", "dimensions", "parentStorageClass", "isCalibration"}, ) + if cache is not None: + cache[key] = node return node @@ -685,6 +692,13 @@ def from_simple( datasetType : `DatasetType` Newly-constructed object. """ + # check to see if there is a cache, and if there is, if there is a + # cached dataset type + cache = PersistenceContextVars.loadedTypes.get() + key = (simple.name, simple.storageClass or "") + if cache is not None and (type_ := cache.get(key, None)) is not None: + return type_ + if simple.storageClass is None: # Treat this as minimalist representation if registry is None: @@ -708,7 +722,7 @@ def from_simple( # mypy hint raise ValueError(f"Dimensions must be specified in {simple}") - return cls( + newType = cls( name=simple.name, dimensions=DimensionGraph.from_simple(simple.dimensions, universe=universe), storageClass=simple.storageClass, @@ -716,6 +730,9 @@ def from_simple( parentStorageClass=simple.parentStorageClass, universe=universe, ) + if cache is not None: + cache[key] = newType + return newType to_json = to_json_pydantic from_json: ClassVar = classmethod(from_json_pydantic) diff --git a/python/lsst/daf/butler/core/datastoreRecordData.py b/python/lsst/daf/butler/core/datastoreRecordData.py index c6d9f31e0b..58b8da8b42 100644 --- a/python/lsst/daf/butler/core/datastoreRecordData.py +++ b/python/lsst/daf/butler/core/datastoreRecordData.py @@ -36,6 +36,7 @@ from .datasets import DatasetId from .dimensions import DimensionUniverse +from .persistenceContext import PersistenceContextVars from .storedFileInfo import StoredDatastoreItemInfo if TYPE_CHECKING: @@ -70,6 +71,11 @@ def direct( This method should only be called when the inputs are trusted. """ + key = frozenset(dataset_ids) + cache = PersistenceContextVars.serializedDatastoreRecordMapping.get() + if cache is not None and (value := cache.get(key)) is not None: + return value + data = SerializedDatastoreRecordData.__new__(cls) setter = object.__setattr__ # JSON makes strings out of UUIDs, need to convert them back @@ -83,6 +89,8 @@ def direct( if (id := record.get("dataset_id")) is not None: record["dataset_id"] = uuid.UUID(id) if isinstance(id, str) else id setter(data, "records", records) + if cache is not None: + cache[key] = data return data @@ -204,6 +212,10 @@ def from_simple( item_info : `StoredDatastoreItemInfo` De-serialized instance of `StoredDatastoreItemInfo`. """ + cache = PersistenceContextVars.dataStoreRecords.get() + key = frozenset(simple.dataset_ids) + if cache is not None and (record := cache.get(key)) is not None: + return record records: dict[DatasetId, dict[str, list[StoredDatastoreItemInfo]]] = {} # make sure that all dataset IDs appear in the dict even if they don't # have records. @@ -216,4 +228,7 @@ def from_simple( info = klass.from_record(record) dataset_type_records = records.setdefault(info.dataset_id, {}) dataset_type_records.setdefault(table_name, []).append(info) - return cls(records=records) + record = cls(records=records) + if cache is not None: + cache[key] = record + return record diff --git a/python/lsst/daf/butler/core/dimensions/_coordinate.py b/python/lsst/daf/butler/core/dimensions/_coordinate.py index 2c104b1d46..34576f9c3b 100644 --- a/python/lsst/daf/butler/core/dimensions/_coordinate.py +++ b/python/lsst/daf/butler/core/dimensions/_coordinate.py @@ -39,6 +39,7 @@ from ..json import from_json_pydantic, to_json_pydantic from ..named import NamedKeyDict, NamedKeyMapping, NamedValueAbstractSet, NameLookupMapping +from ..persistenceContext import PersistenceContextVars from ..timespan import Timespan from ._elements import Dimension, DimensionElement from ._graph import DimensionGraph @@ -76,6 +77,10 @@ def direct(cls, *, dataId: dict[str, DataIdValue], records: dict[str, dict]) -> This method should only be called when the inputs are trusted. """ + key = (frozenset(dataId.items()), records is not None) + cache = PersistenceContextVars.serializedDataCoordinateMapping.get() + if cache is not None and (result := cache.get(key)) is not None: + return result node = SerializedDataCoordinate.__new__(cls) setter = object.__setattr__ setter(node, "dataId", dataId) @@ -87,6 +92,8 @@ def direct(cls, *, dataId: dict[str, DataIdValue], records: dict[str, dict]) -> else {k: SerializedDimensionRecord.direct(**v) for k, v in records.items()}, ) setter(node, "__fields_set__", {"dataId", "records"}) + if cache is not None: + cache[key] = node return node @@ -730,6 +737,10 @@ def from_simple( dataId : `DataCoordinate` Newly-constructed object. """ + key = (frozenset(simple.dataId.items()), simple.records is not None) + cache = PersistenceContextVars.dataCoordinates.get() + if cache is not None and (result := cache.get(key)) is not None: + return result if universe is None and registry is None: raise ValueError("One of universe or registry is required to convert a dict to a DataCoordinate") if universe is None and registry is not None: @@ -743,6 +754,8 @@ def from_simple( dataId = dataId.expanded( {k: DimensionRecord.from_simple(v, universe=universe) for k, v in simple.records.items()} ) + if cache is not None: + cache[key] = dataId return dataId to_json = to_json_pydantic diff --git a/python/lsst/daf/butler/core/dimensions/_records.py b/python/lsst/daf/butler/core/dimensions/_records.py index 95f7f12bd8..b5a0e6a04f 100644 --- a/python/lsst/daf/butler/core/dimensions/_records.py +++ b/python/lsst/daf/butler/core/dimensions/_records.py @@ -30,6 +30,7 @@ from pydantic import BaseModel, Field, StrictBool, StrictFloat, StrictInt, StrictStr, create_model from ..json import from_json_pydantic, to_json_pydantic +from ..persistenceContext import PersistenceContextVars from ..timespan import Timespan, TimespanDatabaseRepresentation from ._elements import Dimension, DimensionElement @@ -166,7 +167,13 @@ def direct( This method should only be called when the inputs are trusted. """ - node = cls.construct(definition=definition, record=record) + key = ( + definition, + frozenset((k, v if not isinstance(v, list) else tuple(v)) for k, v in record.items()), + ) + cache = PersistenceContextVars.serializedDimensionRecordMapping.get() + if cache is not None and (result := cache.get(key)) is not None: + return result node = SerializedDimensionRecord.__new__(cls) setter = object.__setattr__ setter(node, "definition", definition) @@ -177,6 +184,8 @@ def direct( node, "record", {k: v if type(v) != list else tuple(v) for k, v in record.items()} # type: ignore ) setter(node, "__fields_set__", {"definition", "record"}) + if cache is not None: + cache[key] = node return node @@ -367,6 +376,13 @@ def from_simple( if universe is None: # this is for mypy raise ValueError("Unable to determine a usable universe") + key = ( + simple.definition, + frozenset((k, v if not isinstance(v, list) else tuple(v)) for k, v in simple.record.items()), + ) + cache = PersistenceContextVars.dimensionRecords.get() + if cache is not None and (result := cache.get(key)) is not None: + return result definition = DimensionElement.from_simple(simple.definition, universe=universe) @@ -389,7 +405,10 @@ def from_simple( if (hsh := "hash") in rec: rec[hsh] = bytes.fromhex(rec[hsh].decode()) - return _reconstructDimensionRecord(definition, rec) + dimRec = _reconstructDimensionRecord(definition, rec) + if cache is not None: + cache[key] = dimRec + return dimRec to_json = to_json_pydantic from_json: ClassVar = classmethod(from_json_pydantic) diff --git a/python/lsst/daf/butler/core/quantum.py b/python/lsst/daf/butler/core/quantum.py index d3cdb77e89..58954c63d2 100644 --- a/python/lsst/daf/butler/core/quantum.py +++ b/python/lsst/daf/butler/core/quantum.py @@ -25,6 +25,8 @@ from collections.abc import Iterable, Mapping, MutableMapping from typing import Any +import sys +import warnings from lsst.utils import doImportType from pydantic import BaseModel @@ -46,7 +48,6 @@ def _reconstructDatasetRef( type_: DatasetType | None, ids: Iterable[int], dimensionRecords: dict[int, SerializedDimensionRecord] | None, - reconstitutedDimensions: dict[int, tuple[str, DimensionRecord]], universe: DimensionUniverse, ) -> DatasetRef: """Reconstruct a DatasetRef stored in a Serialized Quantum.""" @@ -55,19 +56,13 @@ def _reconstructDatasetRef( for dId in ids: # if the dimension record has been loaded previously use that, # otherwise load it from the dict of Serialized DimensionRecords - if (recId := reconstitutedDimensions.get(dId)) is None: - if dimensionRecords is None: - raise ValueError( - "Cannot construct from a SerializedQuantum with no dimension records. " - "Reconstituted Dimensions must be supplied and populated in method call." - ) - tmpSerialized = dimensionRecords[dId] - reconstructedDim = DimensionRecord.from_simple(tmpSerialized, universe=universe) - definition = tmpSerialized.definition - reconstitutedDimensions[dId] = (definition, reconstructedDim) - else: - definition, reconstructedDim = recId - records[definition] = reconstructedDim + if dimensionRecords is None: + raise ValueError( + "Cannot construct from a SerializedQuantum with no dimension records. " + ) + tmpSerialized = dimensionRecords[dId] + reconstructedDim = DimensionRecord.from_simple(tmpSerialized, universe=universe) + records[sys.intern(reconstructedDim.definition.name)] = reconstructedDim # turn the serialized form into an object and attach the dimension records rebuiltDatasetRef = DatasetRef.from_simple(simple, universe, datasetType=type_) if records: @@ -110,13 +105,15 @@ def direct( """ node = SerializedQuantum.__new__(cls) setter = object.__setattr__ - setter(node, "taskName", taskName) + setter(node, "taskName", sys.intern(taskName)) setter(node, "dataId", dataId if dataId is None else SerializedDataCoordinate.direct(**dataId)) + setter( node, "datasetTypeMapping", {k: SerializedDatasetType.direct(**v) for k, v in datasetTypeMapping.items()}, ) + setter( node, "initInputs", @@ -207,7 +204,6 @@ class Quantum: "_initInputs", "_inputs", "_outputs", - "_hash", "_datastore_records", ) @@ -236,8 +232,12 @@ def __init__( if outputs is None: outputs = {} self._initInputs = NamedKeyDict[DatasetType, DatasetRef](initInputs).freeze() - self._inputs = NamedKeyDict[DatasetType, list[DatasetRef]](inputs).freeze() - self._outputs = NamedKeyDict[DatasetType, list[DatasetRef]](outputs).freeze() + self._inputs = NamedKeyDict[DatasetType, tuple[DatasetRef]]( + (k, tuple(v)) for k, v in inputs.items() + ).freeze() + self._outputs = NamedKeyDict[DatasetType, tuple[DatasetRef]]( + (k, tuple(v)) for k, v in outputs.items() + ).freeze() if datastore_records is None: datastore_records = {} self._datastore_records = datastore_records @@ -412,23 +412,21 @@ def from_simple( required dimension has already been loaded. Otherwise the record will be unpersisted from the SerializedQuatnum and added to the reconstitutedDimensions dict (if not None). Defaults to None. + Deprecated, any argument will be ignored. """ - loadedTypes: MutableMapping[str, DatasetType] = {} initInputs: MutableMapping[DatasetType, DatasetRef] = {} - if reconstitutedDimensions is None: - reconstitutedDimensions = {} + if reconstitutedDimensions is not None: + warnings.warn( + "The reconstitutedDimensions argument is now ignored and may be removed after v 27", + category=DeprecationWarning, + ) # Unpersist all the init inputs for key, (value, dimensionIds) in simple.initInputs.items(): - # If a datasetType has already been created use that instead of - # unpersisting. - if (type_ := loadedTypes.get(key)) is None: - type_ = loadedTypes.setdefault( - key, DatasetType.from_simple(simple.datasetTypeMapping[key], universe=universe) - ) + type_ = DatasetType.from_simple(simple.datasetTypeMapping[key], universe=universe) # reconstruct the dimension records rebuiltDatasetRef = _reconstructDatasetRef( - value, type_, dimensionIds, simple.dimensionRecords, reconstitutedDimensions, universe + value, type_, dimensionIds, simple.dimensionRecords, universe ) initInputs[type_] = rebuiltDatasetRef @@ -438,17 +436,12 @@ def from_simple( for container, simpleRefs in ((inputs, simple.inputs), (outputs, simple.outputs)): for key, values in simpleRefs.items(): - # If a datasetType has already been created use that instead of - # unpersisting. - if (type_ := loadedTypes.get(key)) is None: - type_ = loadedTypes.setdefault( - key, DatasetType.from_simple(simple.datasetTypeMapping[key], universe=universe) - ) + type_ = DatasetType.from_simple(simple.datasetTypeMapping[key], universe=universe) # reconstruct the list of DatasetRefs for this DatasetType tmp: list[DatasetRef] = [] for v, recIds in values: rebuiltDatasetRef = _reconstructDatasetRef( - v, type_, recIds, simple.dimensionRecords, reconstitutedDimensions, universe + v, type_, recIds, simple.dimensionRecords, universe ) tmp.append(rebuiltDatasetRef) container[type_] = tmp @@ -466,7 +459,7 @@ def from_simple( for datastore_name, record_data in simple.datastoreRecords.items() } - return Quantum( + quant = Quantum( taskName=simple.taskName, dataId=dataId, initInputs=initInputs, @@ -474,6 +467,7 @@ def from_simple( outputs=outputs, datastore_records=datastore_records, ) + return quant @property def taskClass(self) -> type | None: @@ -508,7 +502,7 @@ def initInputs(self) -> NamedKeyMapping[DatasetType, DatasetRef]: return self._initInputs @property - def inputs(self) -> NamedKeyMapping[DatasetType, list[DatasetRef]]: + def inputs(self) -> NamedKeyMapping[DatasetType, tuple[DatasetRef]]: """Return mapping of input datasets that were expected to be used. Has `DatasetType` instances as keys (names can also be used for @@ -523,7 +517,7 @@ def inputs(self) -> NamedKeyMapping[DatasetType, list[DatasetRef]]: return self._inputs @property - def outputs(self) -> NamedKeyMapping[DatasetType, list[DatasetRef]]: + def outputs(self) -> NamedKeyMapping[DatasetType, tuple[DatasetRef]]: """Return mapping of output datasets (to be) generated by this quantum. Has the same form as `predictedInputs`.