Skip to content

Commit

Permalink
add support for arrow schema creation from known dlt schema
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Sep 19, 2024
1 parent 8497036 commit 3dc2c90
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 72 deletions.
17 changes: 3 additions & 14 deletions dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,22 +320,11 @@ def _create_writer(self, schema: "pa.Schema") -> "pa.parquet.ParquetWriter":
)

def write_header(self, columns_schema: TTableSchemaColumns) -> None:
from dlt.common.libs.pyarrow import pyarrow, get_py_arrow_datatype
from dlt.common.libs.pyarrow import table_schema_columns_to_py_arrow

# build schema
self.schema = pyarrow.schema(
[
pyarrow.field(
name,
get_py_arrow_datatype(
schema_item,
self._caps,
self.timestamp_timezone,
),
nullable=is_nullable_column(schema_item),
)
for name, schema_item in columns_schema.items()
]
self.schema = table_schema_columns_to_py_arrow(
columns_schema, self._caps, self.timestamp_timezone
)
# find row items that are of the json type (could be abstracted out for use in other writers?)
self.nested_indices = [
Expand Down
16 changes: 12 additions & 4 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from dlt.common.exceptions import TerminalValueError
from dlt.common.metrics import LoadJobMetrics
from dlt.common.normalizers.naming import NamingConvention
from dlt.common.schema.typing import TTableSchemaColumns

from dlt.common.schema import Schema, TTableSchema, TSchemaTables

Expand Down Expand Up @@ -506,14 +507,17 @@ class DBApiCursor(SupportsReadableRelation):
native_cursor: "DBApiCursor"
"""Cursor implementation native to current destination"""

columns: TTableSchemaColumns
"""Known dlt table columns for this cursor"""

def execute(self, query: AnyStr, *args: Any, **kwargs: Any) -> None: ...
def close(self) -> None: ...


class SupportsReadableDataset(Protocol):
"""A readable dataset retrieved from a destination, has support for creating readable relations for a query or table"""

def query(self, sql: str) -> SupportsReadableRelation: ...
def query(self, query: str) -> SupportsReadableRelation: ...

def __getitem__(self, table: str) -> SupportsReadableRelation: ...

Expand Down Expand Up @@ -709,11 +713,15 @@ class WithReadableRelations(ABC):
"""Add support for getting readable reletions form a destination"""

@abstractmethod
def get_readable_relation(
def table_relation(
self, *, table: str, columns: TTableSchemaColumns
) -> ContextManager[SupportsReadableRelation]: ...

@abstractmethod
def query_relation(
self,
*,
table: str = None,
sql: str = None,
query: str,
) -> ContextManager[SupportsReadableRelation]: ...


Expand Down
30 changes: 30 additions & 0 deletions dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,36 @@ def py_arrow_to_table_schema_columns(schema: pyarrow.Schema) -> TTableSchemaColu
return result


def table_schema_columns_to_py_arrow(
columns: TTableSchemaColumns,
caps: DestinationCapabilitiesContext,
timestamp_timezone: str = "UTC",
) -> pyarrow.Schema:
"""Convert a table schema columns dict to a pyarrow schema.
Args:
columns (TTableSchemaColumns): table schema columns
Returns:
pyarrow.Schema: pyarrow schema
"""
return pyarrow.schema(
[
pyarrow.field(
name,
get_py_arrow_datatype(
schema_item,
caps,
timestamp_timezone,
),
nullable=schema_item.get("nullable", True),
)
for name, schema_item in columns.items()
]
)


def get_parquet_metadata(parquet_file: TFileOrPath) -> Tuple[int, pyarrow.Schema]:
"""Gets parquet file metadata (including row count and schema)
Expand Down
46 changes: 20 additions & 26 deletions dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
)

from dlt.destinations.typing import DataFrame, ArrowTable
from dlt.common.schema.typing import TTableSchemaColumns
from dlt.common.schema import Schema


class ReadableRelation(SupportsReadableRelation):
Expand All @@ -16,54 +18,44 @@ def __init__(
*,
client: WithReadableRelations,
table: str = None,
sql: str = None,
query: str = None,
columns: TTableSchemaColumns = None
) -> None:
"""Create a lazy evaluated relation to for the dataset of a destination"""
self.client = client
self.sql = sql
self.query = query
self.table = table
self.columns = columns

@contextmanager
def cursor(self) -> Generator[SupportsReadableRelation, Any, Any]:
"""Gets a DBApiCursor for the current relation"""
with self.client.get_readable_relation(sql=self.sql, table=self.table) as cursor:
with self.client.table_relation(table=self.table, columns=self.columns) as cursor:
yield cursor

def df(
self,
chunk_size: int = None,
) -> Optional[DataFrame]:
def df(self, chunk_size: int = None) -> Optional[DataFrame]:
"""Get first batch of table as dataframe"""
with self.cursor() as cursor:
return cursor.df(chunk_size=chunk_size)

def arrow(
self,
chunk_size: int = None,
) -> Optional[ArrowTable]:
def arrow(self, chunk_size: int = None) -> Optional[ArrowTable]:
"""Get first batch of table as arrow table"""
with self.cursor() as cursor:
return cursor.arrow(chunk_size=chunk_size)

def iter_df(
self,
chunk_size: int,
self, chunk_size: int, columns: TTableSchemaColumns = None
) -> Generator[DataFrame, None, None]:
"""iterates over the whole table in dataframes of the given chunk_size"""
with self.cursor() as cursor:
yield from cursor.iter_df(
chunk_size=chunk_size,
)
yield from cursor.iter_df(chunk_size=chunk_size)

def iter_arrow(
self,
chunk_size: int,
self, chunk_size: int, columns: TTableSchemaColumns = None
) -> Generator[ArrowTable, None, None]:
"""iterates over the whole table in arrow tables of the given chunk_size"""
with self.cursor() as cursor:
yield from cursor.iter_arrow(
chunk_size=chunk_size,
)
yield from cursor.iter_arrow(chunk_size=chunk_size)

def fetchall(self) -> List[Tuple[Any, ...]]:
"""does a dbapi fetch all"""
Expand All @@ -89,16 +81,18 @@ def fetchone(self) -> Optional[Tuple[Any, ...]]:
class ReadableDataset(SupportsReadableDataset):
"""Access to dataframes and arrowtables in the destination dataset"""

def __init__(self, client: WithReadableRelations) -> None:
def __init__(self, client: WithReadableRelations, schema: Schema) -> None:
self.client = client
self.schema = schema

def query(self, sql: str) -> SupportsReadableRelation:
return ReadableRelation(client=self.client, sql=sql)
def query(self, query: str) -> SupportsReadableRelation:
return ReadableRelation(client=self.client, query=query)

def __getitem__(self, table: str) -> SupportsReadableRelation:
"""access of table via dict notation"""
return ReadableRelation(client=self.client, table=table)
table_columns = self.schema.tables[table]["columns"]
return ReadableRelation(client=self.client, table=table, columns=table_columns)

def __getattr__(self, table: str) -> SupportsReadableRelation:
"""access of table via property notation"""
return ReadableRelation(client=self.client, table=table)
return self[table]
1 change: 0 additions & 1 deletion dlt/destinations/impl/duckdb/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]:
yield df

def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]:
# full table
if not chunk_size:
yield self.native_cursor.fetch_arrow_table()
return
Expand Down
17 changes: 11 additions & 6 deletions dlt/destinations/impl/filesystem/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from contextlib import contextmanager
from types import TracebackType
from typing import (
ContextManager,
List,
Type,
Iterable,
Expand Down Expand Up @@ -41,6 +42,7 @@
from dlt.common.destination.reference import (
FollowupJobRequest,
PreparedTableSchema,
SupportsReadableRelation,
TLoadJobState,
RunnableLoadJob,
JobClientBase,
Expand Down Expand Up @@ -697,14 +699,17 @@ def create_table_chain_completed_followup_jobs(
return jobs

@contextmanager
def get_readable_relation(
self, *, table: str = None, sql: str = None
def table_relation(
self, *, table: str, columns: TTableSchemaColumns
) -> Generator[DBApiCursor, Any, Any]:
if table:
sql = f"SELECT * FROM {table}"
with self.sql_client.execute_query(f"SELECT * FROM {table}") as cursor:
cursor.columns = columns
yield cursor

with self.sql_client.execute_query(sql) as cursor:
@contextmanager
def query_relation(self, *, query: str) -> Generator[DBApiCursor, Any, Any]:
with self.sql_client.execute_query(query) as cursor:
yield cursor

def dataset(self) -> SupportsReadableDataset:
return ReadableDataset(self)
return ReadableDataset(self, self.schema)
20 changes: 13 additions & 7 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,18 +714,24 @@ def _set_query_tags_for_job(self, load_id: str, table: PreparedTableSchema) -> N
)

@contextmanager
def get_readable_relation(
self, *, table: str = None, sql: str = None
def table_relation(
self, *, table: str, columns: TTableSchemaColumns
) -> Generator[SupportsReadableRelation, Any, Any]:
with self.sql_client as sql_client:
if not sql:
table = sql_client.make_qualified_table_name(table)
sql = f"SELECT * FROM {table}"
with sql_client.execute_query(sql) as cursor:
table = sql_client.make_qualified_table_name(table)
query = f"SELECT * FROM {table}"
with sql_client.execute_query(query) as cursor:
cursor.columns = columns
yield cursor

@contextmanager
def query_relation(self, *, query: str) -> Generator[SupportsReadableRelation, Any, Any]:
with self.sql_client as sql_client:
with sql_client.execute_query(query) as cursor:
yield cursor

def dataset(self) -> SupportsReadableDataset:
return ReadableDataset(self)
return ReadableDataset(self, self.schema)


class SqlJobClientWithStagingDataset(SqlJobClientBase, WithStagingDataset):
Expand Down
47 changes: 34 additions & 13 deletions dlt/destinations/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)

from dlt.common.typing import TFun
from dlt.common.schema.typing import TTableSchemaColumns
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.utils import concat_strings_with_limit
from dlt.common.destination.reference import JobClientBase
Expand Down Expand Up @@ -362,25 +363,45 @@ def iter_fetchmany(self, chunk_size: int) -> Generator[List[Tuple[Any, ...]], An
yield result

def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]:
from dlt.common.libs.pandas_sql import _wrap_result
"""Default implementation converts arrow to df"""
from dlt.common.libs.pandas import pandas as pd

columns = self._get_columns()
for table in self.iter_arrow(chunk_size=chunk_size):
# NOTE: we go via arrow table
# https://github.com/apache/arrow/issues/38644 for reference on types_mapper
yield table.to_pandas(types_mapper=pd.ArrowDtype)

def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]:
"""Default implementation converts query result to arrow table"""
from dlt.common.libs.pyarrow import table_schema_columns_to_py_arrow, pyarrow

def _result_to_arrow_table(
result: List[Tuple[Any, ...]], columns: List[str], schema: pyarrow.schema
) -> ArrowTable:
# TODO: it might be faster to creaty pyarrow arrays and create tables from them
pylist = [dict(zip(columns, t)) for t in result]
return ArrowTable.from_pylist(pylist, schema=schema)

cursor_columns = self._get_columns()

# we can create the arrow schema if columns are present
# TODO: when using this dataset as a source for a new pipeline, we should
# get the capabilities of the destination that it will end up it
arrow_schema = (
table_schema_columns_to_py_arrow(
self.columns, caps=DestinationCapabilitiesContext.generic_capabilities()
)
if self.columns
else None
)

# if no chunk size, fetch all
if not chunk_size:
yield _wrap_result(self.fetchall(), columns)
result = self.fetchall()
yield _result_to_arrow_table(result, cursor_columns, arrow_schema)
return

# otherwise iterate over results in batch size chunks
for result in self.iter_fetchmany(chunk_size=chunk_size):
# TODO: ensure that this is arrow backed
yield _wrap_result(result, columns, dtype_backend="pyarrow")

def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]:
"""Default implementation converts df to arrow"""
for df in self.iter_df(chunk_size=chunk_size):
# TODO: is this efficient?
yield ArrowTable.from_pandas(df)
yield _result_to_arrow_table(result, cursor_columns, arrow_schema)


def raise_database_error(f: TFun) -> TFun:
Expand Down
7 changes: 6 additions & 1 deletion tests/load/test_read_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from typing import List
from functools import reduce

from tests.load.utils import destinations_configs, DestinationTestConfiguration, AZ_BUCKET, ABFS_BUCKET
from tests.load.utils import (
destinations_configs,
DestinationTestConfiguration,
AZ_BUCKET,
ABFS_BUCKET,
)
from pandas import DataFrame


Expand Down

0 comments on commit 3dc2c90

Please sign in to comment.