diff --git a/python/lsst/daf/butler/core/datasets/ref.py b/python/lsst/daf/butler/core/datasets/ref.py index 005fd5b1ff..4bc0232f25 100644 --- a/python/lsst/daf/butler/core/datasets/ref.py +++ b/python/lsst/daf/butler/core/datasets/ref.py @@ -30,6 +30,7 @@ ] import enum +import sys import uuid from collections.abc import Iterable from typing import TYPE_CHECKING, Any, ClassVar @@ -41,6 +42,7 @@ from ..dimensions import DataCoordinate, DimensionGraph, DimensionUniverse, SerializedDataCoordinate from ..json import from_json_pydantic, to_json_pydantic from ..named import NamedKeyDict +from ..persistenceContext import PersistenceContextVars from .type import DatasetType, SerializedDatasetType if TYPE_CHECKING: @@ -142,6 +144,10 @@ def makeDatasetId( return uuid.uuid5(self.NS_UUID, data) +# This is constant, so don't recreate a set for each instance +_serializedDatasetRefFieldsSet = {"id", "datasetType", "dataId", "run", "component"} + + class SerializedDatasetRef(BaseModel): """Simplified model of a `DatasetRef` suitable for serialization.""" @@ -202,9 +208,9 @@ def direct( datasetType if datasetType is None else SerializedDatasetType.direct(**datasetType), ) setter(node, "dataId", dataId if dataId is None else SerializedDataCoordinate.direct(**dataId)) - setter(node, "run", run) + setter(node, "run", sys.intern(run)) setter(node, "component", component) - setter(node, "__fields_set__", {"id", "datasetType", "dataId", "run", "component"}) + setter(node, "__fields_set__", _serializedDatasetRefFieldsSet) return node @@ -254,7 +260,7 @@ class DatasetRef: _serializedType = SerializedDatasetRef __slots__ = ( - "id", + "_id", "datasetType", "dataId", "run", @@ -277,11 +283,15 @@ def __init__( self.dataId = dataId self.run = run if id is not None: - self.id = id + self._id = id.int else: - self.id = DatasetIdFactory().makeDatasetId( + self._id = DatasetIdFactory().makeDatasetId( self.run, self.datasetType, self.dataId, id_generation_mode - ) + ).int + + @property + def id(self) -> DatasetId: + return uuid.UUID(int=self._id) def __eq__(self, other: Any) -> bool: try: @@ -396,9 +406,18 @@ def from_simple( ref : `DatasetRef` Newly-constructed object. """ + cache = PersistenceContextVars.datasetRefs.get() + localName = sys.intern( + datasetType.name + if datasetType is not None + else (x.name if (x := simple.datasetType) is not None else "") + ) + key = (simple.id.int, localName) + if cache is not None and (cachedRef := cache.get(key, None)) is not None: + return cachedRef # Minimalist component will just specify component and id and # require registry to reconstruct - if set(simple.dict(exclude_unset=True, exclude_defaults=True)).issubset({"id", "component"}): + if not (simple.datasetType is not None or simple.dataId is not None or simple.run is not None): if registry is None: raise ValueError("Registry is required to construct component DatasetRef from integer id") if simple.id is None: @@ -408,6 +427,8 @@ def from_simple( raise RuntimeError(f"No matching dataset found in registry for id {simple.id}") if simple.component: ref = ref.makeComponentRef(simple.component) + if cache is not None: + cache[key] = ref return ref if universe is None and registry is None: @@ -443,7 +464,10 @@ def from_simple( f"Encountered with {simple!r}{dstr}." ) - return cls(datasetType, dataId, id=simple.id, run=simple.run) + newRef = cls(datasetType, dataId, id=simple.id, run=simple.run) + if cache is not None: + cache[key] = newRef + return newRef to_json = to_json_pydantic from_json: ClassVar = classmethod(from_json_pydantic) diff --git a/python/lsst/daf/butler/core/datasets/type.py b/python/lsst/daf/butler/core/datasets/type.py index 1ddbc018c0..72f714d624 100644 --- a/python/lsst/daf/butler/core/datasets/type.py +++ b/python/lsst/daf/butler/core/datasets/type.py @@ -34,6 +34,7 @@ from ..configSupport import LookupKey from ..dimensions import DimensionGraph, SerializedDimensionGraph from ..json import from_json_pydantic, to_json_pydantic +from ..persistenceContext import PersistenceContextVars from ..storageClass import StorageClass, StorageClassFactory if TYPE_CHECKING: @@ -74,6 +75,10 @@ def direct( This method should only be called when the inputs are trusted. """ + cache = PersistenceContextVars.serializedDatasetTypeMapping.get() + key = (name, storageClass or "") + if cache is not None and (type_ := cache.get(key, None)) is not None: + return type_ node = SerializedDatasetType.__new__(cls) setter = object.__setattr__ setter(node, "name", name) @@ -90,6 +95,8 @@ def direct( "__fields_set__", {"name", "storageClass", "dimensions", "parentStorageClass", "isCalibration"}, ) + if cache is not None: + cache[key] = node return node @@ -685,6 +692,13 @@ def from_simple( datasetType : `DatasetType` Newly-constructed object. """ + # check to see if there is a cache, and if there is, if there is a + # cached dataset type + cache = PersistenceContextVars.loadedTypes.get() + key = (simple.name, simple.storageClass or "") + if cache is not None and (type_ := cache.get(key, None)) is not None: + return type_ + if simple.storageClass is None: # Treat this as minimalist representation if registry is None: @@ -708,7 +722,7 @@ def from_simple( # mypy hint raise ValueError(f"Dimensions must be specified in {simple}") - return cls( + newType = cls( name=simple.name, dimensions=DimensionGraph.from_simple(simple.dimensions, universe=universe), storageClass=simple.storageClass, @@ -716,6 +730,9 @@ def from_simple( parentStorageClass=simple.parentStorageClass, universe=universe, ) + if cache is not None: + cache[key] = newType + return newType to_json = to_json_pydantic from_json: ClassVar = classmethod(from_json_pydantic) diff --git a/python/lsst/daf/butler/core/datastoreRecordData.py b/python/lsst/daf/butler/core/datastoreRecordData.py index c6d9f31e0b..58b8da8b42 100644 --- a/python/lsst/daf/butler/core/datastoreRecordData.py +++ b/python/lsst/daf/butler/core/datastoreRecordData.py @@ -36,6 +36,7 @@ from .datasets import DatasetId from .dimensions import DimensionUniverse +from .persistenceContext import PersistenceContextVars from .storedFileInfo import StoredDatastoreItemInfo if TYPE_CHECKING: @@ -70,6 +71,11 @@ def direct( This method should only be called when the inputs are trusted. """ + key = frozenset(dataset_ids) + cache = PersistenceContextVars.serializedDatastoreRecordMapping.get() + if cache is not None and (value := cache.get(key)) is not None: + return value + data = SerializedDatastoreRecordData.__new__(cls) setter = object.__setattr__ # JSON makes strings out of UUIDs, need to convert them back @@ -83,6 +89,8 @@ def direct( if (id := record.get("dataset_id")) is not None: record["dataset_id"] = uuid.UUID(id) if isinstance(id, str) else id setter(data, "records", records) + if cache is not None: + cache[key] = data return data @@ -204,6 +212,10 @@ def from_simple( item_info : `StoredDatastoreItemInfo` De-serialized instance of `StoredDatastoreItemInfo`. """ + cache = PersistenceContextVars.dataStoreRecords.get() + key = frozenset(simple.dataset_ids) + if cache is not None and (record := cache.get(key)) is not None: + return record records: dict[DatasetId, dict[str, list[StoredDatastoreItemInfo]]] = {} # make sure that all dataset IDs appear in the dict even if they don't # have records. @@ -216,4 +228,7 @@ def from_simple( info = klass.from_record(record) dataset_type_records = records.setdefault(info.dataset_id, {}) dataset_type_records.setdefault(table_name, []).append(info) - return cls(records=records) + record = cls(records=records) + if cache is not None: + cache[key] = record + return record diff --git a/python/lsst/daf/butler/core/dimensions/_coordinate.py b/python/lsst/daf/butler/core/dimensions/_coordinate.py index 2c104b1d46..34576f9c3b 100644 --- a/python/lsst/daf/butler/core/dimensions/_coordinate.py +++ b/python/lsst/daf/butler/core/dimensions/_coordinate.py @@ -39,6 +39,7 @@ from ..json import from_json_pydantic, to_json_pydantic from ..named import NamedKeyDict, NamedKeyMapping, NamedValueAbstractSet, NameLookupMapping +from ..persistenceContext import PersistenceContextVars from ..timespan import Timespan from ._elements import Dimension, DimensionElement from ._graph import DimensionGraph @@ -76,6 +77,10 @@ def direct(cls, *, dataId: dict[str, DataIdValue], records: dict[str, dict]) -> This method should only be called when the inputs are trusted. """ + key = (frozenset(dataId.items()), records is not None) + cache = PersistenceContextVars.serializedDataCoordinateMapping.get() + if cache is not None and (result := cache.get(key)) is not None: + return result node = SerializedDataCoordinate.__new__(cls) setter = object.__setattr__ setter(node, "dataId", dataId) @@ -87,6 +92,8 @@ def direct(cls, *, dataId: dict[str, DataIdValue], records: dict[str, dict]) -> else {k: SerializedDimensionRecord.direct(**v) for k, v in records.items()}, ) setter(node, "__fields_set__", {"dataId", "records"}) + if cache is not None: + cache[key] = node return node @@ -730,6 +737,10 @@ def from_simple( dataId : `DataCoordinate` Newly-constructed object. """ + key = (frozenset(simple.dataId.items()), simple.records is not None) + cache = PersistenceContextVars.dataCoordinates.get() + if cache is not None and (result := cache.get(key)) is not None: + return result if universe is None and registry is None: raise ValueError("One of universe or registry is required to convert a dict to a DataCoordinate") if universe is None and registry is not None: @@ -743,6 +754,8 @@ def from_simple( dataId = dataId.expanded( {k: DimensionRecord.from_simple(v, universe=universe) for k, v in simple.records.items()} ) + if cache is not None: + cache[key] = dataId return dataId to_json = to_json_pydantic diff --git a/python/lsst/daf/butler/core/dimensions/_records.py b/python/lsst/daf/butler/core/dimensions/_records.py index 95f7f12bd8..b5a0e6a04f 100644 --- a/python/lsst/daf/butler/core/dimensions/_records.py +++ b/python/lsst/daf/butler/core/dimensions/_records.py @@ -30,6 +30,7 @@ from pydantic import BaseModel, Field, StrictBool, StrictFloat, StrictInt, StrictStr, create_model from ..json import from_json_pydantic, to_json_pydantic +from ..persistenceContext import PersistenceContextVars from ..timespan import Timespan, TimespanDatabaseRepresentation from ._elements import Dimension, DimensionElement @@ -166,7 +167,13 @@ def direct( This method should only be called when the inputs are trusted. """ - node = cls.construct(definition=definition, record=record) + key = ( + definition, + frozenset((k, v if not isinstance(v, list) else tuple(v)) for k, v in record.items()), + ) + cache = PersistenceContextVars.serializedDimensionRecordMapping.get() + if cache is not None and (result := cache.get(key)) is not None: + return result node = SerializedDimensionRecord.__new__(cls) setter = object.__setattr__ setter(node, "definition", definition) @@ -177,6 +184,8 @@ def direct( node, "record", {k: v if type(v) != list else tuple(v) for k, v in record.items()} # type: ignore ) setter(node, "__fields_set__", {"definition", "record"}) + if cache is not None: + cache[key] = node return node @@ -367,6 +376,13 @@ def from_simple( if universe is None: # this is for mypy raise ValueError("Unable to determine a usable universe") + key = ( + simple.definition, + frozenset((k, v if not isinstance(v, list) else tuple(v)) for k, v in simple.record.items()), + ) + cache = PersistenceContextVars.dimensionRecords.get() + if cache is not None and (result := cache.get(key)) is not None: + return result definition = DimensionElement.from_simple(simple.definition, universe=universe) @@ -389,7 +405,10 @@ def from_simple( if (hsh := "hash") in rec: rec[hsh] = bytes.fromhex(rec[hsh].decode()) - return _reconstructDimensionRecord(definition, rec) + dimRec = _reconstructDimensionRecord(definition, rec) + if cache is not None: + cache[key] = dimRec + return dimRec to_json = to_json_pydantic from_json: ClassVar = classmethod(from_json_pydantic) diff --git a/python/lsst/daf/butler/core/quantum.py b/python/lsst/daf/butler/core/quantum.py index d3cdb77e89..58954c63d2 100644 --- a/python/lsst/daf/butler/core/quantum.py +++ b/python/lsst/daf/butler/core/quantum.py @@ -25,6 +25,8 @@ from collections.abc import Iterable, Mapping, MutableMapping from typing import Any +import sys +import warnings from lsst.utils import doImportType from pydantic import BaseModel @@ -46,7 +48,6 @@ def _reconstructDatasetRef( type_: DatasetType | None, ids: Iterable[int], dimensionRecords: dict[int, SerializedDimensionRecord] | None, - reconstitutedDimensions: dict[int, tuple[str, DimensionRecord]], universe: DimensionUniverse, ) -> DatasetRef: """Reconstruct a DatasetRef stored in a Serialized Quantum.""" @@ -55,19 +56,13 @@ def _reconstructDatasetRef( for dId in ids: # if the dimension record has been loaded previously use that, # otherwise load it from the dict of Serialized DimensionRecords - if (recId := reconstitutedDimensions.get(dId)) is None: - if dimensionRecords is None: - raise ValueError( - "Cannot construct from a SerializedQuantum with no dimension records. " - "Reconstituted Dimensions must be supplied and populated in method call." - ) - tmpSerialized = dimensionRecords[dId] - reconstructedDim = DimensionRecord.from_simple(tmpSerialized, universe=universe) - definition = tmpSerialized.definition - reconstitutedDimensions[dId] = (definition, reconstructedDim) - else: - definition, reconstructedDim = recId - records[definition] = reconstructedDim + if dimensionRecords is None: + raise ValueError( + "Cannot construct from a SerializedQuantum with no dimension records. " + ) + tmpSerialized = dimensionRecords[dId] + reconstructedDim = DimensionRecord.from_simple(tmpSerialized, universe=universe) + records[sys.intern(reconstructedDim.definition.name)] = reconstructedDim # turn the serialized form into an object and attach the dimension records rebuiltDatasetRef = DatasetRef.from_simple(simple, universe, datasetType=type_) if records: @@ -110,13 +105,15 @@ def direct( """ node = SerializedQuantum.__new__(cls) setter = object.__setattr__ - setter(node, "taskName", taskName) + setter(node, "taskName", sys.intern(taskName)) setter(node, "dataId", dataId if dataId is None else SerializedDataCoordinate.direct(**dataId)) + setter( node, "datasetTypeMapping", {k: SerializedDatasetType.direct(**v) for k, v in datasetTypeMapping.items()}, ) + setter( node, "initInputs", @@ -207,7 +204,6 @@ class Quantum: "_initInputs", "_inputs", "_outputs", - "_hash", "_datastore_records", ) @@ -236,8 +232,12 @@ def __init__( if outputs is None: outputs = {} self._initInputs = NamedKeyDict[DatasetType, DatasetRef](initInputs).freeze() - self._inputs = NamedKeyDict[DatasetType, list[DatasetRef]](inputs).freeze() - self._outputs = NamedKeyDict[DatasetType, list[DatasetRef]](outputs).freeze() + self._inputs = NamedKeyDict[DatasetType, tuple[DatasetRef]]( + (k, tuple(v)) for k, v in inputs.items() + ).freeze() + self._outputs = NamedKeyDict[DatasetType, tuple[DatasetRef]]( + (k, tuple(v)) for k, v in outputs.items() + ).freeze() if datastore_records is None: datastore_records = {} self._datastore_records = datastore_records @@ -412,23 +412,21 @@ def from_simple( required dimension has already been loaded. Otherwise the record will be unpersisted from the SerializedQuatnum and added to the reconstitutedDimensions dict (if not None). Defaults to None. + Deprecated, any argument will be ignored. """ - loadedTypes: MutableMapping[str, DatasetType] = {} initInputs: MutableMapping[DatasetType, DatasetRef] = {} - if reconstitutedDimensions is None: - reconstitutedDimensions = {} + if reconstitutedDimensions is not None: + warnings.warn( + "The reconstitutedDimensions argument is now ignored and may be removed after v 27", + category=DeprecationWarning, + ) # Unpersist all the init inputs for key, (value, dimensionIds) in simple.initInputs.items(): - # If a datasetType has already been created use that instead of - # unpersisting. - if (type_ := loadedTypes.get(key)) is None: - type_ = loadedTypes.setdefault( - key, DatasetType.from_simple(simple.datasetTypeMapping[key], universe=universe) - ) + type_ = DatasetType.from_simple(simple.datasetTypeMapping[key], universe=universe) # reconstruct the dimension records rebuiltDatasetRef = _reconstructDatasetRef( - value, type_, dimensionIds, simple.dimensionRecords, reconstitutedDimensions, universe + value, type_, dimensionIds, simple.dimensionRecords, universe ) initInputs[type_] = rebuiltDatasetRef @@ -438,17 +436,12 @@ def from_simple( for container, simpleRefs in ((inputs, simple.inputs), (outputs, simple.outputs)): for key, values in simpleRefs.items(): - # If a datasetType has already been created use that instead of - # unpersisting. - if (type_ := loadedTypes.get(key)) is None: - type_ = loadedTypes.setdefault( - key, DatasetType.from_simple(simple.datasetTypeMapping[key], universe=universe) - ) + type_ = DatasetType.from_simple(simple.datasetTypeMapping[key], universe=universe) # reconstruct the list of DatasetRefs for this DatasetType tmp: list[DatasetRef] = [] for v, recIds in values: rebuiltDatasetRef = _reconstructDatasetRef( - v, type_, recIds, simple.dimensionRecords, reconstitutedDimensions, universe + v, type_, recIds, simple.dimensionRecords, universe ) tmp.append(rebuiltDatasetRef) container[type_] = tmp @@ -466,7 +459,7 @@ def from_simple( for datastore_name, record_data in simple.datastoreRecords.items() } - return Quantum( + quant = Quantum( taskName=simple.taskName, dataId=dataId, initInputs=initInputs, @@ -474,6 +467,7 @@ def from_simple( outputs=outputs, datastore_records=datastore_records, ) + return quant @property def taskClass(self) -> type | None: @@ -508,7 +502,7 @@ def initInputs(self) -> NamedKeyMapping[DatasetType, DatasetRef]: return self._initInputs @property - def inputs(self) -> NamedKeyMapping[DatasetType, list[DatasetRef]]: + def inputs(self) -> NamedKeyMapping[DatasetType, tuple[DatasetRef]]: """Return mapping of input datasets that were expected to be used. Has `DatasetType` instances as keys (names can also be used for @@ -523,7 +517,7 @@ def inputs(self) -> NamedKeyMapping[DatasetType, list[DatasetRef]]: return self._inputs @property - def outputs(self) -> NamedKeyMapping[DatasetType, list[DatasetRef]]: + def outputs(self) -> NamedKeyMapping[DatasetType, tuple[DatasetRef]]: """Return mapping of output datasets (to be) generated by this quantum. Has the same form as `predictedInputs`.