Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sqlalchemy staging dataset support and docs #1841

Merged
merged 5 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions dlt/destinations/impl/sqlalchemy/db_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
Iterator,
Any,
Sequence,
ContextManager,
AnyStr,
Union,
Tuple,
List,
Dict,
Set,
)
from contextlib import contextmanager
from functools import wraps
Expand All @@ -19,6 +19,7 @@
from sqlalchemy.engine import Connection

from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import PreparedTableSchema
from dlt.destinations.exceptions import (
DatabaseUndefinedRelation,
DatabaseTerminalException,
Expand Down Expand Up @@ -122,6 +123,8 @@ def __init__(
self._current_connection: Optional[Connection] = None
self._current_transaction: Optional[SqlaTransactionWrapper] = None
self.metadata = sa.MetaData()
# Keep a list of datasets already attached on the current connection
self._sqlite_attached_datasets: Set[str] = set()

@property
def engine(self) -> sa.engine.Engine:
Expand Down Expand Up @@ -155,6 +158,7 @@ def close_connection(self) -> None:
self._current_connection.close()
self.engine.dispose()
finally:
self._sqlite_attached_datasets.clear()
self._current_connection = None
self._current_transaction = None

Expand Down Expand Up @@ -234,13 +238,17 @@ def _sqlite_create_dataset(self, dataset_name: str) -> None:
"""Mimic multiple schemas in sqlite using ATTACH DATABASE to
attach a new database file to the current connection.
"""
if dataset_name == "main":
# main always exists
return
if self._sqlite_is_memory_db():
new_db_fn = ":memory:"
else:
new_db_fn = self._sqlite_dataset_filename(dataset_name)

statement = "ATTACH DATABASE :fn AS :name"
self.execute_sql(statement, fn=new_db_fn, name=dataset_name)
self._sqlite_attached_datasets.add(dataset_name)

def _sqlite_drop_dataset(self, dataset_name: str) -> None:
"""Drop a dataset in sqlite by detaching the database file
Expand All @@ -252,13 +260,23 @@ def _sqlite_drop_dataset(self, dataset_name: str) -> None:
if dataset_name != "main": # main is the default database, it cannot be detached
statement = "DETACH DATABASE :name"
self.execute_sql(statement, name=dataset_name)
self._sqlite_attached_datasets.discard(dataset_name)

fn = dbs[dataset_name]
if not fn: # It's a memory database, nothing to do
return
# Delete the database file
Path(fn).unlink()

@contextmanager
def with_alternative_dataset_name(
self, dataset_name: str
) -> Iterator[SqlClientBase[Connection]]:
with super().with_alternative_dataset_name(dataset_name):
if self.dialect_name == "sqlite" and dataset_name not in self._sqlite_attached_datasets:
self._sqlite_reattach_dataset_if_exists(dataset_name)
yield self

def create_dataset(self) -> None:
if self.dialect_name == "sqlite":
return self._sqlite_create_dataset(self.dataset_name)
Expand Down Expand Up @@ -332,8 +350,10 @@ def make_qualified_table_name(self, table_name: str, escape: bool = True) -> str

def fully_qualified_dataset_name(self, escape: bool = True, staging: bool = False) -> str:
if staging:
raise NotImplementedError("Staging not supported")
return self.dialect.identifier_preparer.format_schema(self.dataset_name) # type: ignore[attr-defined, no-any-return]
dataset_name = self.staging_dataset_name
else:
dataset_name = self.dataset_name
return self.dialect.identifier_preparer.format_schema(dataset_name) # type: ignore[attr-defined, no-any-return]

def alter_table_add_columns(self, columns: Sequence[sa.Column]) -> None:
if not columns:
Expand Down
5 changes: 4 additions & 1 deletion dlt/destinations/impl/sqlalchemy/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
if t.TYPE_CHECKING:
# from dlt.destinations.impl.sqlalchemy.sqlalchemy_client import SqlalchemyJobClient
from dlt.destinations.impl.sqlalchemy.sqlalchemy_job_client import SqlalchemyJobClient
from sqlalchemy.engine import Engine


class sqlalchemy(Destination[SqlalchemyClientConfiguration, "SqlalchemyJobClient"]):
Expand All @@ -45,6 +46,8 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext:
caps.supports_ddl_transactions = True
caps.max_query_parameters = 20_0000
caps.max_rows_per_insert = 10_000 # Set a default to avoid OOM on large datasets
# Multiple concatenated statements are not supported by all engines, so leave them off by default
caps.supports_multiple_statements = False
caps.type_mapper = SqlalchemyTypeMapper

return caps
Expand Down Expand Up @@ -74,7 +77,7 @@ def client_class(self) -> t.Type["SqlalchemyJobClient"]:

def __init__(
self,
credentials: t.Union[SqlalchemyCredentials, t.Dict[str, t.Any], str] = None,
credentials: t.Union[SqlalchemyCredentials, t.Dict[str, t.Any], str, "Engine"] = None,
destination_name: t.Optional[str] = None,
environment: t.Optional[str] = None,
engine_args: t.Optional[t.Dict[str, t.Any]] = None,
Expand Down
136 changes: 136 additions & 0 deletions dlt/destinations/impl/sqlalchemy/load_jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import IO, Any, Dict, Iterator, List, Sequence, TYPE_CHECKING, Optional
import math

import sqlalchemy as sa

from dlt.common.destination.reference import (
RunnableLoadJob,
HasFollowupJobs,
PreparedTableSchema,
)
from dlt.common.storages import FileStorage
from dlt.common.json import json, PY_DATETIME_DECODERS
from dlt.destinations.sql_jobs import SqlFollowupJob, SqlJobParams

from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient

if TYPE_CHECKING:
from dlt.destinations.impl.sqlalchemy.sqlalchemy_job_client import SqlalchemyJobClient


class SqlalchemyJsonLInsertJob(RunnableLoadJob, HasFollowupJobs):
def __init__(self, file_path: str, table: sa.Table) -> None:
super().__init__(file_path)
self._job_client: "SqlalchemyJobClient" = None
self.table = table

def _open_load_file(self) -> IO[bytes]:
return FileStorage.open_zipsafe_ro(self._file_path, "rb")

def _iter_data_items(self) -> Iterator[Dict[str, Any]]:
all_cols = {col.name: None for col in self.table.columns}
with FileStorage.open_zipsafe_ro(self._file_path, "rb") as f:
for line in f:
# Decode date/time to py datetime objects. Some drivers have issues with pendulum objects
for item in json.typed_loadb(line, decoders=PY_DATETIME_DECODERS):
# Fill any missing columns in item with None. Bulk insert fails when items have different keys
if item.keys() != all_cols.keys():
yield {**all_cols, **item}
else:
yield item

def _iter_data_item_chunks(self) -> Iterator[Sequence[Dict[str, Any]]]:
max_rows = self._job_client.capabilities.max_rows_per_insert or math.inf
# Limit by max query length should not be needed,
# bulk insert generates an INSERT template with a single VALUES tuple of placeholders
# If any dialects don't do that we need to check the str length of the query
# TODO: Max params may not be needed. Limits only apply to placeholders in sql string (mysql/sqlite)
max_params = self._job_client.capabilities.max_query_parameters or math.inf
chunk: List[Dict[str, Any]] = []
params_count = 0
for item in self._iter_data_items():
if len(chunk) + 1 == max_rows or params_count + len(item) > max_params:
# Rotate chunk
yield chunk
chunk = []
params_count = 0
params_count += len(item)
chunk.append(item)

if chunk:
yield chunk

def run(self) -> None:
_sql_client = self._job_client.sql_client
# Copy the table to the current dataset (i.e. staging) if needed
# This is a no-op if the table is already in the correct schema
table = self.table.to_metadata(
self.table.metadata, schema=_sql_client.dataset_name # type: ignore[attr-defined]
)

with _sql_client.begin_transaction():
for chunk in self._iter_data_item_chunks():
_sql_client.execute_sql(table.insert(), chunk)


class SqlalchemyParquetInsertJob(SqlalchemyJsonLInsertJob):
def _iter_data_item_chunks(self) -> Iterator[Sequence[Dict[str, Any]]]:
from dlt.common.libs.pyarrow import ParquetFile

num_cols = len(self.table.columns)
max_rows = self._job_client.capabilities.max_rows_per_insert or None
max_params = self._job_client.capabilities.max_query_parameters or None
read_limit = None

with ParquetFile(self._file_path) as reader:
if max_params is not None:
read_limit = math.floor(max_params / num_cols)

if max_rows is not None:
if read_limit is None:
read_limit = max_rows
else:
read_limit = min(read_limit, max_rows)

if read_limit is None:
yield reader.read().to_pylist()
return

for chunk in reader.iter_batches(batch_size=read_limit):
yield chunk.to_pylist()


class SqlalchemyStagingCopyJob(SqlFollowupJob):
@classmethod
def generate_sql(
cls,
table_chain: Sequence[PreparedTableSchema],
sql_client: SqlalchemyClient, # type: ignore[override]
params: Optional[SqlJobParams] = None,
) -> List[str]:
statements: List[str] = []
for table in table_chain:
# Tables must have already been created in metadata
table_obj = sql_client.get_existing_table(table["name"])
staging_table_obj = table_obj.to_metadata(
sql_client.metadata, schema=sql_client.staging_dataset_name
)
if params["replace"]:
stmt = str(table_obj.delete().compile(dialect=sql_client.dialect))
if not stmt.endswith(";"):
stmt += ";"
statements.append(stmt)

stmt = str(
table_obj.insert()
.from_select(
[col.name for col in staging_table_obj.columns], staging_table_obj.select()
)
.compile(dialect=sql_client.dialect)
)
if not stmt.endswith(";"):
stmt += ";"

statements.append(stmt)

return statements
Loading
Loading