Skip to content

Commit 46c8d62

Browse files
committed
✏️Fix model_validate in presence of inherited Relationship fields, add unit test
1 parent 594aac3 commit 46c8d62

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

sqlmodel/main.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,28 @@ def __new__(
503503
**kwargs: Any,
504504
) -> Any:
505505
relationships: Dict[str, RelationshipInfo] = {}
506+
backup_base_annotations: Dict[Type[Any], Dict[str, Any]] = {}
506507
for base in bases:
507-
relationships.update(getattr(base, "__sqlmodel_relationships__", {}))
508+
base_relationships = getattr(base, "__sqlmodel_relationships__", None)
509+
if base_relationships:
510+
relationships.update(base_relationships)
511+
#
512+
# Temporarily pluck out `__annotations__` corresponding to relationships from base classes, otherwise these annotations
513+
# make their way into `cls.model_fields` as `FieldInfo(..., required=True)`, even when the relationships are declared
514+
# optional. When a model instance is then constructed using `model_validate` and an optional relationship field is not
515+
# passed, this leads to an incorrect `pydantic.ValidationError`.
516+
#
517+
# We can't just clean up `new_cls.model_fields` after `new_cls` is constructed because by this time
518+
# Pydantic has created model schema and validation rules, so this won't fix the problem.
519+
#
520+
base_annotations = getattr(base, "__annotations__", None)
521+
if base_annotations:
522+
backup_base_annotations[base] = base_annotations
523+
base.__annotations__ = {
524+
name: typ
525+
for name, typ in base_annotations.items()
526+
if name not in base_relationships
527+
}
508528
dict_for_pydantic = {}
509529
original_annotations = get_annotations(class_dict)
510530
pydantic_annotations = {}
@@ -539,6 +559,9 @@ def __new__(
539559
key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs
540560
}
541561
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
562+
# Restore base annotations
563+
for base, annotations in backup_base_annotations.items():
564+
base.__annotations__ = annotations
542565
new_cls.__annotations__ = {
543566
**relationship_annotations,
544567
**pydantic_annotations,

tests/test_relationship_inheritance.py renamed to tests/test_inherit_relationship.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import datetime
22
from typing import Optional
33

4+
import pydantic
45
from sqlalchemy import DateTime, func
56
from sqlalchemy.orm import declared_attr, relationship
67
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
8+
from sqlmodel._compat import IS_PYDANTIC_V2
79

810

9-
def test_relationship_inheritance() -> None:
11+
def test_inherit_relationship(clear_sqlmodel) -> None:
1012
def now():
1113
return datetime.datetime.now(tz=datetime.timezone.utc)
1214

@@ -90,3 +92,57 @@ class Document(CreatedUpdatedMixin, table=True):
9092
doc = session.exec(select(Document)).one()
9193
assert doc.created_by.name == "Jane"
9294
assert doc.updated_by.name == "John"
95+
96+
97+
def test_inherit_relationship_model_validate(clear_sqlmodel) -> None:
98+
class User(SQLModel, table=True):
99+
id: Optional[int] = Field(default=None, primary_key=True)
100+
101+
class Mixin(SQLModel):
102+
owner_id: Optional[int] = Field(default=None, foreign_key="user.id")
103+
owner: Optional[User] = Relationship(
104+
sa_relationship=declared_attr(
105+
lambda cls: relationship(User, foreign_keys=cls.owner_id)
106+
)
107+
)
108+
109+
class Asset(Mixin, table=True):
110+
id: Optional[int] = Field(default=None, primary_key=True)
111+
112+
class AssetCreate(pydantic.BaseModel):
113+
pass
114+
115+
asset_create = AssetCreate()
116+
117+
engine = create_engine("sqlite://")
118+
119+
SQLModel.metadata.create_all(engine)
120+
121+
user = User()
122+
123+
# Owner must be optional
124+
asset = Asset.model_validate(asset_create)
125+
with Session(engine) as session:
126+
session.add(asset)
127+
session.commit()
128+
session.refresh(asset)
129+
assert asset.id is not None
130+
assert asset.owner_id is None
131+
assert asset.owner is None
132+
133+
# When set, owner must be saved
134+
#
135+
# Under Pydantic V2, relationship fields set it `model_validate` are not saved,
136+
# with or without inheritance. Consider it a known issue.
137+
#
138+
if IS_PYDANTIC_V2:
139+
asset = Asset.model_validate(asset_create, update={"owner": user})
140+
with Session(engine) as session:
141+
session.add(asset)
142+
session.commit()
143+
session.refresh(asset)
144+
session.refresh(user)
145+
assert asset.id is not None
146+
assert user.id is not None
147+
assert asset.owner_id == user.id
148+
assert asset.owner.id == user.id

0 commit comments

Comments
 (0)