From acb3b45863d091a79fed6dde42ebe937d3b7aba8 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 23 Feb 2024 15:13:10 -0500 Subject: [PATCH] Typing --- dlt/common/typing.py | 5 ++++- dlt/extract/resource.py | 2 +- dlt/extract/utils.py | 14 ++++++-------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 0682be5062..7c20b0df43 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -25,6 +25,7 @@ runtime_checkable, IO, Iterator, + Generator, ) from typing_extensions import ( @@ -70,7 +71,9 @@ AnyFun: TypeAlias = Callable[..., Any] TFun = TypeVar("TFun", bound=AnyFun) # any function TAny = TypeVar("TAny", bound=Any) -TAnyFunOrIterator = TypeVar("TAnyFunOrIterator", AnyFun, Iterator[Any]) +TAnyFunOrGenerator = TypeVar( + "TAnyFunOrGenerator", AnyFun, Generator[Any, Optional[Any], Optional[Any]] +) TAnyClass = TypeVar("TAnyClass", bound=object) TimedeltaSeconds = Union[int, float, timedelta] # represent secret value ie. coming from Kubernetes/Docker secrets or other providers diff --git a/dlt/extract/resource.py b/dlt/extract/resource.py index 28a071729e..f73ee676b6 100644 --- a/dlt/extract/resource.py +++ b/dlt/extract/resource.py @@ -360,7 +360,7 @@ def parallelize(self) -> "DltResource": ): raise InvalidParallelResourceDataType(self.name, self._pipe.gen, type(self._pipe.gen)) - self._pipe.replace_gen(wrap_parallel_iterator(self._pipe.gen)) + self._pipe.replace_gen(wrap_parallel_iterator(self._pipe.gen)) # type: ignore # TODO return self def add_step( diff --git a/dlt/extract/utils.py b/dlt/extract/utils.py index 7d5bea5e28..f9a35cd9ae 100644 --- a/dlt/extract/utils.py +++ b/dlt/extract/utils.py @@ -21,7 +21,7 @@ from dlt.common.exceptions import MissingDependencyException from dlt.common.pipeline import reset_resource_state from dlt.common.schema.typing import TColumnNames, TAnySchemaColumns, TTableSchemaColumns -from dlt.common.typing import AnyFun, DictStrAny, TDataItem, TDataItems +from dlt.common.typing import AnyFun, DictStrAny, TDataItem, TDataItems, TAnyFunOrGenerator from dlt.common.utils import get_callable_name from dlt.extract.exceptions import ( @@ -178,14 +178,12 @@ async def run() -> TDataItems: exhausted = True -def wrap_parallel_iterator( - f: Union[Generator[TDataItems, Optional[Any], Optional[Any]], AnyFun] -) -> Union[Generator[TDataItems, Optional[Any], Optional[Any]], AnyFun]: +def wrap_parallel_iterator(f: TAnyFunOrGenerator) -> TAnyFunOrGenerator: """Wraps a generator for parallel extraction""" def _wrapper(*args: Any, **kwargs: Any) -> Generator[TDataItems, None, None]: is_generator = True - gen: Union[Generator[TDataItems, Optional[Any], Optional[Any]], AnyFun] + gen: TAnyFunOrGenerator if callable(f): if inspect.isgeneratorfunction(inspect.unwrap(f)): gen = f(*args, **kwargs) @@ -204,7 +202,7 @@ def _parallel_gen() -> TDataItems: nonlocal exhausted try: if is_generator: - return next(gen) # type: ignore[arg-type] + return next(gen) # type: ignore[call-overload] else: return gen(*args, **kwargs) # type: ignore[operator] except StopIteration: @@ -223,11 +221,11 @@ def _parallel_gen() -> TDataItems: yield _parallel_gen except GeneratorExit: if is_generator: - gen.close() # type: ignore[union-attr] + gen.close() # type: ignore[attr-defined] raise if callable(f): - return wraps(f)(_wrapper) + return wraps(f)(_wrapper) # type: ignore[arg-type] return _wrapper()