diff --git a/.gitignore b/.gitignore index 4196e0fd..8add6e6b 100644 --- a/.gitignore +++ b/.gitignore @@ -142,3 +142,5 @@ cython_debug/ .benchmarks/ .python-version + +env3.7/ diff --git a/pydantic_redis/model.py b/pydantic_redis/model.py index 13f950c5..6dce89bc 100644 --- a/pydantic_redis/model.py +++ b/pydantic_redis/model.py @@ -50,25 +50,27 @@ def initialize(cls): is_generic = hasattr(field_type, "__origin__") if ( is_generic - and field_type.__origin__ == Union - and field_type.__args__[-1] == None.__class__ + and typing_get_origin(field_type) == Union + and typing_get_args(field_type)[-1] == None.__class__ ): - field_type = field_type.__args__[0] + field_type = typing_get_args(field_type)[0] is_generic = hasattr(field_type, "__origin__") if ( is_generic - and field_type.__origin__ == List - and issubclass(field_type.__args__[0], Model) + and typing_get_origin(field_type) in (List, list) + and issubclass(typing_get_args(field_type)[0], Model) ): - cls._nested_model_list_fields[field] = field_type.__args__[0] + cls._nested_model_list_fields[field] = typing_get_args(field_type)[ + 0 + ] elif ( is_generic - and field_type.__origin__ == Tuple - and any([issubclass(v, Model) for v in field_type.__args__]) + and typing_get_origin(field_type) in (Tuple, tuple) + and any([issubclass(v, Model) for v in typing_get_args(field_type)]) ): - cls._nested_model_tuple_fields[field] = field_type.__args__ + cls._nested_model_tuple_fields[field] = typing_get_args(field_type) elif issubclass(field_type, Model): cls._nested_model_fields[field] = field_type @@ -401,7 +403,7 @@ def __get_select_fields(cls, columns: Optional[List[str]]) -> Optional[List[str] if isinstance(field_type, type(Model)): fields.append(f"{NESTED_MODEL_PREFIX}{col}") elif issubclass(field_type, List) and isinstance( - field_type.__args__[0], type(Model) + typing_get_args(field_type)[0], type(Model) ): fields.append(f"{NESTED_MODEL_LIST_FIELD_PREFIX}{col}") else: @@ -487,3 +489,19 @@ def strip_leading(word: str, substring: str) -> str: if word.startswith(substring): return word[len(substring) :] return word + + +def typing_get_args(v: Any) -> Tuple[Any, ...]: + """Gets the __args__ of the annotations of a given typing""" + try: + return typing.get_args(v) + except AttributeError: + return getattr(v, "__args__", ()) if v is not typing.Generic else typing.Generic + + +def typing_get_origin(v: Any) -> Optional[Any]: + """Gets the __origin__ of the annotations of a given typing""" + try: + return typing.get_origin(v) + except AttributeError: + return getattr(v, "__origin__", None)