Skip to content

Commit 1e9cf12

Browse files
committed
📝 Refactor type alias handling to improve compatibility with Python 3.10 and enhance type resolution logic
1 parent d128318 commit 1e9cf12

File tree

3 files changed

+54
-46
lines changed

3 files changed

+54
-46
lines changed

sqlmodel/_compat.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import sys
22
import types
3-
import typing
43
from contextlib import contextmanager
54
from contextvars import ContextVar
65
from dataclasses import dataclass
@@ -20,7 +19,6 @@
2019
Union,
2120
)
2221

23-
import typing_extensions
2422
from pydantic import VERSION as P_VERSION
2523
from pydantic import BaseModel
2624
from pydantic.fields import FieldInfo
@@ -200,47 +198,10 @@ def is_field_noneable(field: "FieldInfo") -> bool:
200198
return False
201199
return False
202200

203-
def _is_typing_type_instance(annotation: Any, type_name: str) -> bool:
204-
check_type = []
205-
if hasattr(typing, type_name):
206-
check_type.append(getattr(typing, type_name))
207-
if hasattr(typing_extensions, type_name):
208-
check_type.append(getattr(typing_extensions, type_name))
209-
210-
return bool(check_type) and isinstance(annotation, tuple(check_type))
211-
212-
def _is_new_type_instance(annotation: Any) -> bool:
213-
if sys.version_info >= (3, 10):
214-
return _is_typing_type_instance(annotation, "NewType")
215-
else:
216-
return hasattr(annotation, "__supertype__")
217-
218-
def _is_type_var_instance(annotation: Any) -> bool:
219-
return _is_typing_type_instance(annotation, "TypeVar")
220-
221-
def _is_type_alias_type_instance(annotation: Any) -> bool:
222-
if sys.version_info[:2] == (3, 10):
223-
if type(annotation) is types.GenericAlias:
224-
# In Python 3.10, GenericAlias instances are of type TypeAliasType
225-
return False
226-
227-
return _is_typing_type_instance(annotation, "TypeAliasType")
228-
229201
def get_sa_type_from_type_annotation(annotation: Any) -> Any:
230202
# Resolve Optional fields
231203
if annotation is None:
232204
raise ValueError("Missing field type")
233-
if _is_type_var_instance(annotation):
234-
annotation = annotation.__bound__
235-
if not annotation:
236-
raise ValueError(
237-
"TypeVars without a bound type cannot be converted to SQLAlchemy types"
238-
)
239-
# annotations.__constraints__ could be used and defined Union[*constraints], but ORM does not support it
240-
elif _is_new_type_instance(annotation):
241-
annotation = annotation.__supertype__
242-
elif _is_type_alias_type_instance(annotation):
243-
annotation = annotation.__value__
244205
origin = get_origin(annotation)
245206
if origin is None:
246207
return annotation

sqlmodel/main.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from __future__ import annotations
22

33
import ipaddress
4+
import sys
5+
import types
6+
import typing
47
import uuid
58
import weakref
69
from datetime import date, datetime, time, timedelta
@@ -27,6 +30,7 @@
2730
overload,
2831
)
2932

33+
import typing_extensions
3034
from pydantic import BaseModel, EmailStr
3135
from pydantic.fields import FieldInfo as PydanticFieldInfo
3236
from sqlalchemy import (
@@ -519,7 +523,7 @@ def __new__(
519523
if k in relationships:
520524
relationship_annotations[k] = v
521525
else:
522-
pydantic_annotations[k] = v
526+
pydantic_annotations[k] = resolve_type_alias(v)
523527
dict_used = {
524528
**dict_for_pydantic,
525529
"__weakref__": None,
@@ -763,6 +767,54 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
763767
return Column(sa_type, *args, **kwargs) # type: ignore
764768

765769

770+
def _is_typing_type_instance(annotation: Any, type_name: str) -> bool:
771+
check_type = []
772+
if hasattr(typing, type_name):
773+
check_type.append(getattr(typing, type_name))
774+
if hasattr(typing_extensions, type_name):
775+
check_type.append(getattr(typing_extensions, type_name))
776+
777+
return bool(check_type) and isinstance(annotation, tuple(check_type))
778+
779+
780+
def _is_new_type_instance(annotation: Any) -> bool:
781+
if sys.version_info >= (3, 10):
782+
return _is_typing_type_instance(annotation, "NewType")
783+
else:
784+
return hasattr(annotation, "__supertype__")
785+
786+
787+
def _is_type_var_instance(annotation: Any) -> bool:
788+
return _is_typing_type_instance(annotation, "TypeVar")
789+
790+
791+
def _is_type_alias_type_instance(annotation: Any) -> bool:
792+
if sys.version_info[:2] == (3, 10):
793+
if type(annotation) is types.GenericAlias:
794+
# In Python 3.10, GenericAlias instances are of type TypeAliasType
795+
return False
796+
797+
return _is_typing_type_instance(annotation, "TypeAliasType")
798+
799+
800+
def resolve_type_alias(annotation: Any) -> Any:
801+
if _is_type_var_instance(annotation):
802+
resolution = annotation.__bound__
803+
if not annotation:
804+
raise ValueError(
805+
"TypeVars without a bound type cannot be converted to SQLAlchemy types"
806+
)
807+
# annotations.__constraints__ could be used and defined Union[*constraints], but ORM does not support it
808+
elif _is_new_type_instance(annotation):
809+
resolution = annotation.__supertype__
810+
elif _is_type_alias_type_instance(annotation):
811+
resolution = annotation.__value__
812+
else:
813+
resolution = annotation
814+
815+
return resolution
816+
817+
766818
class_registry = weakref.WeakValueDictionary() # type: ignore
767819

768820
default_registry = registry()

tests/test_field_sa_type.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import typing_extensions as te
66
from sqlmodel import Field, SQLModel
77

8-
from tests.conftest import needs_py312, needs_pydanticv2
8+
from tests.conftest import needs_py312
99

1010

1111
def test_sa_type_typing_1() -> None:
@@ -44,7 +44,6 @@ class Hero(SQLModel, table=True):
4444

4545

4646
@needs_py312
47-
@needs_pydanticv2
4847
def test_sa_type_typing_5() -> None:
4948
test_code = dedent("""
5049
type Type5_t = str
@@ -57,7 +56,6 @@ class Hero(SQLModel, table=True):
5756

5857

5958
@needs_py312
60-
@needs_pydanticv2
6159
def test_sa_type_typing_6() -> None:
6260
test_code = dedent("""
6361
type Type6_t = t.Annotated[str, "Just a comment"]
@@ -131,7 +129,6 @@ class Hero(SQLModel, table=True):
131129

132130

133131
@needs_py312
134-
@needs_pydanticv2
135132
def test_sa_type_typing_extensions_5() -> None:
136133
test_code = dedent("""
137134
type Type5_te = str
@@ -144,7 +141,6 @@ class Hero(SQLModel, table=True):
144141

145142

146143
@needs_py312
147-
@needs_pydanticv2
148144
def test_sa_type_typing_extensions_6() -> None:
149145
test_code = dedent("""
150146
type Type6_te = te.Annotated[str, "Just a comment"]
@@ -156,7 +152,6 @@ class Hero(SQLModel, table=True):
156152
exec(test_code, globals())
157153

158154

159-
@needs_pydanticv2
160155
def test_sa_type_typing_extensions_7() -> None:
161156
Type7_te = te.NewType("Type7_te", str)
162157

0 commit comments

Comments
 (0)