Skip to content

Commit

Permalink
move code for accessing frames and tables to the cursor and use duckd…
Browse files Browse the repository at this point in the history
…b dbapi cursor in filesystem
  • Loading branch information
sh-rp committed Aug 6, 2024
1 parent 13ec73b commit 46e0226
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 103 deletions.
17 changes: 4 additions & 13 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
import dataclasses
from importlib import import_module
from contextlib import contextmanager

from types import TracebackType
from typing import (
Callable,
Expand Down Expand Up @@ -570,24 +572,13 @@ class SupportsDataAccess(ABC):
"""Add support for accessing data as arrow tables or pandas dataframes"""

@abstractmethod
def iter_df(
self,
*,
table: str = None,
batch_size: int = 1000,
sql: str = None,
prepare_tables: List[str] = None,
) -> Generator[DataFrame, None, None]: ...

@abstractmethod
def iter_arrow(
def cursor_for_relation(
self,
*,
table: str = None,
batch_size: int = 1000,
sql: str = None,
prepare_tables: List[str] = None,
) -> Generator[ArrowTable, None, None]: ...
) -> ContextManager[Any]: ...


# TODO: type Destination properly
Expand Down
46 changes: 24 additions & 22 deletions dlt/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import cast, Any, TYPE_CHECKING, Generator, List
from typing import cast, Any, TYPE_CHECKING, Generator, List, ContextManager

from contextlib import contextmanager

Expand All @@ -21,7 +21,7 @@ def __init__(
self.table = table

@contextmanager
def _client(self) -> Generator[SupportsDataAccess, None, None]:
def _client(self) -> Generator[SupportsDataAccess, Any, Any]:
from dlt.destinations.job_client_impl import SqlJobClientBase
from dlt.destinations.fs_client import FSClientBase

Expand All @@ -41,58 +41,60 @@ def _client(self) -> Generator[SupportsDataAccess, None, None]:
" dataset."
)

@contextmanager
def _cursor_for_relation(self) -> Generator[Any, Any, Any]:
with self._client() as client:
with client.cursor_for_relation(
sql=self.sql, table=self.table, prepare_tables=self.prepare_tables
) as cursor:
yield cursor

def df(
self,
*,
batch_size: int = 1000,
chunk_size: int = 1000,
) -> DataFrame:
"""Get first batch of table as dataframe"""
return next(
self.iter_df(
batch_size=batch_size,
chunk_size=chunk_size,
)
)

def arrow(
self,
*,
batch_size: int = 1000,
chunk_size: int = 1000,
) -> ArrowTable:
"""Get first batch of table as arrow table"""
return next(
self.iter_arrow(
batch_size=batch_size,
chunk_size=chunk_size,
)
)

def iter_df(
self,
*,
batch_size: int = 1000,
chunk_size: int = 1000,
) -> Generator[DataFrame, None, None]:
"""iterates over the whole table in dataframes of the given batch_size, batch_size of -1 will return the full table in the first batch"""
"""iterates over the whole table in dataframes of the given chunk_size, chunk_size of -1 will return the full table in the first batch"""
# if no table is given, take the bound table
with self._client() as data_access:
yield from data_access.iter_df(
sql=self.sql,
table=self.table,
batch_size=batch_size,
prepare_tables=self.prepare_tables,
with self._cursor_for_relation() as cursor:
yield from cursor.iter_df(
chunk_size=chunk_size,
)

def iter_arrow(
self,
*,
batch_size: int = 1000,
chunk_size: int = 1000,
) -> Generator[ArrowTable, None, None]:
"""iterates over the whole table in arrow tables of the given batch_size, batch_size of -1 will return the full table in the first batch"""
"""iterates over the whole table in arrow tables of the given chunk_size, chunk_size of -1 will return the full table in the first batch"""
# if no table is given, take the bound table
with self._client() as data_access:
yield from data_access.iter_arrow(
sql=self.sql,
table=self.table,
batch_size=batch_size,
prepare_tables=self.prepare_tables,
with self._cursor_for_relation() as cursor:
yield from cursor.iter_arrow(
chunk_size=chunk_size,
)


Expand Down
45 changes: 12 additions & 33 deletions dlt/destinations/impl/filesystem/filesystem.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import posixpath
import os
import base64

from contextlib import contextmanager
from types import TracebackType
from typing import (
List,
Expand All @@ -14,6 +14,7 @@
cast,
Generator,
Literal,
Any,
)
from fsspec import AbstractFileSystem
from contextlib import contextmanager
Expand All @@ -31,6 +32,7 @@
TPipelineStateDoc,
load_package as current_load_package,
)
from dlt.destinations.sql_client import DBApiCursor
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import (
FollowupJob,
Expand Down Expand Up @@ -646,15 +648,11 @@ def get_duckdb(

return db

def iter_df(
self,
*,
table: str = None,
batch_size: int = 1000,
sql: str = None,
prepare_tables: List[str] = None,
) -> Generator[DataFrame, None, None]:
"""Provide dataframes via duckdb"""
@contextmanager
def cursor_for_relation(
self, *, table: str = None, sql: str = None, prepare_tables: List[str] = None
) -> Generator[DBApiCursor, Any, Any]:
from dlt.destinations.impl.duckdb.sql_client import DuckDBDBApiCursorImpl

if table:
prepare_tables = [table]
Expand All @@ -669,27 +667,8 @@ def iter_df(

db = self.get_duckdb(tables=prepare_tables)

# yield in batches
offset = 0
while True:
df = db.sql(sql + f" OFFSET {offset} LIMIT {batch_size}").df()
if len(df.index) == 0:
break
yield df
offset += batch_size
if not sql:
sql = f"SELECT * FROM {table}"

def iter_arrow(
self,
*,
table: str = None,
batch_size: int = 1000,
sql: str = None,
prepare_tables: List[str] = None,
) -> Generator[ArrowTable, None, None]:
"""Default implementation converts df to arrow"""

# TODO: duckdb supports iterating in batches natively..
for df in self.iter_df(
sql=sql, table=table, batch_size=batch_size, prepare_tables=prepare_tables
):
yield ArrowTable.from_pandas(df)
db.execute(sql)
yield DuckDBDBApiCursorImpl(db) # type: ignore
48 changes: 20 additions & 28 deletions dlt/destinations/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,38 +286,15 @@ def _truncate_table_sql(self, qualified_table_name: str) -> str:
else:
return f"DELETE FROM {qualified_table_name} WHERE 1=1;"

def iter_df(
self,
*,
table: str = None,
batch_size: int = 1000,
sql: str = None,
prepare_tables: List[str] = None,
) -> Generator[DataFrame, None, None]:
@contextmanager
def cursor_for_relation(
self, *, table: str = None, sql: str = None, prepare_tables: List[str] = None
) -> Generator[DBApiCursor, Any, Any]:
if not sql:
table = self.make_qualified_table_name(table)
sql = f"SELECT * FROM {table}"

# iterate over results in batch size chunks
with self.execute_query(sql) as cursor:
while True:
if not (result := cursor.fetchmany(batch_size)):
return
df = DataFrame(result)
df.columns = [x[0] for x in cursor.description]
yield df

def iter_arrow(
self,
*,
table: str = None,
batch_size: int = 1000,
sql: str = None,
prepare_tables: List[str] = None,
) -> Generator[ArrowTable, None, None]:
"""Default implementation converts df to arrow"""
for df in self.iter_df(sql=sql, table=table, batch_size=batch_size):
yield ArrowTable.from_pandas(df)
yield cursor


class DBApiCursorImpl(DBApiCursor):
Expand Down Expand Up @@ -357,6 +334,21 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]:
else:
return df

def iter_df(self, chunk_size: int = 1000) -> Generator[DataFrame, None, None]:
from dlt.common.libs.pandas_sql import _wrap_result

# iterate over results in batch size chunks
columns = self._get_columns()
while True:
if not (result := self.fetchmany(chunk_size)):
return
yield _wrap_result(result, columns)

def iter_arrow(self, chunk_size: int = 1000) -> Generator[ArrowTable, None, None]:
"""Default implementation converts df to arrow"""
for df in self.iter_df(chunk_size=chunk_size):
yield ArrowTable.from_pandas(df)


def raise_database_error(f: TFun) -> TFun:
@wraps(f)
Expand Down
6 changes: 5 additions & 1 deletion dlt/destinations/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, AnyStr, List, Type, Optional, Protocol, Tuple, TypeVar
from typing import Any, AnyStr, List, Type, Optional, Protocol, Tuple, TypeVar, Generator

from dlt.common.typing import DataFrame, ArrowTable

Expand Down Expand Up @@ -47,3 +47,7 @@ def df(self, chunk_size: int = None, **kwargs: None) -> Optional[DataFrame]:
Optional[DataFrame]: A data frame with query results. If chunk_size > 0, None will be returned if there is no more data in results
"""
...

def iter_df(self, chunk_size: int = 1000) -> Generator[DataFrame, None, None]: ...

def iter_arrow(self, chunk_size: int = 1000) -> Generator[ArrowTable, None, None]: ...
12 changes: 6 additions & 6 deletions tests/load/test_read_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def test_read_interfaces_sql(destination_config: DestinationTestConfiguration) -
)

# get one df
df = pipeline.dataset["items"].df(batch_size=5)
df = pipeline.dataset["items"].df(chunk_size=5)
assert len(df.index) == 5
assert set(df.columns.values) == {"id", "_dlt_load_id", "_dlt_id"}

# iterate all dataframes
frames = []
for df in pipeline.dataset["items"].iter_df(batch_size=70):
for df in pipeline.dataset["items"].iter_df(chunk_size=70):
frames.append(df)

# check frame amount and items counts
Expand All @@ -58,7 +58,7 @@ def test_read_interfaces_sql(destination_config: DestinationTestConfiguration) -
assert set(ids) == set(range(300))

# basic check of arrow table
table = pipeline.dataset.items.arrow(batch_size=5)
table = pipeline.dataset.items.arrow(chunk_size=5)
assert set(table.column_names) == {"id", "_dlt_load_id", "_dlt_id"}
assert table.num_rows == 5

Expand Down Expand Up @@ -93,13 +93,13 @@ def test_read_interfaces_filesystem(destination_config: DestinationTestConfigura
pipeline.run(s, loader_file_format=destination_config.file_format)

# get one df
df = pipeline.dataset["items"].df(batch_size=5)
df = pipeline.dataset["items"].df(chunk_size=5)
assert len(df.index) == 5
assert set(df.columns.values) == {"id", "_dlt_load_id", "_dlt_id"}

# iterate all dataframes
frames = []
for df in pipeline.dataset.items.iter_df(batch_size=70):
for df in pipeline.dataset.items.iter_df(chunk_size=70):
frames.append(df)

# check frame amount and items counts
Expand All @@ -111,6 +111,6 @@ def test_read_interfaces_filesystem(destination_config: DestinationTestConfigura
assert set(ids) == set(range(300))

# basic check of arrow table
table = pipeline.dataset["items"].arrow(batch_size=5)
table = pipeline.dataset["items"].arrow(chunk_size=5)
assert set(table.column_names) == {"id", "_dlt_load_id", "_dlt_id"}
assert table.num_rows == 5

0 comments on commit 46e0226

Please sign in to comment.