Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Jan 22, 2024
1 parent 38eb953 commit 6971b8c
Show file tree
Hide file tree
Showing 12 changed files with 146 additions and 95 deletions.
53 changes: 30 additions & 23 deletions dlt/extract/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
simulate_func_call,
wrap_compat_transformer,
wrap_resource_gen,
wrap_async_generator,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -321,6 +322,9 @@ def evaluate_gen(self) -> None:
# verify if transformer can be called
self._ensure_transform_step(self._gen_idx, gen)

# ensure that asyn gens are wrapped
self.replace_gen(wrap_async_generator(self.gen))

# evaluate transforms
for step_no, step in enumerate(self._steps):
# print(f"pipe {self.name} step no {step_no} step({step})")
Expand Down Expand Up @@ -366,7 +370,7 @@ def _wrap_gen(self, *args: Any, **kwargs: Any) -> Any:

def _verify_head_step(self, step: TPipeStep) -> None:
# first element must be Iterable, Iterator or Callable in resource pipe
if not isinstance(step, (Iterable, Iterator)) and not callable(step):
if not isinstance(step, (Iterable, Iterator, AsyncIterator)) and not callable(step):
raise CreatePipeException(
self.name, "A head of a resource pipe must be Iterable, Iterator or a Callable"
)
Expand Down Expand Up @@ -791,6 +795,27 @@ def _get_source_item(self) -> ResolvablePipeItem:
elif self._next_item_mode == "round_robin":
return self._get_source_item_round_robin()

def _get_next_item_from_generator(
self,
gen: Any,
step: int,
pipe: Pipe,
meta: Any,
) -> ResolvablePipeItem:
item: ResolvablePipeItem = next(gen)
if not item:
return item
# full pipe item may be returned, this is used by ForkPipe step
# to redirect execution of an item to another pipe
if isinstance(item, ResolvablePipeItem):
return item
else:
# keep the item assigned step and pipe when creating resolvable item
if isinstance(item, DataItemWithMeta):
return ResolvablePipeItem(item.data, step, pipe, item.meta)
else:
return ResolvablePipeItem(item, step, pipe, meta)

