Skip to content

Commit fc9a2e6

Browse files
committed
📝 Refactor type alias handling and add tests for NewType and TypeVar support
1 parent a05a685 commit fc9a2e6

File tree

2 files changed

+75
-13
lines changed

2 files changed

+75
-13
lines changed

‎sqlmodel/_compat.py‎

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
PYDANTIC_MINOR_VERSION = tuple(int(i) for i in P_VERSION.split(".")[:2])
3232
IS_PYDANTIC_V2 = PYDANTIC_MINOR_VERSION[0] == 2
3333

34-
3534
if TYPE_CHECKING:
3635
from .main import RelationshipInfo, SQLModel
3736

@@ -201,29 +200,43 @@ def is_field_noneable(field: "FieldInfo") -> bool:
201200
return False
202201
return False
203202

204-
def _is_type_alias_type_instance(annotation: Any) -> bool:
205-
type_to_check = "TypeAliasType"
206-
in_typing = hasattr(typing, type_to_check)
207-
in_typing_extensions = hasattr(typing_extensions, type_to_check)
208-
203+
def _is_typing_type_instance(annotation: Any, type_name: str) -> bool:
209204
check_type = []
210-
if in_typing:
211-
check_type.append(typing.TypeAliasType)
212-
if in_typing_extensions:
213-
check_type.append(typing_extensions.TypeAliasType)
205+
if hasattr(typing, type_name):
206+
check_type.append(getattr(typing, type_name))
207+
if hasattr(typing_extensions, type_name):
208+
check_type.append(getattr(typing_extensions, type_name))
209+
210+
return check_type and isinstance(annotation, tuple(check_type))
211+
212+
def _is_new_type_instance(annotation: Any) -> bool:
213+
return _is_typing_type_instance(annotation, "NewType")
214214

215+
def _is_type_var_instance(annotation: Any) -> bool:
216+
return _is_typing_type_instance(annotation, "TypeVar")
217+
218+
def _is_type_alias_type_instance(annotation: Any) -> bool:
215219
if sys.version_info[:2] == (3, 10):
216220
if type(annotation) is types.GenericAlias:
217-
# In Python 3.10, TypeAliasType instances are of type GenericAlias
221+
# In Python 3.10, GenericAlias instances are of type TypeAliasType
218222
return False
219223

220-
return check_type and isinstance(annotation, tuple(check_type))
224+
return _is_typing_type_instance(annotation, "TypeAliasType")
221225

222226
def get_sa_type_from_type_annotation(annotation: Any) -> Any:
223227
# Resolve Optional fields
224228
if annotation is None:
225229
raise ValueError("Missing field type")
226-
if _is_type_alias_type_instance(annotation):
230+
if _is_type_var_instance(annotation):
231+
annotation = annotation.__bound__
232+
if not annotation:
233+
raise ValueError(
234+
"TypeVars without a bound type cannot be converted to SQLAlchemy types"
235+
)
236+
# annotations.__constraints__ could be used and defined Union[*constraints], but ORM does not support it
237+
elif _is_new_type_instance(annotation):
238+
annotation = annotation.__supertype__
239+
elif _is_type_alias_type_instance(annotation):
227240
annotation = annotation.__value__
228241
origin = get_origin(annotation)
229242
if origin is None:

‎tests/test_field_sa_type.py‎

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import typing as t
22

3+
import pytest
34
import typing_extensions as te
45
from sqlmodel import Field, SQLModel
56

@@ -54,6 +55,29 @@ class Hero(SQLModel, table=True):
5455
weapon: Type6_t = "sword"
5556

5657

58+
def test_sa_type_typing_7() -> None:
59+
Type7_t = t.NewType("Type7_t", str)
60+
61+
class Hero(SQLModel, table=True):
62+
pk: int = Field(primary_key=True)
63+
weapon: Type7_t = "sword"
64+
65+
66+
def test_sa_type_typing_8() -> None:
67+
Type8_t = t.TypeVar("Type8_t", bound=str)
68+
69+
class Hero(SQLModel, table=True):
70+
pk: int = Field(primary_key=True)
71+
weapon: Type8_t = "sword"
72+
73+
def test_sa_type_typing_9() -> None:
74+
Type9_t = t.TypeVar("Type9_t", str, bytes)
75+
76+
with pytest.raises(ValueError):
77+
class Hero(SQLModel, table=True):
78+
pk: int = Field(primary_key=True)
79+
weapon: Type9_t = "sword"
80+
5781
def test_sa_type_typing_extensions_1() -> None:
5882
Type1_te = str
5983

@@ -102,3 +126,28 @@ def test_sa_type_typing_extensions_6() -> None:
102126
class Hero(SQLModel, table=True):
103127
pk: int = Field(primary_key=True)
104128
weapon: Type6_te = "sword"
129+
130+
131+
def test_sa_type_typing_extensions_7() -> None:
132+
Type7_te = te.NewType("Type7_te", str)
133+
134+
class Hero(SQLModel, table=True):
135+
pk: int = Field(primary_key=True)
136+
weapon: Type7_te = "sword"
137+
138+
139+
def test_sa_type_typing_extensions_8() -> None:
140+
Type8_te = te.TypeVar("Type8_te", bound=str)
141+
142+
class Hero(SQLModel, table=True):
143+
pk: int = Field(primary_key=True)
144+
weapon: Type8_te = "sword"
145+
146+
147+
def test_sa_type_typing_extensions_9() -> None:
148+
Type9_te = te.TypeVar("Type9_te", str, bytes)
149+
150+
with pytest.raises(ValueError):
151+
class Hero(SQLModel, table=True):
152+
pk: int = Field(primary_key=True)
153+
weapon: Type9_te = "sword"

0 commit comments

Comments
 (0)