Skip to content

Commit

Permalink
Typing
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Feb 23, 2024
1 parent 0882987 commit acb3b45
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
5 changes: 4 additions & 1 deletion dlt/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
runtime_checkable,
IO,
Iterator,
Generator,
)

from typing_extensions import (
Expand Down Expand Up @@ -70,7 +71,9 @@
AnyFun: TypeAlias = Callable[..., Any]
TFun = TypeVar("TFun", bound=AnyFun) # any function
TAny = TypeVar("TAny", bound=Any)
TAnyFunOrIterator = TypeVar("TAnyFunOrIterator", AnyFun, Iterator[Any])
TAnyFunOrGenerator = TypeVar(
"TAnyFunOrGenerator", AnyFun, Generator[Any, Optional[Any], Optional[Any]]
)
TAnyClass = TypeVar("TAnyClass", bound=object)
TimedeltaSeconds = Union[int, float, timedelta]
# represent secret value ie. coming from Kubernetes/Docker secrets or other providers
Expand Down
2 changes: 1 addition & 1 deletion dlt/extract/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def parallelize(self) -> "DltResource":
):
raise InvalidParallelResourceDataType(self.name, self._pipe.gen, type(self._pipe.gen))

self._pipe.replace_gen(wrap_parallel_iterator(self._pipe.gen))
self._pipe.replace_gen(wrap_parallel_iterator(self._pipe.gen)) # type: ignore # TODO
return self

def add_step(
Expand Down
14 changes: 6 additions & 8 deletions dlt/extract/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dlt.common.exceptions import MissingDependencyException
from dlt.common.pipeline import reset_resource_state
from dlt.common.schema.typing import TColumnNames, TAnySchemaColumns, TTableSchemaColumns
from dlt.common.typing import AnyFun, DictStrAny, TDataItem, TDataItems
from dlt.common.typing import AnyFun, DictStrAny, TDataItem, TDataItems, TAnyFunOrGenerator
from dlt.common.utils import get_callable_name

from dlt.extract.exceptions import (
Expand Down Expand Up @@ -178,14 +178,12 @@ async def run() -> TDataItems:
exhausted = True


def wrap_parallel_iterator(
f: Union[Generator[TDataItems, Optional[Any], Optional[Any]], AnyFun]
) -> Union[Generator[TDataItems, Optional[Any], Optional[Any]], AnyFun]:
def wrap_parallel_iterator(f: TAnyFunOrGenerator) -> TAnyFunOrGenerator:
"""Wraps a generator for parallel extraction"""

def _wrapper(*args: Any, **kwargs: Any) -> Generator[TDataItems, None, None]:
is_generator = True
gen: Union[Generator[TDataItems, Optional[Any], Optional[Any]], AnyFun]
gen: TAnyFunOrGenerator
if callable(f):
if inspect.isgeneratorfunction(inspect.unwrap(f)):
gen = f(*args, **kwargs)
Expand All @@ -204,7 +202,7 @@ def _parallel_gen() -> TDataItems:
nonlocal exhausted
try:
if is_generator:
return next(gen) # type: ignore[arg-type]
return next(gen) # type: ignore[call-overload]
else:
return gen(*args, **kwargs) # type: ignore[operator]
except StopIteration:
Expand All @@ -223,11 +221,11 @@ def _parallel_gen() -> TDataItems:
yield _parallel_gen
except GeneratorExit:
if is_generator:
gen.close() # type: ignore[union-attr]
gen.close() # type: ignore[attr-defined]
raise

if callable(f):
return wraps(f)(_wrapper)
return wraps(f)(_wrapper) # type: ignore[arg-type]
return _wrapper()


Expand Down

0 comments on commit acb3b45

Please sign in to comment.