Skip to content

Commit

Permalink
make column code a bit clearer and fix mssql again
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Sep 23, 2024
1 parent 10e04d6 commit 5077ce1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
5 changes: 5 additions & 0 deletions dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def __init__(
def cursor(self) -> Generator[SupportsReadableRelation, Any, Any]:
"""Gets a DBApiCursor for the current relation"""
with self.client as client:
# this hacky code is needed for mssql to disable autocommit, read iterators
# will not work otherwise. in the future we should be able to create a readony
# client which will do this automatically
if hasattr(self.client, "_conn") and hasattr(self.client._conn, "autocommit"):
self.client._conn.autocommit = False
with client.execute_query(self.query) as cursor:
cursor.schema_columns = self.schema_columns
yield cursor
Expand Down
4 changes: 2 additions & 2 deletions dlt/destinations/impl/sqlalchemy/db_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def _wrap(self: "SqlalchemyClient", *args: Any, **kwargs: Any) -> Any:

class SqlaDbApiCursor(DBApiCursorImpl):
def __init__(self, curr: sa.engine.CursorResult) -> None:
self.schema_columns = None

# Sqlalchemy CursorResult is *mostly* compatible with DB-API cursor
self.native_cursor = curr # type: ignore[assignment]
curr.columns
Expand All @@ -81,6 +79,8 @@ def __init__(self, curr: sa.engine.CursorResult) -> None:
self.fetchone = curr.fetchone # type: ignore[assignment]
self.fetchmany = curr.fetchmany # type: ignore[assignment]

self.set_default_schema_columns()

def _get_columns(self) -> List[str]:
return list(self.native_cursor.keys()) # type: ignore[attr-defined]

Expand Down
19 changes: 10 additions & 9 deletions dlt/destinations/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,20 +322,26 @@ class DBApiCursorImpl(DBApiCursor):

def __init__(self, curr: DBApiCursor) -> None:
self.native_cursor = curr
self.schema_columns = None

# wire protocol methods
self.execute = curr.execute # type: ignore
self.fetchall = curr.fetchall # type: ignore
self.fetchmany = curr.fetchmany # type: ignore
self.fetchone = curr.fetchone # type: ignore

self.set_default_schema_columns()

def __getattr__(self, name: str) -> Any:
return getattr(self.native_cursor, name)

def _get_columns(self) -> List[str]:
return [c[0] for c in self.native_cursor.description]

def set_default_schema_columns(self) -> None:
self.schema_columns = cast(
TTableSchemaColumns, {c: {"name": c, "nullable": True} for c in self._get_columns()}
)

def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]:
"""Fetches results as data frame in full or in specified chunks.
Expand Down Expand Up @@ -371,7 +377,7 @@ def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]:
for table in self.iter_arrow(chunk_size=chunk_size):
# NOTE: we go via arrow table, types are created for arrow is columns are known
# https://github.com/apache/arrow/issues/38644 for reference on types_mapper
yield table.to_pandas(types_mapper=pd.ArrowDtype)
yield table.to_pandas()

def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]:
"""Default implementation converts query result to arrow table"""
Expand All @@ -380,18 +386,13 @@ def iter_arrow(self, chunk_size: int) -> Generator[ArrowTable, None, None]:
# if loading to a specific pipeline, it would be nice to have the correct caps here
caps = DestinationCapabilitiesContext.generic_capabilities()

# provide default columns in case not known
columns = self.schema_columns or cast(
TTableSchemaColumns, {c: {"name": c, "nullable": True} for c in self._get_columns()}
)

if not chunk_size:
result = self.fetchall()
yield row_tuples_to_arrow(result, caps, columns, tz="UTC")
yield row_tuples_to_arrow(result, caps, self.schema_columns, tz="UTC")
return

for result in self.iter_fetchmany(chunk_size=chunk_size):
yield row_tuples_to_arrow(result, caps, columns, tz="UTC")
yield row_tuples_to_arrow(result, caps, self.schema_columns, tz="UTC")


def raise_database_error(f: TFun) -> TFun:
Expand Down

0 comments on commit 5077ce1

Please sign in to comment.