diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index d6be15abdd..ca417c1f6c 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -320,22 +320,11 @@ def _create_writer(self, schema: "pa.Schema") -> "pa.parquet.ParquetWriter": ) def write_header(self, columns_schema: TTableSchemaColumns) -> None: - from dlt.common.libs.pyarrow import pyarrow, get_py_arrow_datatype + from dlt.common.libs.pyarrow import table_schema_columns_to_py_arrow # build schema - self.schema = pyarrow.schema( - [ - pyarrow.field( - name, - get_py_arrow_datatype( - schema_item, - self._caps, - self.timestamp_timezone, - ), - nullable=is_nullable_column(schema_item), - ) - for name, schema_item in columns_schema.items() - ] + self.schema = table_schema_columns_to_py_arrow( + columns_schema, self._caps, self.timestamp_timezone ) # find row items that are of the json type (could be abstracted out for use in other writers?) self.nested_indices = [ diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index b922a39b6f..277247cb7c 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -38,6 +38,7 @@ from dlt.common.exceptions import TerminalValueError from dlt.common.metrics import LoadJobMetrics from dlt.common.normalizers.naming import NamingConvention +from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -506,6 +507,9 @@ class DBApiCursor(SupportsReadableRelation): native_cursor: "DBApiCursor" """Cursor implementation native to current destination""" + columns: TTableSchemaColumns + """Known dlt table columns for this cursor""" + def execute(self, query: AnyStr, *args: Any, **kwargs: Any) -> None: ... def close(self) -> None: ... @@ -513,7 +517,7 @@ def close(self) -> None: ... class SupportsReadableDataset(Protocol): """A readable dataset retrieved from a destination, has support for creating readable relations for a query or table""" - def query(self, sql: str) -> SupportsReadableRelation: ... + def query(self, query: str) -> SupportsReadableRelation: ... def __getitem__(self, table: str) -> SupportsReadableRelation: ... @@ -709,11 +713,15 @@ class WithReadableRelations(ABC): """Add support for getting readable reletions form a destination""" @abstractmethod - def get_readable_relation( + def table_relation( + self, *, table: str, columns: TTableSchemaColumns + ) -> ContextManager[SupportsReadableRelation]: ... + + @abstractmethod + def query_relation( self, *, - table: str = None, - sql: str = None, + query: str, ) -> ContextManager[SupportsReadableRelation]: ... diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 0c3c8c21cc..2afb2e663d 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -395,6 +395,36 @@ def py_arrow_to_table_schema_columns(schema: pyarrow.Schema) -> TTableSchemaColu return result +def table_schema_columns_to_py_arrow( + columns: TTableSchemaColumns, + caps: DestinationCapabilitiesContext, + timestamp_timezone: str = "UTC", +) -> pyarrow.Schema: + """Convert a table schema columns dict to a pyarrow schema. + + Args: + columns (TTableSchemaColumns): table schema columns + + Returns: + pyarrow.Schema: pyarrow schema + + """ + return pyarrow.schema( + [ + pyarrow.field( + name, + get_py_arrow_datatype( + schema_item, + caps, + timestamp_timezone, + ), + nullable=schema_item.get("nullable", True), + ) + for name, schema_item in columns.items() + ] + ) + + def get_parquet_metadata(parquet_file: TFileOrPath) -> Tuple[int, pyarrow.Schema]: """Gets parquet file metadata (including row count and schema) diff --git a/dlt/destinations/dataset.py b/dlt/destinations/dataset.py index b46bf56b86..4dbae60996 100644 --- a/dlt/destinations/dataset.py +++ b/dlt/destinations/dataset.py @@ -8,6 +8,8 @@ ) from dlt.destinations.typing import DataFrame, ArrowTable +from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common.schema import Schema class ReadableRelation(SupportsReadableRelation): @@ -16,54 +18,44 @@ def __init__( *, client: WithReadableRelations, table: str = None, - sql: str = None, + query: str = None, + columns: TTableSchemaColumns = None ) -> None: """Create a lazy evaluated relation to for the dataset of a destination""" self.client = client - self.sql = sql + self.query = query self.table = table + self.columns = columns @contextmanager def cursor(self) -> Generator[SupportsReadableRelation, Any, Any]: """Gets a DBApiCursor for the current relation""" - with self.client.get_readable_relation(sql=self.sql, table=self.table) as cursor: + with self.client.table_relation(table=self.table, columns=self.columns) as cursor: yield cursor - def df( - self, - chunk_size: int = None, - ) -> Optional[DataFrame]: + def df(self, chunk_size: int = None) -> Optional[DataFrame]: """Get first batch of table as dataframe""" with self.cursor() as cursor: return cursor.df(chunk_size=chunk_size) - def arrow( - self, - chunk_size: int = None, - ) -> Optional[ArrowTable]: + def arrow(self, chunk_size: int = None) -> Optional[ArrowTable]: """Get first batch of table as arrow table""" with self.cursor() as cursor: return cursor.arrow(chunk_size=chunk_size) def iter_df( - self, - chunk_size: int, + self, chunk_size: int, columns: TTableSchemaColumns = None ) -> Generator[DataFrame, None, None]: """iterates over the whole table in dataframes of the given chunk_size""" with self.cursor() as cursor: - yield from cursor.iter_df( - chunk_size=chunk_size, - ) + yield from cursor.iter_df(chunk_size=chunk_size) def iter_arrow( - self, - chunk_size: int, + self, chunk_size: int, columns: TTableSchemaColumns = None ) -> Generator[ArrowTable, None, None]: """iterates over the whole table in arrow tables of the given chunk_size""" with self.cursor() as cursor: - yield from cursor.iter_arrow( - chunk_size=chunk_size, - ) + yield from cursor.iter_arrow(chunk_size=chunk_size) def fetchall(self) -> List[Tuple[Any, ...]]: """does a dbapi fetch all""" @@ -89,16 +81,18 @@ def fetchone(self) -> Optional[Tuple[Any, ...]]: class ReadableDataset(SupportsReadableDataset): """Access to dataframes and arrowtables in the destination dataset""" - def __init__(self, client: WithReadableRelations) -> None: + def __init__(self, client: WithReadableRelations, schema: Schema) -> None: self.client = client + self.schema = schema - def query(self, sql: str) -> SupportsReadableRelation: - return ReadableRelation(client=self.client, sql=sql) + def query(self, query: str) -> SupportsReadableRelation: + return ReadableRelation(client=self.client, query=query) def __getitem__(self, table: str) -> SupportsReadableRelation: """access of table via dict notation""" - return ReadableRelation(client=self.client, table=table) + table_columns = self.schema.tables[table]["columns"] + return ReadableRelation(client=self.client, table=table, columns=table_columns) def __getattr__(self, table: str) -> SupportsReadableRelation: """access of table via property notation""" - return ReadableRelation(client=self.client, table=table) + return self[table] diff --git a/dlt/destinations/impl/duckdb/sql_client.py b/dlt/destinations/impl/duckdb/sql_client.py index 014ae9d674..89a522c8f7 100644 --- a/dlt/destinations/impl/duckdb/sql_client.py +++ b/dlt/destinations/impl/duckdb/sql_client.py @@ -48,7 +48,6 @@ def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]: yield df def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]: - # full table if not chunk_size: yield self.native_cursor.fetch_arrow_table() return diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index 63eb383b5b..b863cf7534 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -4,6 +4,7 @@ from contextlib import contextmanager from types import TracebackType from typing import ( + ContextManager, List, Type, Iterable, @@ -41,6 +42,7 @@ from dlt.common.destination.reference import ( FollowupJobRequest, PreparedTableSchema, + SupportsReadableRelation, TLoadJobState, RunnableLoadJob, JobClientBase, @@ -697,14 +699,17 @@ def create_table_chain_completed_followup_jobs( return jobs @contextmanager - def get_readable_relation( - self, *, table: str = None, sql: str = None + def table_relation( + self, *, table: str, columns: TTableSchemaColumns ) -> Generator[DBApiCursor, Any, Any]: - if table: - sql = f"SELECT * FROM {table}" + with self.sql_client.execute_query(f"SELECT * FROM {table}") as cursor: + cursor.columns = columns + yield cursor - with self.sql_client.execute_query(sql) as cursor: + @contextmanager + def query_relation(self, *, query: str) -> Generator[DBApiCursor, Any, Any]: + with self.sql_client.execute_query(query) as cursor: yield cursor def dataset(self) -> SupportsReadableDataset: - return ReadableDataset(self) + return ReadableDataset(self, self.schema) diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index b3b4ed2351..c626b613b8 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -714,18 +714,24 @@ def _set_query_tags_for_job(self, load_id: str, table: PreparedTableSchema) -> N ) @contextmanager - def get_readable_relation( - self, *, table: str = None, sql: str = None + def table_relation( + self, *, table: str, columns: TTableSchemaColumns ) -> Generator[SupportsReadableRelation, Any, Any]: with self.sql_client as sql_client: - if not sql: - table = sql_client.make_qualified_table_name(table) - sql = f"SELECT * FROM {table}" - with sql_client.execute_query(sql) as cursor: + table = sql_client.make_qualified_table_name(table) + query = f"SELECT * FROM {table}" + with sql_client.execute_query(query) as cursor: + cursor.columns = columns + yield cursor + + @contextmanager + def query_relation(self, *, query: str) -> Generator[SupportsReadableRelation, Any, Any]: + with self.sql_client as sql_client: + with sql_client.execute_query(query) as cursor: yield cursor def dataset(self) -> SupportsReadableDataset: - return ReadableDataset(self) + return ReadableDataset(self, self.schema) class SqlJobClientWithStagingDataset(SqlJobClientBase, WithStagingDataset): diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 69f95c5e1c..c21eb8d673 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -21,6 +21,7 @@ ) from dlt.common.typing import TFun +from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.utils import concat_strings_with_limit from dlt.common.destination.reference import JobClientBase @@ -362,25 +363,45 @@ def iter_fetchmany(self, chunk_size: int) -> Generator[List[Tuple[Any, ...]], An yield result def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]: - from dlt.common.libs.pandas_sql import _wrap_result + """Default implementation converts arrow to df""" + from dlt.common.libs.pandas import pandas as pd - columns = self._get_columns() + for table in self.iter_arrow(chunk_size=chunk_size): + # NOTE: we go via arrow table + # https://github.com/apache/arrow/issues/38644 for reference on types_mapper + yield table.to_pandas(types_mapper=pd.ArrowDtype) + + def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]: + """Default implementation converts query result to arrow table""" + from dlt.common.libs.pyarrow import table_schema_columns_to_py_arrow, pyarrow + + def _result_to_arrow_table( + result: List[Tuple[Any, ...]], columns: List[str], schema: pyarrow.schema + ) -> ArrowTable: + # TODO: it might be faster to creaty pyarrow arrays and create tables from them + pylist = [dict(zip(columns, t)) for t in result] + return ArrowTable.from_pylist(pylist, schema=schema) + + cursor_columns = self._get_columns() + + # we can create the arrow schema if columns are present + # TODO: when using this dataset as a source for a new pipeline, we should + # get the capabilities of the destination that it will end up it + arrow_schema = ( + table_schema_columns_to_py_arrow( + self.columns, caps=DestinationCapabilitiesContext.generic_capabilities() + ) + if self.columns + else None + ) - # if no chunk size, fetch all if not chunk_size: - yield _wrap_result(self.fetchall(), columns) + result = self.fetchall() + yield _result_to_arrow_table(result, cursor_columns, arrow_schema) return - # otherwise iterate over results in batch size chunks for result in self.iter_fetchmany(chunk_size=chunk_size): - # TODO: ensure that this is arrow backed - yield _wrap_result(result, columns, dtype_backend="pyarrow") - - def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]: - """Default implementation converts df to arrow""" - for df in self.iter_df(chunk_size=chunk_size): - # TODO: is this efficient? - yield ArrowTable.from_pandas(df) + yield _result_to_arrow_table(result, cursor_columns, arrow_schema) def raise_database_error(f: TFun) -> TFun: diff --git a/tests/load/test_read_interfaces.py b/tests/load/test_read_interfaces.py index 734e5506ca..2d78737c47 100644 --- a/tests/load/test_read_interfaces.py +++ b/tests/load/test_read_interfaces.py @@ -7,7 +7,12 @@ from typing import List from functools import reduce -from tests.load.utils import destinations_configs, DestinationTestConfiguration, AZ_BUCKET, ABFS_BUCKET +from tests.load.utils import ( + destinations_configs, + DestinationTestConfiguration, + AZ_BUCKET, + ABFS_BUCKET, +) from pandas import DataFrame