From 02600bbc9c6374e23fb6f67b25c86db7f36b7d7d Mon Sep 17 00:00:00 2001 From: Tian Gao Date: Mon, 21 Oct 2024 20:17:06 -0700 Subject: [PATCH] Fix type hint of add_objprint (#113) --- src/objprint/decorator.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/objprint/decorator.py b/src/objprint/decorator.py index cfcc00a..e50b9fa 100644 --- a/src/objprint/decorator.py +++ b/src/objprint/decorator.py @@ -3,12 +3,29 @@ import functools -from typing import Callable, Optional, Type, Set, Union +from typing import Callable, Optional, Type, TypeVar, Set, Union, overload + + +T = TypeVar("T", bound=Type) + + +@overload +def add_objprint( + orig_class: None = None, + format: str = "string", **kwargs) -> Callable[[T], T]: + ... + + +@overload +def add_objprint( + orig_class: T, + format: str = "string", **kwargs) -> T: + ... def add_objprint( - orig_class: Optional[Type] = None, - format: str = "string", **kwargs) -> Union[Type, Callable[[Type], Type]]: + orig_class: Optional[T] = None, + format: str = "string", **kwargs) -> Union[T, Callable[[T], T]]: from . import _objprint @@ -24,7 +41,7 @@ def __str__(self) -> str: return _objprint._get_custom_object_str(self, memo, indent_level=0, cfg=cfg) if orig_class is None: - def wrapper(cls: Type) -> Type: + def wrapper(cls: T) -> T: cls.__str__ = functools.wraps(cls.__str__)(__str__) # type: ignore return cls return wrapper