def _get_source_item_current(self) -> ResolvablePipeItem:
# no more sources to iterate
if len(self._sources) == 0:
Expand All @@ -803,17 +828,8 @@ def _get_source_item_current(self) -> ResolvablePipeItem:
set_current_pipe_name(pipe.name)
item = None
while item is None:
item = next(gen)
# full pipe item may be returned, this is used by ForkPipe step
# to redirect execution of an item to another pipe
if isinstance(item, ResolvablePipeItem):
return item
else:
# keep the item assigned step and pipe when creating resolvable item
if isinstance(item, DataItemWithMeta):
return ResolvablePipeItem(item.data, step, pipe, item.meta)
else:
return ResolvablePipeItem(item, step, pipe, meta)
item = self._get_next_item_from_generator(gen, step, pipe, meta)
return item
except StopIteration:
# remove empty iterator and try another source
self._sources.pop()
Expand All @@ -839,17 +855,8 @@ def _get_source_item_round_robin(self) -> ResolvablePipeItem:
self._round_robin_index = (self._round_robin_index + 1) % sources_count
gen, step, pipe, meta = self._sources[self._round_robin_index]
set_current_pipe_name(pipe.name)
item = next(gen)
# full pipe item may be returned, this is used by ForkPipe step
# to redirect execution of an item to another pipe
if isinstance(item, ResolvablePipeItem):
return item
else:
# keep the item assigned step and pipe when creating resolvable item
if isinstance(item, DataItemWithMeta):
return ResolvablePipeItem(item.data, step, pipe, item.meta)
else:
return ResolvablePipeItem(item, step, pipe, meta)
item = self._get_next_item_from_generator(gen, step, pipe, meta)
return item
except StopIteration:
# remove empty iterator and try another source
self._sources.pop(self._round_robin_index)
Expand Down
4 changes: 1 addition & 3 deletions dlt/extract/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ def from_data(
data = wrap_additional_type(data)

# several iterable types are not allowed and must be excluded right away
if isinstance(data, (AsyncIterator, AsyncIterable)):
raise InvalidResourceDataTypeAsync(name, data, type(data))
if isinstance(data, (str, dict)):
raise InvalidResourceDataTypeBasic(name, data, type(data))

Expand All @@ -135,7 +133,7 @@ def from_data(
parent_pipe = DltResource._get_parent_pipe(name, data_from)

# create resource from iterator, iterable or generator function
if isinstance(data, (Iterable, Iterator)) or callable(data):
if isinstance(data, (Iterable, Iterator, AsyncIterable)) or callable(data):
pipe = Pipe.from_data(name, data, parent=parent_pipe)
return cls(
pipe,
Expand Down
32 changes: 29 additions & 3 deletions dlt/extract/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import makefun
from typing import Optional, Tuple, Union, List, Any, Sequence, cast
from typing import Optional, Tuple, Union, List, Any, Sequence, cast, Iterator
from collections.abc import Mapping as C_Mapping

from dlt.common.exceptions import MissingDependencyException
Expand Down Expand Up @@ -119,6 +119,28 @@ def check_compat_transformer(name: str, f: AnyFun, sig: inspect.Signature) -> in
return meta_arg


def wrap_async_generator(wrapped) -> Any:
"""Wraps an async generator into a list with one awaitable"""
if inspect.isasyncgen(wrapped):

async def run() -> List[TDataItem]:
result: List[TDataItem] = []
try:
item: TDataItems = None
while item := await wrapped.__anext__():
if isinstance(item, Iterator):
result.extend(item)
else:
result.append(item)
except StopAsyncIteration:
pass
return result

yield run()
else:
return wrapped


def wrap_compat_transformer(
name: str, f: AnyFun, sig: inspect.Signature, *args: Any, **kwargs: Any
) -> AnyFun:
Expand All @@ -142,8 +164,12 @@ def wrap_resource_gen(
name: str, f: AnyFun, sig: inspect.Signature, *args: Any, **kwargs: Any
) -> AnyFun:
"""Wraps a generator or generator function so it is evaluated on extraction"""
if inspect.isgeneratorfunction(inspect.unwrap(f)) or inspect.isgenerator(f):
# always wrap generators and generator functions. evaluate only at runtime!

if (
inspect.isgeneratorfunction(inspect.unwrap(f))
or inspect.isgenerator(f)
or inspect.isasyncgenfunction(f)
):

def _partial() -> Any:
# print(f"_PARTIAL: {args} {kwargs}")
Expand Down
5 changes: 4 additions & 1 deletion docs/examples/google_sheets/google_sheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
)
from dlt.common.typing import DictStrAny, StrAny


def _initialize_sheets(
credentials: Union[GcpOAuthCredentials, GcpServiceAccountCredentials]
) -> Any:
# Build the service object.
service = build("sheets", "v4", credentials=credentials.to_native_credentials())
return service


@dlt.source
def google_spreadsheet(
spreadsheet_id: str,
Expand Down Expand Up @@ -55,6 +57,7 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]:
for name in sheet_names
]


if __name__ == "__main__":
pipeline = dlt.pipeline(destination="duckdb")
# see example.secrets.toml to where to put credentials
Expand All @@ -67,4 +70,4 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]:
sheet_names=range_names,
)
)
print(info)
print(info)
18 changes: 5 additions & 13 deletions docs/examples/pdf_to_weaviate/pdf_to_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@ def list_files(folder_path: str):
folder_path = os.path.abspath(folder_path)
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
yield {
"file_name": filename,
"file_path": file_path,
"mtime": os.path.getmtime(file_path)
}
yield {"file_name": filename, "file_path": file_path, "mtime": os.path.getmtime(file_path)}


@dlt.transformer(primary_key="page_id", write_disposition="merge")
Expand All @@ -30,10 +26,8 @@ def pdf_to_text(file_item, separate_pages: bool = False):
page_item["page_id"] = file_item["file_name"] + "_" + str(page_no)
yield page_item

pipeline = dlt.pipeline(
pipeline_name='pdf_to_text',
destination='weaviate'
)

pipeline = dlt.pipeline(pipeline_name="pdf_to_text", destination="weaviate")

# this constructs a simple pipeline that: (1) reads files from "invoices" folder (2) filters only those ending with ".pdf"
# (3) sends them to pdf_to_text transformer with pipe (|) operator
Expand All @@ -46,9 +40,7 @@ def pdf_to_text(file_item, separate_pages: bool = False):
pdf_pipeline.table_name = "InvoiceText"

