Skip to content

Commit

Permalink
Use model_construct for pydantic v2
Browse files Browse the repository at this point in the history
This is meant to be the rquivalent of using __setattr__
and __field_set__ directly. The latter has been renamed in
pydantic v2
  • Loading branch information
timj committed Jul 18, 2023
1 parent 7e19163 commit a4fcc53
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 129 deletions.
53 changes: 37 additions & 16 deletions python/lsst/daf/butler/_quantum_backed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from typing import TYPE_CHECKING, Any

from deprecated.sphinx import deprecated
from lsst.daf.butler._compat import _BaseModelCompat
from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat
from lsst.resources import ResourcePathExpression

from ._butlerConfig import ButlerConfig
Expand Down Expand Up @@ -745,19 +745,40 @@ def _to_uuid_set(uuids: Iterable[str | uuid.UUID]) -> set[uuid.UUID]:
"""
return {uuid.UUID(id) if isinstance(id, str) else id for id in uuids}

data = QuantumProvenanceData.__new__(cls)
setter = object.__setattr__
setter(data, "predicted_inputs", _to_uuid_set(predicted_inputs))
setter(data, "available_inputs", _to_uuid_set(available_inputs))
setter(data, "actual_inputs", _to_uuid_set(actual_inputs))
setter(data, "predicted_outputs", _to_uuid_set(predicted_outputs))
setter(data, "actual_outputs", _to_uuid_set(actual_outputs))
setter(
data,
"datastore_records",
{
key: SerializedDatastoreRecordData.direct(**records)
for key, records in datastore_records.items()
},
)
if PYDANTIC_V2:
data = cls.model_construct(
_fields_set={
"predicted_inputs",
"available_inputs",
"actual_inputs",
"predicted_outputs",
"actual_outputs",
"datastore_records",
},
predicted_inputs=_to_uuid_set(predicted_inputs),
available_inputs=_to_uuid_set(available_inputs),
actual_inputs=_to_uuid_set(actual_inputs),
predicted_outputs=_to_uuid_set(predicted_outputs),
actual_outputs=_to_uuid_set(actual_outputs),
datastore_records={
key: SerializedDatastoreRecordData.direct(**records)
for key, records in datastore_records.items()
},
)
else:
data = QuantumProvenanceData.__new__(cls)
setter = object.__setattr__
setter(data, "predicted_inputs", _to_uuid_set(predicted_inputs))
setter(data, "available_inputs", _to_uuid_set(available_inputs))
setter(data, "actual_inputs", _to_uuid_set(actual_inputs))
setter(data, "predicted_outputs", _to_uuid_set(predicted_outputs))
setter(data, "actual_outputs", _to_uuid_set(actual_outputs))
setter(
data,
"datastore_records",
{
key: SerializedDatastoreRecordData.direct(**records)
for key, records in datastore_records.items()
},
)
return data
40 changes: 26 additions & 14 deletions python/lsst/daf/butler/core/datasets/ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
import sys
import uuid
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, runtime_checkable

from lsst.daf.butler._compat import _BaseModelCompat
from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat
from lsst.utils.classes import immutable
from pydantic import StrictStr, validator

Expand Down Expand Up @@ -221,22 +221,34 @@ def direct(
This method should only be called when the inputs are trusted.
"""
node = SerializedDatasetRef.__new__(cls)
setter = object.__setattr__
setter(node, "id", uuid.UUID(id))
setter(
node,
"datasetType",
datasetType if datasetType is None else SerializedDatasetType.direct(**datasetType),
serialized_datasetType = (
SerializedDatasetType.direct(**datasetType) if datasetType is not None else None
)
setter(node, "dataId", dataId if dataId is None else SerializedDataCoordinate.direct(**dataId))
setter(node, "run", sys.intern(run))
setter(node, "component", component)
setter(node, "__fields_set__", _serializedDatasetRefFieldsSet)
serialized_dataId = SerializedDataCoordinate.direct(**dataId) if dataId is not None else None

if PYDANTIC_V2:
node = cls.model_construct(
_fields_set=_serializedDatasetRefFieldsSet,
id=uuid.UUID(id),
datasetType=serialized_datasetType,
dataId=serialized_dataId,
run=sys.intern(run),
component=component,
)
else:
node = SerializedDatasetRef.__new__(cls)
setter = object.__setattr__
setter(node, "id", uuid.UUID(id))
setter(node, "datasetType", serialized_datasetType)
setter(node, "dataId", serialized_dataId)
setter(node, "run", sys.intern(run))
setter(node, "component", component)
setter(node, "__fields_set__", _serializedDatasetRefFieldsSet)

return node


DatasetId = uuid.UUID
DatasetId: TypeAlias = uuid.UUID
"""A type-annotation alias for dataset ID providing typing flexibility.
"""

Expand Down
43 changes: 27 additions & 16 deletions python/lsst/daf/butler/core/datasets/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, ClassVar

from lsst.daf.butler._compat import _BaseModelCompat
from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat
from pydantic import StrictBool, StrictStr

from ..configSupport import LookupKey
Expand Down Expand Up @@ -80,22 +80,33 @@ def direct(
key = (name, storageClass or "")
if cache is not None and (type_ := cache.get(key, None)) is not None:
return type_
node = SerializedDatasetType.__new__(cls)
setter = object.__setattr__
setter(node, "name", name)
setter(node, "storageClass", storageClass)
setter(
node,
"dimensions",
dimensions if dimensions is None else SerializedDimensionGraph.direct(**dimensions),
)
setter(node, "parentStorageClass", parentStorageClass)
setter(node, "isCalibration", isCalibration)
setter(
node,
"__fields_set__",
{"name", "storageClass", "dimensions", "parentStorageClass", "isCalibration"},

serialized_dimensions = (
SerializedDimensionGraph.direct(**dimensions) if dimensions is not None else None
)

if PYDANTIC_V2:
node = cls.model_construct(
name=name,
storageClass=storageClass,
dimensions=serialized_dimensions,
parentStorageClass=parentStorageClass,
isCalibration=isCalibration,
)
else:
node = SerializedDatasetType.__new__(cls)
setter = object.__setattr__
setter(node, "name", name)
setter(node, "storageClass", storageClass)
setter(node, "dimensions", serialized_dimensions)
setter(node, "parentStorageClass", parentStorageClass)
setter(node, "isCalibration", isCalibration)
setter(
node,
"__fields_set__",
{"name", "storageClass", "dimensions", "parentStorageClass", "isCalibration"},
)

if cache is not None:
cache[key] = node
return node
Expand Down
21 changes: 15 additions & 6 deletions python/lsst/daf/butler/core/datastoreRecordData.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any

from lsst.daf.butler._compat import _BaseModelCompat
from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat
from lsst.utils import doImportType
from lsst.utils.introspection import get_full_type_name

Expand Down Expand Up @@ -71,10 +71,6 @@ def direct(
This method should only be called when the inputs are trusted.
"""
data = SerializedDatastoreRecordData.__new__(cls)
setter = object.__setattr__
# JSON makes strings out of UUIDs, need to convert them back
setter(data, "dataset_ids", [uuid.UUID(id) if isinstance(id, str) else id for id in dataset_ids])
# See also comments in record_ids_to_uuid()
for table_data in records.values():
for table_records in table_data.values():
Expand All @@ -83,7 +79,20 @@ def direct(
# columns that are UUIDs we'd need more generic approach.
if (id := record.get("dataset_id")) is not None:
record["dataset_id"] = uuid.UUID(id) if isinstance(id, str) else id
setter(data, "records", records)

if PYDANTIC_V2:
print("INSIDE RECORD DATA DIRECT")
data = cls.model_construct(
_fields_set={"dataset_ids", "records"},
dataset_ids=[uuid.UUID(id) if isinstance(id, str) else id for id in dataset_ids],
records=records,
)
else:
data = SerializedDatastoreRecordData.__new__(cls)
setter = object.__setattr__
# JSON makes strings out of UUIDs, need to convert them back
setter(data, "dataset_ids", [uuid.UUID(id) if isinstance(id, str) else id for id in dataset_ids])
setter(data, "records", records)
return data


Expand Down
28 changes: 16 additions & 12 deletions python/lsst/daf/butler/core/dimensions/_coordinate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from typing import TYPE_CHECKING, Any, ClassVar, Literal, overload

from deprecated.sphinx import deprecated
from lsst.daf.butler._compat import _BaseModelCompat
from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat
from lsst.sphgeom import IntersectionRegion, Region

from ..json import from_json_pydantic, to_json_pydantic
Expand Down Expand Up @@ -81,17 +81,21 @@ def direct(cls, *, dataId: dict[str, DataIdValue], records: dict[str, dict]) ->
cache = PersistenceContextVars.serializedDataCoordinateMapping.get()
if cache is not None and (result := cache.get(key)) is not None:
return result
node = SerializedDataCoordinate.__new__(cls)
setter = object.__setattr__
setter(node, "dataId", dataId)
setter(
node,
"records",
records
if records is None
else {k: SerializedDimensionRecord.direct(**v) for k, v in records.items()},
)
setter(node, "__fields_set__", {"dataId", "records"})

if records is None:
serialized_records = None
else:
serialized_records = {k: SerializedDimensionRecord.direct(**v) for k, v in records.items()}

if PYDANTIC_V2:
node = cls.model_construct(dataId=dataId, records=serialized_records)
else:
node = SerializedDataCoordinate.__new__(cls)
setter = object.__setattr__
setter(node, "dataId", dataId)
setter(node, "records", serialized_records)
setter(node, "__fields_set__", {"dataId", "records"})

if cache is not None:
cache[key] = node
return node
Expand Down
13 changes: 8 additions & 5 deletions python/lsst/daf/butler/core/dimensions/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, ClassVar

from lsst.daf.butler._compat import _BaseModelCompat
from lsst.daf.butler._compat import PYDANTIC_V2, _BaseModelCompat
from lsst.utils.classes import cached_getter, immutable

from .._topology import TopologicalFamily, TopologicalSpace
Expand Down Expand Up @@ -57,10 +57,13 @@ def direct(cls, *, names: list[str]) -> SerializedDimensionGraph:
This method should only be called when the inputs are trusted.
"""
node = SerializedDimensionGraph.__new__(cls)
object.__setattr__(node, "names", names)
object.__setattr__(node, "__fields_set__", {"names"})
return node
if PYDANTIC_V2:
return cls.model_construct(names=names)
else:
node = SerializedDimensionGraph.__new__(cls)
object.__setattr__(node, "names", names)
object.__setattr__(node, "__fields_set__", {"names"})
return node


@immutable
Expand Down
20 changes: 13 additions & 7 deletions python/lsst/daf/butler/core/dimensions/_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,22 @@ def direct(
cache = PersistenceContextVars.serializedDimensionRecordMapping.get()
if cache is not None and (result := cache.get(key)) is not None:
return result
node = SerializedDimensionRecord.__new__(cls)
setter = object.__setattr__
setter(node, "definition", definition)

# This method requires tuples as values of the mapping, but JSON
# readers will read things in as lists. Be kind and transparently
# transform to tuples
setter(
node, "record", {k: v if type(v) != list else tuple(v) for k, v in record.items()} # type: ignore
)
setter(node, "__fields_set__", {"definition", "record"})
serialized_record = {k: v if type(v) != list else tuple(v) for k, v in record.items()}

if PYDANTIC_V2:
node = cls.model_construct(definition=definition, record=serialized_record)
else:
node = SerializedDimensionRecord.__new__(cls)
setter = object.__setattr__
setter(node, "definition", definition)
setter(node, "record", serialized_record) # type: ignore

setter(node, "__fields_set__", {"definition", "record"})

if cache is not None:
cache[key] = node
return node
Expand Down
Loading

0 comments on commit a4fcc53

Please sign in to comment.