Skip to content

Commit

Permalink
Pydantic 2 support (#847)
Browse files Browse the repository at this point in the history
* Update api utils to work with pydantic 2

* Fix api sanitize

* Fix MPDataDoc creation

* Lazy load all nested resters

* Bump to python>=3.9

* Cache api sanitize

* Cache emmet version retrieval

* Fix client tests

* Migrate __fields__

* Change materials test input

* Move flat models util func to emmet

* api_sanitize allow_dict behavior

* Bump emmet

* Update maggma util import

* Linting and deprecated changes

* Spelling

* Final linting

* Bump emmet

* Linting

* Remove repeat code

* Linting

* Add maggma to deps
  • Loading branch information
Jason Munro authored Sep 27, 2023
1 parent ffd9713 commit ff7f643
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 41 deletions.
22 changes: 11 additions & 11 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def _submit_request_and_process(
data_model(
**{
field: value
for field, value in raw_doc.dict().items()
for field, value in raw_doc.model_dump().items()
if field in set_fields
}
)
Expand Down Expand Up @@ -877,29 +877,29 @@ def _submit_request_and_process(

def _generate_returned_model(self, doc):
set_fields = [
field for field, _ in doc if field in doc.dict(exclude_unset=True)
field for field, _ in doc if field in doc.model_dump(exclude_unset=True)
]
unset_fields = [field for field in doc.__fields__ if field not in set_fields]
unset_fields = [field for field in doc.model_fields if field not in set_fields]

data_model = create_model(
"MPDataDoc",
fields_not_requested=unset_fields,
fields_not_requested=(list[str], unset_fields),
__base__=self.document_model,
)

data_model.__fields__ = {
data_model.model_fields = {
**{
name: description
for name, description in data_model.__fields__.items()
for name, description in data_model.model_fields.items()
if name in set_fields
},
"fields_not_requested": data_model.__fields__["fields_not_requested"],
"fields_not_requested": data_model.model_fields["fields_not_requested"],
}

def new_repr(self) -> str:
extra = ",\n".join(
f"\033[1m{n}\033[0;0m={getattr(self, n)!r}"
for n in data_model.__fields__
for n in data_model.model_fields
)

s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501
Expand All @@ -908,7 +908,7 @@ def new_repr(self) -> str:
def new_str(self) -> str:
extra = ",\n".join(
f"\033[1m{n}\033[0;0m={getattr(self, n)!r}"
for n in data_model.__fields__
for n in data_model.model_fields
if n != "fields_not_requested"
)

Expand All @@ -927,7 +927,7 @@ def new_getattr(self, attr) -> str:
)

def new_dict(self, *args, **kwargs):
d = super(data_model, self).dict(*args, **kwargs)
d = super(data_model, self).model_dump(*args, **kwargs)
return jsanitize(d)

data_model.__repr__ = new_repr
Expand Down Expand Up @@ -1155,7 +1155,7 @@ def count(self, criteria: dict | None = None) -> int | str:
def available_fields(self) -> list[str]:
if self.document_model is None:
return ["Unknown fields."]
return list(self.document_model.schema()["properties"].keys()) # type: ignore
return list(self.document_model.model_json_schema()["properties"].keys()) # type: ignore

def __repr__(self): # pragma: no cover
return f"<{self.__class__.__name__} {self.endpoint}>"
Expand Down
3 changes: 2 additions & 1 deletion mp_api/client/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from multiprocessing import cpu_count
from typing import List

from pydantic import BaseSettings, Field
from pydantic import Field
from pydantic_settings import BaseSettings
from pymatgen.core import _load_pmg_settings

from mp_api.client import __file__ as root_dir
Expand Down
35 changes: 18 additions & 17 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import re
from functools import cache
from typing import get_args
from typing import Optional, get_args

from maggma.utils import get_flat_models_from_model
from monty.json import MSONable
from pydantic import BaseModel
from pydantic.schema import get_flat_models_from_model
from pydantic.utils import lenient_issubclass
from pydantic._internal._utils import lenient_issubclass
from pydantic.fields import FieldInfo


def validate_ids(id_list: list[str]):
Expand Down Expand Up @@ -62,33 +63,33 @@ def api_sanitize(

for model in models:
model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]}
for name, field in model.__fields__.items():
field_type = field.type_

if name not in model_fields_to_leave:
field.required = False
field.default = None
field.default_factory = None
field.allow_none = True
field.field_info.default = None
field.field_info.default_factory = None
for name in model.model_fields:
field = model.model_fields[name]
field_type = field.annotation

if field_type is not None and allow_dict_msonable:
if lenient_issubclass(field_type, MSONable):
field.type_ = allow_msonable_dict(field_type)
field_type = allow_msonable_dict(field_type)
else:
for sub_type in get_args(field_type):
if lenient_issubclass(sub_type, MSONable):
allow_msonable_dict(sub_type)
field.populate_validators()

if name not in model_fields_to_leave:
new_field = FieldInfo.from_annotated_attribute(
Optional[field_type], None
)
model.model_fields[name] = new_field

model.model_rebuild(force=True)

return pydantic_model


def allow_msonable_dict(monty_cls: type[MSONable]):
"""Patch Monty to allow for dict values for MSONable."""

def validate_monty(cls, v):
def validate_monty(cls, v, _):
"""Stub validator for MSONable as a dictionary only."""
if isinstance(v, cls):
return v
Expand All @@ -110,6 +111,6 @@ def validate_monty(cls, v):
else:
raise ValueError(f"Must provide {cls.__name__} or MSONable dictionary")

monty_cls.validate_monty = classmethod(validate_monty)
monty_cls.validate_monty_v2 = classmethod(validate_monty)

return monty_cls
2 changes: 1 addition & 1 deletion mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def get_entries(
if property_data:
for property in property_data:
entry_dict["data"][property] = (
doc.dict()[property]
doc.model_dump()[property]
if self.use_document_model
else doc[property]
)
Expand Down
6 changes: 3 additions & 3 deletions mp_api/client/routes/materials/electronic_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def get_bandstructure_from_material_id(
f"No {path_type.value} band structure data found for {material_id}"
)
else:
bs_data = bs_data.dict()
bs_data = bs_data.model_dump()

if bs_data.get(path_type.value, None):
bs_task_id = bs_data[path_type.value]["task_id"]
Expand All @@ -303,7 +303,7 @@ def get_bandstructure_from_material_id(
f"No uniform band structure data found for {material_id}"
)
else:
bs_data = bs_data.dict()
bs_data = bs_data.model_dump()

if bs_data.get("total", None):
bs_task_id = bs_data["total"]["1"]["task_id"]
Expand Down Expand Up @@ -444,7 +444,7 @@ def get_dos_from_material_id(self, material_id: str):

dos_data = es_rester.get_data_by_id(
document_id=material_id, fields=["dos"]
).dict()
).model_dump()

if dos_data["dos"]:
dos_task_id = dos_data["dos"]["total"]["1"]["task_id"]
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,17 @@ classifiers = [
dependencies = [
"setuptools",
"msgpack",
"maggma",
"pymatgen>=2022.3.7",
"typing-extensions>=3.7.4.1",
"requests>=2.23.0",
"monty>=2021.3.12",
"emmet-core>=0.54.0",
"monty>=2023.9.25",
"emmet-core>=0.69.2",
]
dynamic = ["version"]

[project.optional-dependencies]
all = ["emmet-core[all]>=0.54.0", "custodian", "mpcontribs-client", "boto3"]
all = ["emmet-core[all]>=0.69.1", "custodian", "mpcontribs-client", "boto3"]
test = [
"pre-commit",
"pytest",
Expand Down
2 changes: 1 addition & 1 deletion tests/materials/core_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def client_search_testing(
"num_chunks": 1,
}

doc = search_method(**q)[0].dict()
doc = search_method(**q)[0].model_dump()

for sub_field in sub_doc_fields:
if sub_field in doc:
Expand Down
4 changes: 2 additions & 2 deletions tests/materials/test_electronic_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_bs_client(bs_rester):
"chunk_size": 1,
"num_chunks": 1,
}
doc = search_method(**q)[0].dict()
doc = search_method(**q)[0].model_dump()

for sub_field in bs_sub_doc_fields:
if sub_field in doc:
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_dos_client(dos_rester):
"chunk_size": 1,
"num_chunks": 1,
}
doc = search_method(**q)[0].dict()
doc = search_method(**q)[0].model_dump()
for sub_field in dos_sub_doc_fields:
if sub_field in doc:
doc = doc[sub_field]
Expand Down
2 changes: 1 addition & 1 deletion tests/molecules/core_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def client_search_testing(
docs = search_method(**q)

if len(docs) > 0:
doc = docs[0].dict()
doc = docs[0].model_dump()
else:
raise ValueError("No documents returned")

Expand Down
2 changes: 1 addition & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_generic_get_methods(rester):

if name not in search_only_resters:
doc = rester.get_data_by_id(
doc.dict()[rester.primary_key], fields=[rester.primary_key]
doc.model_dump()[rester.primary_key], fields=[rester.primary_key]
)
assert isinstance(doc, rester.document_model)

Expand Down

0 comments on commit ff7f643

Please sign in to comment.