# use weaviate_adapter to tell destination to vectorize "text" column
load_info = pipeline.run(
weaviate_adapter(pdf_pipeline, vectorize="text")
)
load_info = pipeline.run(weaviate_adapter(pdf_pipeline, vectorize="text"))
row_counts = pipeline.last_trace.last_normalize_info
print(row_counts)
print("------")
Expand All @@ -58,4 +50,4 @@ def pdf_to_text(file_item, separate_pages: bool = False):

client = weaviate.Client("http://localhost:8080")
# get text of all the invoices in InvoiceText class we just created above
print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do())
print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do())
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]:
)
)
print(info)
# @@@DLT_SNIPPET_END google_sheets_run
# @@@DLT_SNIPPET_END example
# @@@DLT_SNIPPET_END google_sheets_run
# @@@DLT_SNIPPET_END example
row_counts = pipeline.last_trace.last_normalize_info.row_counts
print(row_counts.keys())
assert row_counts["hidden_columns_merged_cells"] == 7
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from tests.pipeline.utils import assert_load_info


def pdf_to_weaviate_snippet() -> None:
# @@@DLT_SNIPPET_START example
# @@@DLT_SNIPPET_START pdf_to_weaviate
Expand All @@ -9,7 +10,6 @@ def pdf_to_weaviate_snippet() -> None:
from dlt.destinations.impl.weaviate import weaviate_adapter
from PyPDF2 import PdfReader


@dlt.resource(selected=False)
def list_files(folder_path: str):
folder_path = os.path.abspath(folder_path)
Expand All @@ -18,10 +18,9 @@ def list_files(folder_path: str):
yield {
"file_name": filename,
"file_path": file_path,
"mtime": os.path.getmtime(file_path)
"mtime": os.path.getmtime(file_path),
}


@dlt.transformer(primary_key="page_id", write_disposition="merge")
def pdf_to_text(file_item, separate_pages: bool = False):
if not separate_pages:
Expand All @@ -35,10 +34,7 @@ def pdf_to_text(file_item, separate_pages: bool = False):
page_item["page_id"] = file_item["file_name"] + "_" + str(page_no)
yield page_item

pipeline = dlt.pipeline(
pipeline_name='pdf_to_text',
destination='weaviate'
)
pipeline = dlt.pipeline(pipeline_name="pdf_to_text", destination="weaviate")

# this constructs a simple pipeline that: (1) reads files from "invoices" folder (2) filters only those ending with ".pdf"
# (3) sends them to pdf_to_text transformer with pipe (|) operator
Expand All @@ -51,9 +47,7 @@ def pdf_to_text(file_item, separate_pages: bool = False):
pdf_pipeline.table_name = "InvoiceText"

# use weaviate_adapter to tell destination to vectorize "text" column
load_info = pipeline.run(
weaviate_adapter(pdf_pipeline, vectorize="text")
)
load_info = pipeline.run(weaviate_adapter(pdf_pipeline, vectorize="text"))
row_counts = pipeline.last_trace.last_normalize_info
print(row_counts)
print("------")
Expand Down
11 changes: 3 additions & 8 deletions docs/website/docs/intro-snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ def intro_snippet() -> None:
response.raise_for_status()
data.append(response.json())
# Extract, normalize, and load the data
load_info = pipeline.run(data, table_name='player')
load_info = pipeline.run(data, table_name="player")
# @@@DLT_SNIPPET_END api

assert_load_info(load_info)


def csv_snippet() -> None:

# @@@DLT_SNIPPET_START csv
import dlt
import pandas as pd
Expand All @@ -50,8 +49,8 @@ def csv_snippet() -> None:

assert_load_info(load_info)

def db_snippet() -> None:

def db_snippet() -> None:
# @@@DLT_SNIPPET_START db
import dlt
from sqlalchemy import create_engine
Expand All @@ -74,13 +73,9 @@ def db_snippet() -> None:
)

# Convert the rows into dictionaries on the fly with a map function
load_info = pipeline.run(
map(lambda row: dict(row._mapping), rows),
table_name="genome"
)
load_info = pipeline.run(map(lambda row: dict(row._mapping), rows), table_name="genome")

print(load_info)
# @@@DLT_SNIPPET_END db

assert_load_info(load_info)

Loading

0 comments on commit 6971b8c

Please sign in to comment.