|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import ipaddress |
| 4 | +import sys |
| 5 | +import types |
| 6 | +import typing |
4 | 7 | import uuid |
5 | 8 | import weakref |
6 | 9 | from datetime import date, datetime, time, timedelta |
|
27 | 30 | overload, |
28 | 31 | ) |
29 | 32 |
|
| 33 | +import typing_extensions |
30 | 34 | from pydantic import BaseModel, EmailStr |
31 | 35 | from pydantic.fields import FieldInfo as PydanticFieldInfo |
32 | 36 | from sqlalchemy import ( |
@@ -519,7 +523,7 @@ def __new__( |
519 | 523 | if k in relationships: |
520 | 524 | relationship_annotations[k] = v |
521 | 525 | else: |
522 | | - pydantic_annotations[k] = v |
| 526 | + pydantic_annotations[k] = resolve_type_alias(v) |
523 | 527 | dict_used = { |
524 | 528 | **dict_for_pydantic, |
525 | 529 | "__weakref__": None, |
@@ -763,6 +767,54 @@ def get_column_from_field(field: Any) -> Column: # type: ignore |
763 | 767 | return Column(sa_type, *args, **kwargs) # type: ignore |
764 | 768 |
|
765 | 769 |
|
| 770 | +def _is_typing_type_instance(annotation: Any, type_name: str) -> bool: |
| 771 | + check_type = [] |
| 772 | + if hasattr(typing, type_name): |
| 773 | + check_type.append(getattr(typing, type_name)) |
| 774 | + if hasattr(typing_extensions, type_name): |
| 775 | + check_type.append(getattr(typing_extensions, type_name)) |
| 776 | + |
| 777 | + return bool(check_type) and isinstance(annotation, tuple(check_type)) |
| 778 | + |
| 779 | + |
| 780 | +def _is_new_type_instance(annotation: Any) -> bool: |
| 781 | + if sys.version_info >= (3, 10): |
| 782 | + return _is_typing_type_instance(annotation, "NewType") |
| 783 | + else: |
| 784 | + return hasattr(annotation, "__supertype__") |
| 785 | + |
| 786 | + |
| 787 | +def _is_type_var_instance(annotation: Any) -> bool: |
| 788 | + return _is_typing_type_instance(annotation, "TypeVar") |
| 789 | + |
| 790 | + |
| 791 | +def _is_type_alias_type_instance(annotation: Any) -> bool: |
| 792 | + if sys.version_info[:2] == (3, 10): |
| 793 | + if type(annotation) is types.GenericAlias: |
| 794 | + # In Python 3.10, GenericAlias instances are of type TypeAliasType |
| 795 | + return False |
| 796 | + |
| 797 | + return _is_typing_type_instance(annotation, "TypeAliasType") |
| 798 | + |
| 799 | + |
| 800 | +def resolve_type_alias(annotation: Any) -> Any: |
| 801 | + if _is_type_var_instance(annotation): |
| 802 | + resolution = annotation.__bound__ |
| 803 | + if not annotation: |
| 804 | + raise ValueError( |
| 805 | + "TypeVars without a bound type cannot be converted to SQLAlchemy types" |
| 806 | + ) |
| 807 | + # annotations.__constraints__ could be used and defined Union[*constraints], but ORM does not support it |
| 808 | + elif _is_new_type_instance(annotation): |
| 809 | + resolution = annotation.__supertype__ |
| 810 | + elif _is_type_alias_type_instance(annotation): |
| 811 | + resolution = annotation.__value__ |
| 812 | + else: |
| 813 | + resolution = annotation |
| 814 | + |
| 815 | + return resolution |
| 816 | + |
| 817 | + |
766 | 818 | class_registry = weakref.WeakValueDictionary() # type: ignore |
767 | 819 |
|
768 | 820 | default_registry = registry() |
|
0 commit comments