From 46e022686c565e47749236f74ae8a849c6676b0c Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 6 Aug 2024 17:27:03 +0200 Subject: [PATCH] move code for accessing frames and tables to the cursor and use duckdb dbapi cursor in filesystem --- dlt/common/destination/reference.py | 17 ++----- dlt/dataset.py | 46 +++++++++--------- .../impl/filesystem/filesystem.py | 45 +++++------------ dlt/destinations/sql_client.py | 48 ++++++++----------- dlt/destinations/typing.py | 6 ++- tests/load/test_read_interfaces.py | 12 ++--- 6 files changed, 71 insertions(+), 103 deletions(-) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index b1c1b88421..7ca981a47c 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod import dataclasses from importlib import import_module +from contextlib import contextmanager + from types import TracebackType from typing import ( Callable, @@ -570,24 +572,13 @@ class SupportsDataAccess(ABC): """Add support for accessing data as arrow tables or pandas dataframes""" @abstractmethod - def iter_df( - self, - *, - table: str = None, - batch_size: int = 1000, - sql: str = None, - prepare_tables: List[str] = None, - ) -> Generator[DataFrame, None, None]: ... - - @abstractmethod - def iter_arrow( + def cursor_for_relation( self, *, table: str = None, - batch_size: int = 1000, sql: str = None, prepare_tables: List[str] = None, - ) -> Generator[ArrowTable, None, None]: ... + ) -> ContextManager[Any]: ... # TODO: type Destination properly diff --git a/dlt/dataset.py b/dlt/dataset.py index aba3d98896..a79355a80f 100644 --- a/dlt/dataset.py +++ b/dlt/dataset.py @@ -1,4 +1,4 @@ -from typing import cast, Any, TYPE_CHECKING, Generator, List +from typing import cast, Any, TYPE_CHECKING, Generator, List, ContextManager from contextlib import contextmanager @@ -21,7 +21,7 @@ def __init__( self.table = table @contextmanager - def _client(self) -> Generator[SupportsDataAccess, None, None]: + def _client(self) -> Generator[SupportsDataAccess, Any, Any]: from dlt.destinations.job_client_impl import SqlJobClientBase from dlt.destinations.fs_client import FSClientBase @@ -41,58 +41,60 @@ def _client(self) -> Generator[SupportsDataAccess, None, None]: " dataset." ) + @contextmanager + def _cursor_for_relation(self) -> Generator[Any, Any, Any]: + with self._client() as client: + with client.cursor_for_relation( + sql=self.sql, table=self.table, prepare_tables=self.prepare_tables + ) as cursor: + yield cursor + def df( self, *, - batch_size: int = 1000, + chunk_size: int = 1000, ) -> DataFrame: """Get first batch of table as dataframe""" return next( self.iter_df( - batch_size=batch_size, + chunk_size=chunk_size, ) ) def arrow( self, *, - batch_size: int = 1000, + chunk_size: int = 1000, ) -> ArrowTable: """Get first batch of table as arrow table""" return next( self.iter_arrow( - batch_size=batch_size, + chunk_size=chunk_size, ) ) def iter_df( self, *, - batch_size: int = 1000, + chunk_size: int = 1000, ) -> Generator[DataFrame, None, None]: - """iterates over the whole table in dataframes of the given batch_size, batch_size of -1 will return the full table in the first batch""" + """iterates over the whole table in dataframes of the given chunk_size, chunk_size of -1 will return the full table in the first batch""" # if no table is given, take the bound table - with self._client() as data_access: - yield from data_access.iter_df( - sql=self.sql, - table=self.table, - batch_size=batch_size, - prepare_tables=self.prepare_tables, + with self._cursor_for_relation() as cursor: + yield from cursor.iter_df( + chunk_size=chunk_size, ) def iter_arrow( self, *, - batch_size: int = 1000, + chunk_size: int = 1000, ) -> Generator[ArrowTable, None, None]: - """iterates over the whole table in arrow tables of the given batch_size, batch_size of -1 will return the full table in the first batch""" + """iterates over the whole table in arrow tables of the given chunk_size, chunk_size of -1 will return the full table in the first batch""" # if no table is given, take the bound table - with self._client() as data_access: - yield from data_access.iter_arrow( - sql=self.sql, - table=self.table, - batch_size=batch_size, - prepare_tables=self.prepare_tables, + with self._cursor_for_relation() as cursor: + yield from cursor.iter_arrow( + chunk_size=chunk_size, ) diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index 4968d0ba35..c53a43f728 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -1,7 +1,7 @@ import posixpath import os import base64 - +from contextlib import contextmanager from types import TracebackType from typing import ( List, @@ -14,6 +14,7 @@ cast, Generator, Literal, + Any, ) from fsspec import AbstractFileSystem from contextlib import contextmanager @@ -31,6 +32,7 @@ TPipelineStateDoc, load_package as current_load_package, ) +from dlt.destinations.sql_client import DBApiCursor from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( FollowupJob, @@ -646,15 +648,11 @@ def get_duckdb( return db - def iter_df( - self, - *, - table: str = None, - batch_size: int = 1000, - sql: str = None, - prepare_tables: List[str] = None, - ) -> Generator[DataFrame, None, None]: - """Provide dataframes via duckdb""" + @contextmanager + def cursor_for_relation( + self, *, table: str = None, sql: str = None, prepare_tables: List[str] = None + ) -> Generator[DBApiCursor, Any, Any]: + from dlt.destinations.impl.duckdb.sql_client import DuckDBDBApiCursorImpl if table: prepare_tables = [table] @@ -669,27 +667,8 @@ def iter_df( db = self.get_duckdb(tables=prepare_tables) - # yield in batches - offset = 0 - while True: - df = db.sql(sql + f" OFFSET {offset} LIMIT {batch_size}").df() - if len(df.index) == 0: - break - yield df - offset += batch_size + if not sql: + sql = f"SELECT * FROM {table}" - def iter_arrow( - self, - *, - table: str = None, - batch_size: int = 1000, - sql: str = None, - prepare_tables: List[str] = None, - ) -> Generator[ArrowTable, None, None]: - """Default implementation converts df to arrow""" - - # TODO: duckdb supports iterating in batches natively.. - for df in self.iter_df( - sql=sql, table=table, batch_size=batch_size, prepare_tables=prepare_tables - ): - yield ArrowTable.from_pandas(df) + db.execute(sql) + yield DuckDBDBApiCursorImpl(db) # type: ignore diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 138c38bbea..60f0bf7acc 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -286,38 +286,15 @@ def _truncate_table_sql(self, qualified_table_name: str) -> str: else: return f"DELETE FROM {qualified_table_name} WHERE 1=1;" - def iter_df( - self, - *, - table: str = None, - batch_size: int = 1000, - sql: str = None, - prepare_tables: List[str] = None, - ) -> Generator[DataFrame, None, None]: + @contextmanager + def cursor_for_relation( + self, *, table: str = None, sql: str = None, prepare_tables: List[str] = None + ) -> Generator[DBApiCursor, Any, Any]: if not sql: table = self.make_qualified_table_name(table) sql = f"SELECT * FROM {table}" - - # iterate over results in batch size chunks with self.execute_query(sql) as cursor: - while True: - if not (result := cursor.fetchmany(batch_size)): - return - df = DataFrame(result) - df.columns = [x[0] for x in cursor.description] - yield df - - def iter_arrow( - self, - *, - table: str = None, - batch_size: int = 1000, - sql: str = None, - prepare_tables: List[str] = None, - ) -> Generator[ArrowTable, None, None]: - """Default implementation converts df to arrow""" - for df in self.iter_df(sql=sql, table=table, batch_size=batch_size): - yield ArrowTable.from_pandas(df) + yield cursor class DBApiCursorImpl(DBApiCursor): @@ -357,6 +334,21 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: else: return df + def iter_df(self, chunk_size: int = 1000) -> Generator[DataFrame, None, None]: + from dlt.common.libs.pandas_sql import _wrap_result + + # iterate over results in batch size chunks + columns = self._get_columns() + while True: + if not (result := self.fetchmany(chunk_size)): + return + yield _wrap_result(result, columns) + + def iter_arrow(self, chunk_size: int = 1000) -> Generator[ArrowTable, None, None]: + """Default implementation converts df to arrow""" + for df in self.iter_df(chunk_size=chunk_size): + yield ArrowTable.from_pandas(df) + def raise_database_error(f: TFun) -> TFun: @wraps(f) diff --git a/dlt/destinations/typing.py b/dlt/destinations/typing.py index 4d50729f67..78c4c512a1 100644 --- a/dlt/destinations/typing.py +++ b/dlt/destinations/typing.py @@ -1,4 +1,4 @@ -from typing import Any, AnyStr, List, Type, Optional, Protocol, Tuple, TypeVar +from typing import Any, AnyStr, List, Type, Optional, Protocol, Tuple, TypeVar, Generator from dlt.common.typing import DataFrame, ArrowTable @@ -47,3 +47,7 @@ def df(self, chunk_size: int = None, **kwargs: None) -> Optional[DataFrame]: Optional[DataFrame]: A data frame with query results. If chunk_size > 0, None will be returned if there is no more data in results """ ... + + def iter_df(self, chunk_size: int = 1000) -> Generator[DataFrame, None, None]: ... + + def iter_arrow(self, chunk_size: int = 1000) -> Generator[ArrowTable, None, None]: ... diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index 37d7cc6a22..c4386f698d 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -40,13 +40,13 @@ def test_read_interfaces_sql(destination_config: DestinationTestConfiguration) - ) # get one df - df = pipeline.dataset["items"].df(batch_size=5) + df = pipeline.dataset["items"].df(chunk_size=5) assert len(df.index) == 5 assert set(df.columns.values) == {"id", "_dlt_load_id", "_dlt_id"} # iterate all dataframes frames = [] - for df in pipeline.dataset["items"].iter_df(batch_size=70): + for df in pipeline.dataset["items"].iter_df(chunk_size=70): frames.append(df) # check frame amount and items counts @@ -58,7 +58,7 @@ def test_read_interfaces_sql(destination_config: DestinationTestConfiguration) - assert set(ids) == set(range(300)) # basic check of arrow table - table = pipeline.dataset.items.arrow(batch_size=5) + table = pipeline.dataset.items.arrow(chunk_size=5) assert set(table.column_names) == {"id", "_dlt_load_id", "_dlt_id"} assert table.num_rows == 5 @@ -93,13 +93,13 @@ def test_read_interfaces_filesystem(destination_config: DestinationTestConfigura pipeline.run(s, loader_file_format=destination_config.file_format) # get one df - df = pipeline.dataset["items"].df(batch_size=5) + df = pipeline.dataset["items"].df(chunk_size=5) assert len(df.index) == 5 assert set(df.columns.values) == {"id", "_dlt_load_id", "_dlt_id"} # iterate all dataframes frames = [] - for df in pipeline.dataset.items.iter_df(batch_size=70): + for df in pipeline.dataset.items.iter_df(chunk_size=70): frames.append(df) # check frame amount and items counts @@ -111,6 +111,6 @@ def test_read_interfaces_filesystem(destination_config: DestinationTestConfigura assert set(ids) == set(range(300)) # basic check of arrow table - table = pipeline.dataset["items"].arrow(batch_size=5) + table = pipeline.dataset["items"].arrow(chunk_size=5) assert set(table.column_names) == {"id", "_dlt_load_id", "_dlt_id"} assert table.num_rows == 5