|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +from sqlalchemy.orm import declared_attr, relationship |
| 4 | +from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select |
| 5 | + |
| 6 | + |
| 7 | +def test_relationship_inheritance() -> None: |
| 8 | + class User(SQLModel, table=True): |
| 9 | + id: Optional[int] = Field(default=None, primary_key=True) |
| 10 | + name: str |
| 11 | + |
| 12 | + class CreatedUpdatedMixin(SQLModel): |
| 13 | + # With Pydantic V2, it is also possible to define `created_by` like this: |
| 14 | + # |
| 15 | + # ```python |
| 16 | + # @declared_attr |
| 17 | + # def _created_by(cls): |
| 18 | + # return relationship(User, foreign_keys=cls.created_by_id) |
| 19 | + # |
| 20 | + # created_by: Optional[User] = Relationship(sa_relationship=_created_by)) |
| 21 | + # ``` |
| 22 | + # |
| 23 | + # The difference from Pydantic V1 is that Pydantic V2 plucks attributes with names starting with '_' (but not '__') |
| 24 | + # from class attributes and stores them separately as instances of `pydantic.ModelPrivateAttr` somewhere in depths of |
| 25 | + # Pydantic internals. Under Pydantic V1 this doesn't happen, so SQLAlchemy ends up having two class attributes |
| 26 | + # (`_created_by` and `created_by`) corresponding to one database attribute, causing a conflict and unreliable behavior. |
| 27 | + # The approach with a lambda always works because it doesn't produce the second class attribute and thus eliminates |
| 28 | + # the possibility of a conflict entirely. |
| 29 | + # |
| 30 | + created_by_id: Optional[int] = Field(default=None, foreign_key="user.id") |
| 31 | + created_by: Optional[User] = Relationship( |
| 32 | + sa_relationship=declared_attr( |
| 33 | + lambda cls: relationship(User, foreign_keys=cls.created_by_id) |
| 34 | + ) |
| 35 | + ) |
| 36 | + |
| 37 | + updated_by_id: Optional[int] = Field(default=None, foreign_key="user.id") |
| 38 | + updated_by: Optional[User] = Relationship( |
| 39 | + sa_relationship=declared_attr( |
| 40 | + lambda cls: relationship(User, foreign_keys=cls.updated_by_id) |
| 41 | + ) |
| 42 | + ) |
| 43 | + |
| 44 | + class Asset(CreatedUpdatedMixin, table=True): |
| 45 | + id: Optional[int] = Field(default=None, primary_key=True) |
| 46 | + |
| 47 | + engine = create_engine("sqlite://") |
| 48 | + |
| 49 | + SQLModel.metadata.create_all(engine) |
| 50 | + |
| 51 | + john = User(name="John") |
| 52 | + jane = User(name="Jane") |
| 53 | + asset = Asset(created_by=john, updated_by=jane) |
| 54 | + |
| 55 | + with Session(engine) as session: |
| 56 | + session.add(asset) |
| 57 | + session.commit() |
| 58 | + |
| 59 | + with Session(engine) as session: |
| 60 | + asset = session.exec(select(Asset)).one() |
| 61 | + assert asset.created_by.name == "John" |
| 62 | + assert asset.updated_by.name == "Jane" |
0 commit comments