Skip to content

Commit

Permalink
Handle non-iterator transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Feb 16, 2024
1 parent 871cb41 commit 6f05f2e
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 42 deletions.
8 changes: 4 additions & 4 deletions dlt/extract/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ def _vacate_slot(self, _: TItemFuture) -> None:
self.used_slots -= 1

@overload
def submit(self, pipe_item: ResolvablePipeItem, block: Literal[False]) -> Optional[TItemFuture]:
...
def submit(
self, pipe_item: ResolvablePipeItem, block: Literal[False]
) -> Optional[TItemFuture]: ...

@overload
def submit(self, pipe_item: ResolvablePipeItem, block: Literal[True]) -> TItemFuture:
...
def submit(self, pipe_item: ResolvablePipeItem, block: Literal[True]) -> TItemFuture: ...

def submit(self, pipe_item: ResolvablePipeItem, block: bool = False) -> Optional[TItemFuture]:
"""Submit an item to the pool.
Expand Down
6 changes: 3 additions & 3 deletions dlt/extract/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,8 @@ def transformer(
merge_key: TTableHintTemplate[TColumnNames] = None,
selected: bool = True,
spec: Type[BaseConfiguration] = None,
standalone: Literal[True] = True,
parallelized: bool = False,
standalone: Literal[True] = True,
) -> Callable[
[Callable[Concatenate[TDataItem, TResourceFunParams], Any]],
Callable[TResourceFunParams, DltResource],
Expand Down Expand Up @@ -605,8 +605,8 @@ def transformer(
merge_key: TTableHintTemplate[TColumnNames] = None,
selected: bool = True,
spec: Type[BaseConfiguration] = None,
standalone: Literal[True] = True,
parallelized: bool = False,
standalone: Literal[True] = True,
) -> Callable[TResourceFunParams, DltResource]: ...


Expand All @@ -622,8 +622,8 @@ def transformer(
merge_key: TTableHintTemplate[TColumnNames] = None,
selected: bool = True,
spec: Type[BaseConfiguration] = None,
standalone: bool = False,
parallelized: bool = False,
standalone: bool = False,
) -> Any:
"""A form of `dlt resource` that takes input from other resources via `data_from` argument in order to enrich or transform the data.
Expand Down
2 changes: 1 addition & 1 deletion dlt/extract/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def parallelize(self) -> "DltResource":
if (
not inspect.isgenerator(self._pipe.gen)
and not inspect.isgeneratorfunction(self._pipe.gen)
and not self.is_transformer
and not (callable(self._pipe.gen) and self.is_transformer)
):
raise InvalidParallelResourceDataType(self.name, self._pipe.gen, type(self._pipe.gen))

Expand Down
6 changes: 0 additions & 6 deletions dlt/extract/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@
TTableHintTemplate = Union[TDynHintType, TFunHintTemplate[TDynHintType]]


TGenOrGenFunction = Union[
Generator[TDataItems, Optional[Any], Optional[Any]],
Callable[..., Generator[TDataItems, Optional[Any], Optional[Any]]],
] # ]


class DataItemWithMeta:
__slots__ = "meta", "data"

Expand Down
31 changes: 23 additions & 8 deletions dlt/extract/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Iterator,
)
from collections.abc import Mapping as C_Mapping
from functools import wraps
from functools import wraps, partial

from dlt.common.exceptions import MissingDependencyException
from dlt.common.pipeline import reset_resource_state
Expand All @@ -33,7 +33,6 @@
TDataItem,
TFunHintTemplate,
SupportsPipe,
TGenOrGenFunction,
)

try:
Expand Down Expand Up @@ -179,21 +178,32 @@ async def run() -> TDataItems:
exhausted = True


def wrap_parallel_iterator(f: TGenOrGenFunction) -> TGenOrGenFunction:
def wrap_parallel_iterator(
f: Union[Generator[TDataItems, Optional[Any], Optional[Any]], AnyFun]
) -> Union[Generator[TDataItems, Optional[Any], Optional[Any]], AnyFun]:
"""Wraps a generator for parallel extraction"""

def _wrapper(*args: Any, **kwargs: Any) -> Generator[TDataItems, None, None]:
gen = f(*args, **kwargs) if callable(f) else f
is_generator = True
gen: Union[Generator[TDataItems, Optional[Any], Optional[Any]], AnyFun]
if callable(f):
if inspect.isgeneratorfunction(f):
gen = f(*args, **kwargs)
else:
is_generator = False
gen = f
else:
gen = f

exhausted = False
busy = False

def _parallel_gen() -> TDataItems:
nonlocal busy
nonlocal exhausted
try:
return next(gen)
return next(gen) # type: ignore[arg-type]
except StopIteration:
nonlocal exhausted
exhausted = True
return None
finally:
Expand All @@ -204,9 +214,14 @@ def _parallel_gen() -> TDataItems:
while busy:
yield None
busy = True
yield _parallel_gen
if is_generator:
yield _parallel_gen
else:
exhausted = True
yield partial(gen, *args, **kwargs) # type: ignore[arg-type]
except GeneratorExit:
# gen.close()
if is_generator:
gen.close() # type: ignore[union-attr]
raise

if callable(f):
Expand Down
64 changes: 44 additions & 20 deletions tests/pipeline/test_resources_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, List
import time
import threading
import random
Expand Down Expand Up @@ -326,35 +326,59 @@ def some_data(resource_num: int):


def test_test_parallelized_transformers() -> None:
exec_order = []
item_count = 6

@dlt.resource(parallelized=True)
def pos_data():
for i in range(1, item_count + 1):
time.sleep(0.1)
yield i

@dlt.resource(parallelized=True)
def neg_data():
time.sleep(0.05)
for i in range(-1, -item_count - 1, -1):
time.sleep(0.1)
yield i

@dlt.transformer(parallelized=True)
def multiply(item):
time.sleep(0.05)
exec_order.append("+" if item > 0 else "-")
yield item * 10

@dlt.source
def some_source():
@dlt.resource(parallelized=True)
def pos_data():
for i in range(1, item_count + 1):
time.sleep(0.1)
yield i

@dlt.resource(parallelized=True)
def neg_data():
time.sleep(0.05)
for i in range(-1, -item_count - 1, -1):
time.sleep(0.1)
yield i

@dlt.transformer(parallelized=True)
def multiply(item):
time.sleep(0.05)
exec_order.append("+" if item > 0 else "-")
yield item * 10
return [
neg_data | multiply.with_name("t_a"),
pos_data | multiply.with_name("t_b"),
]

exec_order: List[str] = []
result = list(some_source())

expected_result = [i * 10 for i in range(-item_count, item_count + 1)]
expected_result.remove(0)

assert sorted(result) == expected_result
assert exec_order == ["+", "-"] * item_count

@dlt.transformer(parallelized=True) # type: ignore[no-redef]
def multiply(item):
# Transformer that is not a generator
time.sleep(0.05)
exec_order.append("+" if item > 0 else "-")
return item * 10

@dlt.source # type: ignore[no-redef]
def some_source():
return [
neg_data | multiply.with_name("t_a"),
pos_data | multiply.with_name("t_b"),
]

exec_order = []

result = list(some_source())

expected_result = [i * 10 for i in range(-item_count, item_count + 1)]
Expand Down

0 comments on commit 6f05f2e

Please sign in to comment.