From 3f06e1e2e079f0894f090b20c23bf001123e9fca Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 11 Sep 2024 12:15:58 -0400 Subject: [PATCH 01/37] Implement sqlalchemy loader Begin implementing sqlalchemy loader SQLA load job, factory, schema storage, POC sqlalchemy tests attempt Implement SqlJobClient interface Parquet load, some tests running on mysql update lockfile Limit bulk insert chunk size, sqlite create/drop schema, fixes Generate schema update Get more tests running with mysql More tests passing Fix state, schema restore --- dlt/common/destination/capabilities.py | 6 + dlt/common/destination/reference.py | 17 +- dlt/common/json/__init__.py | 83 +-- dlt/common/json/_orjson.py | 12 +- dlt/common/json/_simplejson.py | 12 +- dlt/common/libs/pyarrow.py | 1 + dlt/common/time.py | 8 + dlt/common/utils.py | 14 +- dlt/destinations/__init__.py | 2 + dlt/destinations/impl/sqlalchemy/__init__.py | 0 .../impl/sqlalchemy/configuration.py | 58 ++ .../impl/sqlalchemy/db_api_client.py | 381 +++++++++++++ dlt/destinations/impl/sqlalchemy/factory.py | 90 ++++ .../impl/sqlalchemy/sqlalchemy_job_client.py | 503 ++++++++++++++++++ dlt/destinations/sql_client.py | 17 + pyproject.toml | 1 + tests/cases.py | 17 +- tests/load/pipeline/test_arrow_loading.py | 8 +- tests/load/pipeline/test_merge_disposition.py | 4 +- tests/load/pipeline/test_pipelines.py | 4 +- tests/load/test_job_client.py | 55 +- tests/load/test_sql_client.py | 12 +- tests/load/utils.py | 8 +- tests/utils.py | 1 + 24 files changed, 1238 insertions(+), 76 deletions(-) create mode 100644 dlt/destinations/impl/sqlalchemy/__init__.py create mode 100644 dlt/destinations/impl/sqlalchemy/configuration.py create mode 100644 dlt/destinations/impl/sqlalchemy/db_api_client.py create mode 100644 dlt/destinations/impl/sqlalchemy/factory.py create mode 100644 dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index eed1d6189e..722c82c176 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -176,6 +176,12 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): loader_parallelism_strategy: Optional[TLoaderParallelismStrategy] = None """The destination can override the parallelism strategy""" + max_query_parameters: Optional[int] = None + """The maximum number of parameters that can be supplied in a single parametrized query""" + + supports_native_boolean: bool = True + """The destination supports a native boolean type, otherwise bool columns are usually stored as integers""" + def generates_case_sensitive_identifiers(self) -> bool: """Tells if capabilities as currently adjusted, will generate case sensitive identifiers""" # must have case sensitive support and folding function must preserve casing diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 05ea5f3515..b394a15e60 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -86,6 +86,20 @@ def from_normalized_mapping( schema=normalized_doc[naming_convention.normalize_identifier("schema")], ) + def to_normalized_mapping(self, naming_convention: NamingConvention) -> Dict[str, Any]: + """Convert this instance to mapping where keys are normalized according to given naming convention + + Args: + naming_convention: Naming convention that should be used to normalize keys + + Returns: + Dict[str, Any]: Mapping with normalized keys (e.g. {Version: ..., SchemaName: ...}) + """ + return { + naming_convention.normalize_identifier(key): value + for key, value in self._asdict().items() + } + @dataclasses.dataclass class StateInfo: @@ -379,6 +393,7 @@ def run_managed( self.run() self._state = "completed" except (DestinationTerminalException, TerminalValueError) as e: + logger.exception(f"Job {self.job_id()} failed terminally") self._state = "failed" self._exception = e logger.exception(f"Terminal exception in job {self.job_id()} in file {self._file_path}") @@ -439,7 +454,7 @@ def __init__( self.capabilities = capabilities @abstractmethod - def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: + def initialize_storage(self, truncate_tables: Optional[Iterable[str]] = None) -> None: """Prepares storage to be used ie. creates database schema or file system folder. Truncates requested tables.""" pass diff --git a/dlt/common/json/__init__.py b/dlt/common/json/__init__.py index 72ab453cbf..fe762cdf11 100644 --- a/dlt/common/json/__init__.py +++ b/dlt/common/json/__init__.py @@ -19,35 +19,7 @@ from dlt.common.utils import map_nested_in_place -class SupportsJson(Protocol): - """Minimum adapter for different json parser implementations""" - - _impl_name: str - """Implementation name""" - - def dump( - self, obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty: bool = False - ) -> None: ... - - def typed_dump(self, obj: Any, fp: IO[bytes], pretty: bool = False) -> None: ... - - def typed_dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ... - - def typed_loads(self, s: str) -> Any: ... - - def typed_dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ... - - def typed_loadb(self, s: Union[bytes, bytearray, memoryview]) -> Any: ... - - def dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ... - - def dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ... - - def load(self, fp: Union[IO[bytes], IO[str]]) -> Any: ... - - def loads(self, s: str) -> Any: ... - - def loadb(self, s: Union[bytes, bytearray, memoryview]) -> Any: ... +TPuaDecoders = List[Callable[[Any], Any]] def custom_encode(obj: Any) -> str: @@ -104,7 +76,7 @@ def _datetime_decoder(obj: str) -> datetime: # define decoder for each prefix -DECODERS: List[Callable[[Any], Any]] = [ +DECODERS: TPuaDecoders = [ Decimal, _datetime_decoder, pendulum.Date.fromisoformat, @@ -114,6 +86,11 @@ def _datetime_decoder(obj: str) -> datetime: Wei, pendulum.Time.fromisoformat, ] +# Alternate decoders that decode date/time/datetime to stdlib types instead of pendulum +PY_DATETIME_DECODERS = list(DECODERS) +PY_DATETIME_DECODERS[1] = datetime.fromisoformat +PY_DATETIME_DECODERS[2] = date.fromisoformat +PY_DATETIME_DECODERS[7] = time.fromisoformat # how many decoders? PUA_CHARACTER_MAX = len(DECODERS) @@ -151,13 +128,13 @@ def custom_pua_encode(obj: Any) -> str: raise TypeError(repr(obj) + " is not JSON serializable") -def custom_pua_decode(obj: Any) -> Any: +def custom_pua_decode(obj: Any, decoders: TPuaDecoders = DECODERS) -> Any: if isinstance(obj, str) and len(obj) > 1: c = ord(obj[0]) - PUA_START # decode only the PUA space defined in DECODERS if c >= 0 and c <= PUA_CHARACTER_MAX: try: - return DECODERS[c](obj[1:]) + return decoders[c](obj[1:]) except Exception: # return strings that cannot be parsed # this may be due @@ -167,11 +144,11 @@ def custom_pua_decode(obj: Any) -> Any: return obj -def custom_pua_decode_nested(obj: Any) -> Any: +def custom_pua_decode_nested(obj: Any, decoders: TPuaDecoders = DECODERS) -> Any: if isinstance(obj, str): - return custom_pua_decode(obj) + return custom_pua_decode(obj, decoders) elif isinstance(obj, (list, dict)): - return map_nested_in_place(custom_pua_decode, obj) + return map_nested_in_place(custom_pua_decode, obj, decoders=decoders) return obj @@ -190,6 +167,39 @@ def may_have_pua(line: bytes) -> bool: return PUA_START_UTF8_MAGIC in line +class SupportsJson(Protocol): + """Minimum adapter for different json parser implementations""" + + _impl_name: str + """Implementation name""" + + def dump( + self, obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty: bool = False + ) -> None: ... + + def typed_dump(self, obj: Any, fp: IO[bytes], pretty: bool = False) -> None: ... + + def typed_dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ... + + def typed_loads(self, s: str) -> Any: ... + + def typed_dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ... + + def typed_loadb( + self, s: Union[bytes, bytearray, memoryview], decoders: TPuaDecoders = DECODERS + ) -> Any: ... + + def dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ... + + def dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ... + + def load(self, fp: Union[IO[bytes], IO[str]]) -> Any: ... + + def loads(self, s: str) -> Any: ... + + def loadb(self, s: Union[bytes, bytearray, memoryview]) -> Any: ... + + # pick the right impl json: SupportsJson = None if os.environ.get(known_env.DLT_USE_JSON) == "simplejson": @@ -216,4 +226,7 @@ def may_have_pua(line: bytes) -> bool: "custom_pua_remove", "SupportsJson", "may_have_pua", + "TPuaDecoders", + "DECODERS", + "PY_DATETIME_DECODERS", ] diff --git a/dlt/common/json/_orjson.py b/dlt/common/json/_orjson.py index d2d960e6ce..d066ffe875 100644 --- a/dlt/common/json/_orjson.py +++ b/dlt/common/json/_orjson.py @@ -1,7 +1,13 @@ from typing import IO, Any, Union import orjson -from dlt.common.json import custom_pua_encode, custom_pua_decode_nested, custom_encode +from dlt.common.json import ( + custom_pua_encode, + custom_pua_decode_nested, + custom_encode, + TPuaDecoders, + DECODERS, +) from dlt.common.typing import AnyFun _impl_name = "orjson" @@ -38,8 +44,8 @@ def typed_loads(s: str) -> Any: return custom_pua_decode_nested(loads(s)) -def typed_loadb(s: Union[bytes, bytearray, memoryview]) -> Any: - return custom_pua_decode_nested(loadb(s)) +def typed_loadb(s: Union[bytes, bytearray, memoryview], decoders: TPuaDecoders = DECODERS) -> Any: + return custom_pua_decode_nested(loadb(s), decoders) def dumps(obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: diff --git a/dlt/common/json/_simplejson.py b/dlt/common/json/_simplejson.py index 10ee17e2f6..e5adcc7120 100644 --- a/dlt/common/json/_simplejson.py +++ b/dlt/common/json/_simplejson.py @@ -4,7 +4,13 @@ import simplejson import platform -from dlt.common.json import custom_pua_encode, custom_pua_decode_nested, custom_encode +from dlt.common.json import ( + custom_pua_encode, + custom_pua_decode_nested, + custom_encode, + TPuaDecoders, + DECODERS, +) if platform.python_implementation() == "PyPy": # disable speedups on PyPy, it can be actually faster than Python C @@ -73,8 +79,8 @@ def typed_dumpb(obj: Any, sort_keys: bool = False, pretty: bool = False) -> byte return typed_dumps(obj, sort_keys, pretty).encode("utf-8") -def typed_loadb(s: Union[bytes, bytearray, memoryview]) -> Any: - return custom_pua_decode_nested(loadb(s)) +def typed_loadb(s: Union[bytes, bytearray, memoryview], decoders: TPuaDecoders = DECODERS) -> Any: + return custom_pua_decode_nested(loadb(s), decoders) def dumps(obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 51d67275af..adba832c43 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -30,6 +30,7 @@ import pyarrow.parquet import pyarrow.compute import pyarrow.dataset + from pyarrow.parquet import ParquetFile except ModuleNotFoundError: raise MissingDependencyException( "dlt pyarrow helpers", diff --git a/dlt/common/time.py b/dlt/common/time.py index 8532f566b8..26de0b5645 100644 --- a/dlt/common/time.py +++ b/dlt/common/time.py @@ -143,6 +143,14 @@ def ensure_pendulum_time(value: Union[str, datetime.time]) -> pendulum.Time: return result else: raise ValueError(f"{value} is not a valid ISO time string.") + elif isinstance(value, timedelta): + # Assume timedelta is seconds passed since midnight. Some drivers (mysqlclient) return time in this format + return pendulum.time( + value.seconds // 3600, + (value.seconds // 60) % 60, + value.seconds % 60, + value.microseconds, + ) raise TypeError(f"Cannot coerce {value} to a pendulum.Time object.") diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 37b644c0b5..9980e725ee 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -282,8 +282,10 @@ def clone_dict_nested(src: TDict) -> TDict: return update_dict_nested({}, src, copy_src_dicts=True) # type: ignore[return-value] -def map_nested_in_place(func: AnyFun, _nested: TAny) -> TAny: - """Applies `func` to all elements in `_dict` recursively, replacing elements in nested dictionaries and lists in place.""" +def map_nested_in_place(func: AnyFun, _nested: TAny, *args: Any, **kwargs: Any) -> TAny: + """Applies `func` to all elements in `_dict` recursively, replacing elements in nested dictionaries and lists in place. + Additional `*args` and `**kwargs` are passed to `func`. + """ if isinstance(_nested, tuple): if hasattr(_nested, "_asdict"): _nested = _nested._asdict() @@ -293,15 +295,15 @@ def map_nested_in_place(func: AnyFun, _nested: TAny) -> TAny: if isinstance(_nested, dict): for k, v in _nested.items(): if isinstance(v, (dict, list, tuple)): - _nested[k] = map_nested_in_place(func, v) + _nested[k] = map_nested_in_place(func, v, *args, **kwargs) else: - _nested[k] = func(v) + _nested[k] = func(v, *args, **kwargs) elif isinstance(_nested, list): for idx, _l in enumerate(_nested): if isinstance(_l, (dict, list, tuple)): - _nested[idx] = map_nested_in_place(func, _l) + _nested[idx] = map_nested_in_place(func, _l, *args, **kwargs) else: - _nested[idx] = func(_l) + _nested[idx] = func(_l, *args, **kwargs) else: raise ValueError(_nested, "Not a nested type") return _nested diff --git a/dlt/destinations/__init__.py b/dlt/destinations/__init__.py index 0546d16bcd..a856f574d8 100644 --- a/dlt/destinations/__init__.py +++ b/dlt/destinations/__init__.py @@ -16,6 +16,7 @@ from dlt.destinations.impl.databricks.factory import databricks from dlt.destinations.impl.dremio.factory import dremio from dlt.destinations.impl.clickhouse.factory import clickhouse +from dlt.destinations.impl.sqlalchemy.factory import sqlalchemy __all__ = [ @@ -37,4 +38,5 @@ "dremio", "clickhouse", "destination", + "sqlalchemy", ] diff --git a/dlt/destinations/impl/sqlalchemy/__init__.py b/dlt/destinations/impl/sqlalchemy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dlt/destinations/impl/sqlalchemy/configuration.py b/dlt/destinations/impl/sqlalchemy/configuration.py new file mode 100644 index 0000000000..0236f84086 --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/configuration.py @@ -0,0 +1,58 @@ +from typing import TYPE_CHECKING, Optional, Any, Final, Type, Dict +import dataclasses + +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import ConnectionStringCredentials +from dlt.common.destination.reference import DestinationClientDwhConfiguration + +if TYPE_CHECKING: + from sqlalchemy.engine import Engine, Dialect + + +@configspec +class SqlalchemyCredentials(ConnectionStringCredentials): + if TYPE_CHECKING: + _engine: Optional["Engine"] = None + + username: Optional[str] = None # e.g. sqlite doesn't need username + + def parse_native_representation(self, native_value: Any) -> None: + from sqlalchemy.engine import Engine + + if isinstance(native_value, Engine): + self.engine = native_value + super().parse_native_representation( + native_value.url.render_as_string(hide_password=False) + ) + else: + super().parse_native_representation(native_value) + + @property + def engine(self) -> Optional["Engine"]: + return getattr(self, "_engine", None) # type: ignore[no-any-return] + + @engine.setter + def engine(self, value: "Engine") -> None: + self._engine = value + + def get_dialect(self) -> Optional[Type["Dialect"]]: + if not self.drivername: + return None + # Type-ignore because of ported URL class has no get_dialect method, + # but here sqlalchemy should be available + if engine := self.engine: + return type(engine.dialect) + return self.to_url().get_dialect() # type: ignore[attr-defined,no-any-return] + + +@configspec +class SqlalchemyClientConfiguration(DestinationClientDwhConfiguration): + destination_type: Final[str] = dataclasses.field(default="sqlalchemy", init=False, repr=False, compare=False) # type: ignore + credentials: SqlalchemyCredentials = None + """SQLAlchemy connection string""" + + engine_args: Dict[str, Any] = dataclasses.field(default_factory=dict) + """Additional arguments passed to `sqlalchemy.create_engine`""" + + def get_dialect(self) -> Type["Dialect"]: + return self.credentials.get_dialect() diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py new file mode 100644 index 0000000000..0969fcb628 --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -0,0 +1,381 @@ +from typing import ( + Optional, + Iterator, + Any, + Sequence, + ContextManager, + AnyStr, + Union, + Tuple, + List, + Dict, +) +from contextlib import contextmanager +from functools import wraps +import inspect +from pathlib import Path + +import sqlalchemy as sa +from sqlalchemy.engine import Connection + +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.destinations.exceptions import ( + DatabaseUndefinedRelation, + DatabaseTerminalException, + DatabaseTransientException, + LoadClientNotConnected, + DatabaseException, +) +from dlt.destinations.typing import DBTransaction, DBApiCursor +from dlt.destinations.sql_client import SqlClientBase, DBApiCursorImpl +from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyCredentials +from dlt.common.typing import TFun + + +class SqlaTransactionWrapper(DBTransaction): + def __init__(self, sqla_transaction: sa.engine.Transaction) -> None: + self.sqla_transaction = sqla_transaction + + def commit_transaction(self) -> None: + if self.sqla_transaction.is_active: + self.sqla_transaction.commit() + + def rollback_transaction(self) -> None: + if self.sqla_transaction.is_active: + self.sqla_transaction.rollback() + + +def raise_database_error(f: TFun) -> TFun: + @wraps(f) + def _wrap_gen(self: "SqlalchemyClient", *args: Any, **kwargs: Any) -> Any: + try: + return (yield from f(self, *args, **kwargs)) + except Exception as e: + raise self._make_database_exception(e) from e + + @wraps(f) + def _wrap(self: "SqlalchemyClient", *args: Any, **kwargs: Any) -> Any: + try: + return f(self, *args, **kwargs) + except Exception as e: + raise self._make_database_exception(e) from e + + if inspect.isgeneratorfunction(f): + return _wrap_gen # type: ignore[return-value] + return _wrap # type: ignore[return-value] + + +class SqlaDbApiCursor(DBApiCursorImpl): + def __init__(self, curr: sa.engine.CursorResult) -> None: + # Sqlalchemy CursorResult is *mostly* compatible with DB-API cursor + self.native_cursor = curr # type: ignore[assignment] + curr.columns + + self.fetchall = curr.fetchall # type: ignore[method-assign] + self.fetchone = curr.fetchone # type: ignore[method-assign] + self.fetchmany = curr.fetchmany # type: ignore[method-assign] + + def _get_columns(self) -> List[str]: + return list(self.native_cursor.keys()) # type: ignore[attr-defined] + + # @property + # def description(self) -> Any: + # # Get the underlying driver's cursor description, this is mostly used in tests + # return self.native_cursor.cursor.description # type: ignore[attr-defined] + + def execute(self, query: AnyStr, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("execute not implemented") + + +class DbApiProps: + # Only needed for some tests + paramstyle = "named" + + +class SqlalchemyClient(SqlClientBase[Connection]): + external_engine: bool = False + dialect: sa.engine.interfaces.Dialect + dialect_name: str + dbapi = DbApiProps # type: ignore[assignment] + + def __init__( + self, + dataset_name: str, + staging_dataset_name: str, + credentials: SqlalchemyCredentials, + capabilities: DestinationCapabilitiesContext, + engine_args: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(credentials.database, dataset_name, staging_dataset_name, capabilities) + self.credentials = credentials + + if credentials.engine: + self.engine = credentials.engine + self.external_engine = True + else: + self.engine = sa.create_engine( + credentials.to_url().render_as_string(hide_password=False), **(engine_args or {}) + ) + + self._current_connection: Optional[Connection] = None + self._current_transaction: Optional[SqlaTransactionWrapper] = None + self.metadata = sa.MetaData() + self.dialect = self.engine.dialect + self.dialect_name = self.dialect.name # type: ignore[attr-defined] + + def open_connection(self) -> Connection: + if self._current_connection is None: + self._current_connection = self.engine.connect() + return self._current_connection + + def close_connection(self) -> None: + if not self.external_engine: + try: + self.engine.dispose() + finally: + self._current_connection = None + self._current_transaction = None + + @property + def native_connection(self) -> Connection: + if not self._current_connection: + raise LoadClientNotConnected(type(self).__name__, self.dataset_name) + return self._current_connection + + @contextmanager + @raise_database_error + def begin_transaction(self) -> Iterator[DBTransaction]: + if self._current_transaction is not None: + raise DatabaseTerminalException("Transaction already started") + trans = self._current_transaction = SqlaTransactionWrapper(self._current_connection.begin()) + try: + yield trans + except Exception: + self.rollback_transaction() + raise + else: + self.commit_transaction() + finally: + self._current_transaction = None + + def commit_transaction(self) -> None: + """Commits the current transaction.""" + self._current_transaction.commit_transaction() + + def rollback_transaction(self) -> None: + """Rolls back the current transaction.""" + self._current_transaction.rollback_transaction() + + @contextmanager + def _transaction(self) -> Iterator[DBTransaction]: + """Context manager yielding either a new or the currently open transaction. + New transaction will be committed/rolled back on exit. + If the transaction is already open, finalization is handled by the top level context manager. + """ + if self._current_transaction is not None: + yield self._current_transaction + return + with self.begin_transaction() as tx: + yield tx + + def has_dataset(self) -> bool: + schema_names = self.engine.dialect.get_schema_names(self._current_connection) + return self.dataset_name in schema_names + + def _sqlite_create_dataset(self, dataset_name: str) -> None: + """Mimic multiple schemas in sqlite using ATTACH DATABASE to + attach a new database file to the current connection. + """ + db_name = self.engine.url.database + if db_name == ":memory:": + new_db_fn = ":memory:" + else: + current_file_path = Path(db_name) + # New filename e.g. ./results.db -> ./results__new_dataset_name.db + new_db_fn = str( + current_file_path.parent + / f"{current_file_path.stem}__{dataset_name}{current_file_path.suffix}" + ) + + statement = "ATTACH DATABASE :fn AS :name" + self.execute(statement, fn=new_db_fn, name=dataset_name) + + def _sqlite_drop_dataset(self, dataset_name: str) -> None: + """Drop a dataset in sqlite by detaching the database file + attached to the current connection. + """ + # Get a list of attached databases and filenames + rows = self.execute_sql("PRAGMA database_list") + dbs = {row[1]: row[2] for row in rows} # db_name: filename + if dataset_name not in dbs: + return + + statement = "DETACH DATABASE :name" + self.execute(statement, name=dataset_name) + + fn = dbs[dataset_name] + if not fn: # It's a memory database, nothing to do + return + # Delete the database file + Path(fn).unlink() + + def create_dataset(self) -> None: + if self.dialect_name == "sqlite": + return self._sqlite_create_dataset(self.dataset_name) + self.execute_sql(sa.schema.CreateSchema(self.dataset_name)) + + def drop_dataset(self) -> None: + if self.dialect_name == "sqlite": + return self._sqlite_drop_dataset(self.dataset_name) + try: + self.execute_sql(sa.schema.DropSchema(self.dataset_name, cascade=True)) + except DatabaseTransientException as e: + if isinstance(e.__cause__, sa.exc.ProgrammingError): + # May not support CASCADE + self.execute_sql(sa.schema.DropSchema(self.dataset_name)) + else: + raise + + def truncate_tables(self, *tables: str) -> None: + # TODO: alchemy doesn't have a construct for TRUNCATE TABLE + for table in tables: + tbl = sa.Table(table, self.metadata, schema=self.dataset_name, keep_existing=True) + self.execute_sql(tbl.delete()) + + def drop_tables(self, *tables: str) -> None: + for table in tables: + tbl = sa.Table(table, self.metadata, schema=self.dataset_name, keep_existing=True) + self.execute_sql(sa.schema.DropTable(tbl, if_exists=True)) + + def execute_sql( + self, sql: Union[AnyStr, sa.sql.Executable], *args: Any, **kwargs: Any + ) -> Optional[Sequence[Sequence[Any]]]: + with self.execute_query(sql, *args, **kwargs) as cursor: + if cursor.returns_rows: + return cursor.fetchall() + return None + + @contextmanager + def execute_query( + self, query: Union[AnyStr, sa.sql.Executable], *args: Any, **kwargs: Any + ) -> Iterator[DBApiCursor]: + if isinstance(query, str): + if args: + # Sqlalchemy text supports :named paramstyle for all dialects + query, kwargs = self._to_named_paramstyle(query, args) # type: ignore[assignment] + args = () + query = sa.text(query) + with self._transaction(): + yield SqlaDbApiCursor(self._current_connection.execute(query, *args, **kwargs)) # type: ignore[abstract] + + def get_existing_table(self, table_name: str) -> Optional[sa.Table]: + """Get a table object from metadata if it exists""" + key = self.dataset_name + "." + table_name + return self.metadata.tables.get(key) # type: ignore[no-any-return] + + def create_table(self, table_obj: sa.Table) -> None: + table_obj.create(self._current_connection) + + def _make_qualified_table_name(self, table: sa.Table, escape: bool = True) -> str: + if escape: + return self.dialect.identifier_preparer.format_table(table) # type: ignore[attr-defined,no-any-return] + return table.fullname # type: ignore[no-any-return] + + def make_qualified_table_name(self, table_name: str, escape: bool = True) -> str: + tbl = self.get_existing_table(table_name) + if tbl is None: + tmp_metadata = sa.MetaData() + tbl = sa.Table(table_name, tmp_metadata, schema=self.dataset_name) + return self._make_qualified_table_name(tbl, escape) + + def alter_table_add_column(self, column: sa.Column) -> None: + """Execute an ALTER TABLE ... ADD COLUMN ... statement for the given column. + The column must be fully defined and attached to a table. + """ + # TODO: May need capability to override ALTER TABLE statement for different dialects + alter_tmpl = "ALTER TABLE {table} ADD COLUMN {column};" + statement = alter_tmpl.format( + table=self._make_qualified_table_name(self._make_qualified_table_name(column.table)), + column=self.compile_column_def(column), + ) + self.execute_sql(statement) + + def escape_column_name(self, column_name: str, escape: bool = True) -> str: + if self.dialect.requires_name_normalize: + column_name = self.dialect.normalize_name(column_name) + if escape: + return self.dialect.identifier_preparer.format_column(sa.Column(column_name)) # type: ignore[attr-defined,no-any-return] + return column_name + + def compile_column_def(self, column: sa.Column) -> str: + """Compile a column definition including type for ADD COLUMN clause""" + return str(sa.schema.CreateColumn(column).compile(self.engine)) + + def reflect_table( + self, + table_name: str, + metadata: Optional[sa.MetaData] = None, + include_columns: Optional[Sequence[str]] = None, + ) -> Optional[sa.Table]: + """Reflect a table from the database and return the Table object""" + if metadata is None: + metadata = self.metadata + try: + return sa.Table( + table_name, + metadata, + autoload_with=self._current_connection, + schema=self.dataset_name, + include_columns=include_columns, + extend_existing=True, + ) + except sa.exc.NoSuchTableError: + return None + + def compare_storage_table(self, table_name: str) -> Tuple[sa.Table, List[sa.Column], bool]: + """Reflect the table from database and compare it with the version already in metadata. + Returns a 3 part tuple: + - The current version of the table in metadata + - List of columns that are missing from the storage table (all columns if it doesn't exist in storage) + - boolean indicating whether the table exists in storage + """ + existing = self.get_existing_table(table_name) + assert existing is not None, "Table must be present in metadata" + all_columns = list(existing.columns) + all_column_names = [c.name for c in all_columns] + tmp_metadata = sa.MetaData() + reflected = self.reflect_table( + table_name, include_columns=all_column_names, metadata=tmp_metadata + ) + if reflected is None: + missing_columns = all_columns + else: + missing_columns = [c for c in all_columns if c.name not in reflected.columns] + return existing, missing_columns, reflected is not None + + @staticmethod + def _make_database_exception(e: Exception) -> Exception: + if isinstance(e, sa.exc.NoSuchTableError): + return DatabaseUndefinedRelation(e) + msg = str(e).lower() + if isinstance(e, (sa.exc.ProgrammingError, sa.exc.OperationalError)): + if "exist" in msg: # TODO: Hack + return DatabaseUndefinedRelation(e) + elif "no such" in msg: # sqlite # TODO: Hack + return DatabaseUndefinedRelation(e) + elif "unknown table" in msg: + return DatabaseUndefinedRelation(e) + elif "unknown database" in msg: + return DatabaseUndefinedRelation(e) + elif isinstance(e, (sa.exc.OperationalError, sa.exc.IntegrityError)): + raise DatabaseTerminalException(e) + return DatabaseTransientException(e) + elif isinstance(e, sa.exc.SQLAlchemyError): + return DatabaseTransientException(e) + else: + return e + # return DatabaseTerminalException(e) + + def _ensure_native_conn(self) -> None: + if not self.native_connection: + raise LoadClientNotConnected(type(self).__name__, self.dataset_name) diff --git a/dlt/destinations/impl/sqlalchemy/factory.py b/dlt/destinations/impl/sqlalchemy/factory.py new file mode 100644 index 0000000000..c6b78baf3b --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/factory.py @@ -0,0 +1,90 @@ +import typing as t + +from dlt.common.data_writers.configuration import CsvFormatConfiguration +from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.normalizers import NamingConvention + +from dlt.destinations.impl.sqlalchemy.configuration import ( + SqlalchemyCredentials, + SqlalchemyClientConfiguration, +) + +if t.TYPE_CHECKING: + # from dlt.destinations.impl.sqlalchemy.sqlalchemy_client import SqlalchemyJobClient + from dlt.destinations.impl.sqlalchemy.sqlalchemy_job_client import SqlalchemyJobClient + + +class sqlalchemy(Destination[SqlalchemyClientConfiguration, "SqlalchemyJobClient"]): + spec = SqlalchemyClientConfiguration + + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + # https://www.sqlalchemyql.org/docs/current/limits.html + caps = DestinationCapabilitiesContext.generic_capabilities() + caps.preferred_loader_file_format = "typed-jsonl" + caps.supported_loader_file_formats = ["typed-jsonl", "parquet"] + caps.preferred_staging_file_format = None + caps.supported_staging_file_formats = [] + caps.has_case_sensitive_identifiers = True + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 63 + caps.max_column_identifier_length = 63 + caps.max_query_length = 32 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 1024 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = True + caps.max_query_parameters = 20_0000 + caps.max_rows_per_insert = 10_000 # Set a default to avoid OOM on large datasets + + return caps + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: SqlalchemyClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + caps = super(sqlalchemy, cls).adjust_capabilities(caps, config, naming) + dialect = config.get_dialect() + if dialect is None: + return caps + caps.max_identifier_length = dialect.max_identifier_length + caps.max_column_identifier_length = dialect.max_identifier_length + caps.supports_native_boolean = dialect.supports_native_boolean + + return caps + + @property + def client_class(self) -> t.Type["SqlalchemyJobClient"]: + from dlt.destinations.impl.sqlalchemy.sqlalchemy_job_client import SqlalchemyJobClient + + return SqlalchemyJobClient + + def __init__( + self, + credentials: t.Union[SqlalchemyCredentials, t.Dict[str, t.Any], str] = None, + destination_name: t.Optional[str] = None, + environment: t.Optional[str] = None, + engine_args: t.Optional[t.Dict[str, t.Any]] = None, + **kwargs: t.Any, + ) -> None: + """Configure the Sqlalchemy destination to use in a pipeline. + + All arguments provided here supersede other configuration sources such as environment variables and dlt config files. + + Args: + credentials: Credentials to connect to the sqlalchemy database. Can be an instance of `SqlalchemyCredentials` or + a connection string in the format `mysql://user:password@host:port/database` + destination_name: The name of the destination + environment: The environment to use + **kwargs: Additional arguments passed to the destination + """ + super().__init__( + credentials=credentials, + destination_name=destination_name, + environment=environment, + **kwargs, + ) diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py new file mode 100644 index 0000000000..3cd6133d47 --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -0,0 +1,503 @@ +from typing import Iterable, Optional, Type, Dict, Any, Iterator, Sequence, List, Tuple, IO +from types import TracebackType +from contextlib import suppress +import inspect +import math + +import sqlalchemy as sa +from sqlalchemy.sql import sqltypes + +from dlt.common import logger +from dlt.common import pendulum +from dlt.common.exceptions import TerminalValueError +from dlt.common.destination.reference import ( + JobClientBase, + LoadJob, + RunnableLoadJob, + StorageSchemaInfo, + StateInfo, +) +from dlt.destinations.job_client_impl import SqlJobClientBase +from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.schema import Schema, TTableSchema, TColumnSchema, TSchemaTables +from dlt.common.schema.typing import TColumnType, TTableFormat, TTableSchemaColumns +from dlt.common.schema.utils import pipeline_state_table, normalize_table_identifiers +from dlt.common.storages import FileStorage +from dlt.common.json import json, PY_DATETIME_DECODERS +from dlt.destinations.exceptions import DatabaseUndefinedRelation + + +# from dlt.destinations.impl.sqlalchemy.sql_client import SqlalchemyClient +from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient +from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyClientConfiguration + + +class SqlaTypeMapper: + # TODO: May be merged with TypeMapper as a generic + def __init__( + self, capabilities: DestinationCapabilitiesContext, dialect: sa.engine.Dialect + ) -> None: + self.capabilities = capabilities + self.dialect = dialect + + def _db_integer_type(self, precision: Optional[int]) -> sa.types.TypeEngine: + if precision is None: + return sa.BigInteger() + elif precision <= 16: + return sa.SmallInteger() + elif precision <= 32: + return sa.Integer() + elif precision <= 64: + return sa.BigInteger() + raise TerminalValueError(f"Unsupported precision for integer type: {precision}") + + def _create_date_time_type(self, sc_t: str, precision: Optional[int]) -> sa.types.TypeEngine: + """Use the dialect specific datetime/time type if possible since the generic type doesn't accept precision argument""" + precision = precision if precision is not None else self.capabilities.timestamp_precision + if sc_t == "timestamp": + base_type = sa.DateTime() + if self.dialect.name == "mysql": # type: ignore[attr-defined] + # Special case, type_descriptor does not return the specifc datetime type + from sqlalchemy.dialects.mysql import DATETIME + + return DATETIME(fsp=precision) + elif sc_t == "time": + base_type = sa.Time() + + dialect_type = type( + self.dialect.type_descriptor(base_type) + ) # Get the dialect specific subtype + precision = precision if precision is not None else self.capabilities.timestamp_precision + + # Find out whether the dialect type accepts precision or fsp argument + params = inspect.signature(dialect_type).parameters + kwargs: Dict[str, Any] = dict(timezone=True) + if "fsp" in params: + kwargs["fsp"] = precision # MySQL uses fsp for fractional seconds + elif "precision" in params: + kwargs["precision"] = precision + return dialect_type(**kwargs) # type: ignore[no-any-return] + + def _create_double_type(self) -> sa.types.TypeEngine: + if dbl := getattr(sa, "Double", None): + # Sqlalchemy 2 has generic double type + return dbl() # type: ignore[no-any-return] + elif self.dialect.name == "mysql": + # MySQL has a specific double type + from sqlalchemy.dialects.mysql import DOUBLE + return sa.Float(precision=53) # Otherwise use float + + def _to_db_decimal_type(self, column: TColumnSchema) -> sa.types.TypeEngine: + precision, scale = column.get("precision"), column.get("scale") + if precision is None and scale is None: + precision, scale = self.capabilities.decimal_precision + return sa.Numeric(precision, scale) + + def to_db_type(self, column: TColumnSchema, table_format: TTableSchema) -> sa.types.TypeEngine: + sc_t = column["data_type"] + precision = column.get("precision") + # TODO: Precision and scale for supported types + if sc_t == "text": + length = precision + if length is None and column.get("unique"): + length = 128 + if length is None: + return sa.Text() # type: ignore[no-any-return] + return sa.String(length=length) # type: ignore[no-any-return] + elif sc_t == "double": + return self._create_double_type() + elif sc_t == "bool": + return sa.Boolean() + elif sc_t == "timestamp": + return self._create_date_time_type(sc_t, precision) + elif sc_t == "bigint": + return self._db_integer_type(precision) + elif sc_t == "binary": + return sa.LargeBinary(length=precision) + elif sc_t == "complex": + return sa.JSON(none_as_null=True) + elif sc_t == "decimal": + return self._to_db_decimal_type(column) + elif sc_t == "wei": + wei_precision, wei_scale = self.capabilities.wei_precision + return sa.Numeric(precision=wei_precision, scale=wei_scale) + elif sc_t == "date": + return sa.Date() + elif sc_t == "time": + return self._create_date_time_type(sc_t, precision) + raise TerminalValueError(f"Unsupported data type: {sc_t}") + + def _from_db_integer_type(self, db_type: sa.Integer) -> TColumnType: + if isinstance(db_type, sa.SmallInteger): + return dict(data_type="bigint", precision=16) + elif isinstance(db_type, sa.Integer): + return dict(data_type="bigint", precision=32) + elif isinstance(db_type, sa.BigInteger): + return dict(data_type="bigint") + return dict(data_type="bigint") + + def _from_db_decimal_type(self, db_type: sa.Numeric) -> TColumnType: + precision, scale = db_type.precision, db_type.scale + if (precision, scale) == self.capabilities.wei_precision: + return dict(data_type="wei") + + return dict(data_type="decimal", precision=precision, scale=scale) + + def from_db_type(self, db_type: sa.types.TypeEngine) -> TColumnType: + # TODO: pass the sqla type through dialect.type_descriptor before instance check + # Possibly need to check both dialect specific and generic types + if isinstance(db_type, sa.String): + return dict(data_type="text") + elif isinstance(db_type, sa.Float): + return dict(data_type="double") + elif isinstance(db_type, sa.Boolean): + return dict(data_type="bool") + elif isinstance(db_type, sa.DateTime): + return dict(data_type="timestamp") + elif isinstance(db_type, sa.Integer): + return self._from_db_integer_type(db_type) + elif isinstance(db_type, sqltypes._Binary): + return dict(data_type="binary", precision=db_type.length) + elif isinstance(db_type, sa.JSON): + return dict(data_type="complex") + elif isinstance(db_type, sa.Numeric): + return self._from_db_decimal_type(db_type) + elif isinstance(db_type, sa.Date): + return dict(data_type="date") + elif isinstance(db_type, sa.Time): + return dict(data_type="time") + raise TerminalValueError(f"Unsupported db type: {db_type}") + + +class SqlalchemyJsonLInsertJob(RunnableLoadJob): + def __init__(self, file_path: str, table: sa.Table) -> None: + super().__init__(file_path) + self._job_client: "SqlalchemyJobClient" = None + self.table = table + + def _open_load_file(self) -> IO[bytes]: + return FileStorage.open_zipsafe_ro(self._file_path, "rb") + + def _iter_data_items(self) -> Iterator[Dict[str, Any]]: + all_cols = {col.name: None for col in self.table.columns} + with FileStorage.open_zipsafe_ro(self._file_path, "rb") as f: + for line in f: + # Decode date/time to py datetime objects. Some drivers have issues with pendulum objects + for item in json.typed_loadb(line, decoders=PY_DATETIME_DECODERS): + # Fill any missing columns in item with None. Bulk insert fails when items have different keys + if item.keys() != all_cols.keys(): + yield {**all_cols, **item} + else: + yield item + + def _iter_data_item_chunks(self) -> Iterator[Sequence[Dict[str, Any]]]: + max_rows = self._job_client.capabilities.max_rows_per_insert or math.inf + # Limit by max query length should not be needed, + # bulk insert generates an INSERT template with a single VALUES tuple of placeholders + # If any dialects don't do that we need to check the str length of the query + # TODO: Max params may not be needed. Limits only apply to placeholders in sql string (mysql/sqlite) + max_params = self._job_client.capabilities.max_query_parameters or math.inf + chunk: List[Dict[str, Any]] = [] + params_count = 0 + for item in self._iter_data_items(): + if len(chunk) + 1 == max_rows or params_count + len(item) > max_params: + # Rotate chunk + yield chunk + chunk = [] + params_count = 0 + params_count += len(item) + chunk.append(item) + + if chunk: + yield chunk + + def run(self) -> None: + _sql_client = self._job_client.sql_client + + with _sql_client.begin_transaction(): + for chunk in self._iter_data_item_chunks(): + _sql_client.execute_sql(self.table.insert(), chunk) + + +class SqlalchemyParquetInsertJob(SqlalchemyJsonLInsertJob): + def _iter_data_item_chunks(self) -> Iterator[Sequence[Dict[str, Any]]]: + from dlt.common.libs.pyarrow import ParquetFile + + num_cols = len(self.table.columns) + max_rows = self._job_client.capabilities.max_rows_per_insert or None + max_params = self._job_client.capabilities.max_query_parameters or None + read_limit = None + + with ParquetFile(self._file_path) as reader: + if max_params is not None: + read_limit = math.floor(max_params / num_cols) + + if max_rows is not None: + if read_limit is None: + read_limit = max_rows + else: + read_limit = min(read_limit, max_rows) + + if read_limit is None: + yield reader.read().to_pylist() + return + + for chunk in reader.iter_batches(batch_size=read_limit): + yield chunk.to_pylist() + + +class SqlalchemyJobClient(SqlJobClientBase): + sql_client: SqlalchemyClient # type: ignore[assignment] + + def __init__( + self, + schema: Schema, + config: SqlalchemyClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + self.sql_client = SqlalchemyClient( + config.normalize_dataset_name(schema), + None, + config.credentials, + capabilities, + engine_args=config.engine_args, + ) + + self.schema = schema + self.capabilities = capabilities + self.config = config + self.type_mapper = SqlaTypeMapper(capabilities, self.sql_client.dialect) + + def _to_table_object(self, schema_table: TTableSchema) -> sa.Table: + existing = self.sql_client.get_existing_table(schema_table["name"]) + if existing is not None: + existing_col_names = set(col.name for col in existing.columns) + new_col_names = set(schema_table["columns"]) + # Re-generate the table if columns have changed + if existing_col_names == new_col_names: + return existing + return sa.Table( + schema_table["name"], + self.sql_client.metadata, + *[ + self._to_column_object(col, schema_table) + for col in schema_table["columns"].values() + ], + extend_existing=True, + schema=self.sql_client.dataset_name, + ) + + def _to_column_object( + self, schema_column: TColumnSchema, table_format: TTableSchema + ) -> sa.Column: + return sa.Column( + schema_column["name"], + self.type_mapper.to_db_type(schema_column, table_format), + nullable=schema_column.get("nullable", True), + unique=schema_column.get("unique", False), + ) + + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + if file_path.endswith(".typed-jsonl"): + table_obj = self._to_table_object(table) + return SqlalchemyJsonLInsertJob(file_path, table_obj) + elif file_path.endswith(".parquet"): + table_obj = self._to_table_object(table) + return SqlalchemyParquetInsertJob(file_path, table_obj) + return None + + def complete_load(self, load_id: str) -> None: + loads_table = self._to_table_object(self.schema.tables[self.schema.loads_table_name]) + now_ts = pendulum.now() + self.sql_client.execute_sql( + loads_table.insert().values( + ( + load_id, + self.schema.name, + 0, + now_ts, + self.schema.version_hash, + ) + ) + ) + + def _get_table_key(self, name: str, schema: Optional[str]) -> str: + if schema is None: + return name + else: + return schema + "." + name + + def get_storage_tables( + self, table_names: Iterable[str] + ) -> Iterable[Tuple[str, TTableSchemaColumns]]: + metadata = sa.MetaData() + for table_name in table_names: + table_obj = self.sql_client.reflect_table(table_name, metadata) + if table_obj is None: + yield table_name, {} + continue + yield table_name, { + col.name: { + "name": col.name, + "nullable": col.nullable, + **self.type_mapper.from_db_type(col.type), + } + for col in table_obj.columns + } + + def update_stored_schema( + self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None + ) -> Optional[TSchemaTables]: + # super().update_stored_schema(only_tables, expected_update) + JobClientBase.update_stored_schema(self, only_tables, expected_update) + + schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) + if schema_info is not None: + logger.info( + "Schema with hash %s inserted at %s found in storage, no upgrade required", + self.schema.stored_version_hash, + schema_info.inserted_at, + ) + return {} + else: + logger.info( + "Schema with hash %s not found in storage, upgrading", + self.schema.stored_version_hash, + ) + + # Create all schema tables in metadata + for table_name in only_tables or self.schema.tables: + self._to_table_object(self.schema.tables[table_name]) + + schema_update: TSchemaTables = {} + tables_to_create: List[sa.Table] = [] + columns_to_add: List[sa.Column] = [] + + for table_name in only_tables or self.schema.tables: + table = self.schema.tables[table_name] + table_obj, new_columns, exists = self.sql_client.compare_storage_table(table["name"]) + if not new_columns: # Nothing to do, don't create table without columns + continue + if not exists: + tables_to_create.append(table_obj) + else: + columns_to_add.extend(new_columns) + partial_table = self.prepare_load_table(table_name) + new_column_names = set(col.name for col in new_columns) + partial_table["columns"] = { + col_name: col_def + for col_name, col_def in partial_table["columns"].items() + if col_name in new_column_names + } + schema_update[table_name] = partial_table + + with self.sql_client.begin_transaction(): + for table_obj in tables_to_create: + self.sql_client.create_table(table_obj) + for col in columns_to_add: + alter = "ALTER TABLE {} ADD COLUMN {}".format( + self.sql_client.make_qualified_table_name(col.table.name), + self.sql_client.compile_column_def(col), + ) + self.sql_client.execute_sql(alter) + + self._update_schema_in_storage(self.schema) + + return schema_update + + def _delete_schema_in_storage(self, schema: Schema) -> None: + version_table = schema.tables[schema.version_table_name] + table_obj = self._to_table_object(version_table) + schema_name_col = schema.naming.normalize_identifier("schema_name") + self.sql_client.execute_sql( + table_obj.delete().where(table_obj.c[schema_name_col] == schema.name) + ) + + def _update_schema_in_storage(self, schema: Schema) -> None: + version_table = schema.tables[schema.version_table_name] + table_obj = self._to_table_object(version_table) + schema_str = json.dumps(schema.to_dict()) + + schema_mapping = StorageSchemaInfo( + version=schema.version, + engine_version=schema.ENGINE_VERSION, + schema_name=schema.name, + version_hash=schema.stored_version_hash, + schema=schema_str, + inserted_at=pendulum.now(), + ).to_normalized_mapping(schema.naming) + + self.sql_client.execute_sql(table_obj.insert().values(schema_mapping)) + + def _get_stored_schema( + self, version_hash: Optional[str] = None, schema_name: Optional[str] = None + ) -> Optional[StorageSchemaInfo]: + version_table = self.schema.tables[self.schema.version_table_name] + table_obj = self._to_table_object(version_table) + with suppress(DatabaseUndefinedRelation): + q = sa.select([table_obj]) + if version_hash is not None: + version_hash_col = self.schema.naming.normalize_identifier("version_hash") + q = q.where(table_obj.c[version_hash_col] == version_hash) + if schema_name is not None: + schema_name_col = self.schema.naming.normalize_identifier("schema_name") + q = q.where(table_obj.c[schema_name_col] == schema_name) + inserted_at_col = self.schema.naming.normalize_identifier("inserted_at") + q = q.order_by(table_obj.c[inserted_at_col].desc()) + with self.sql_client.execute_query(q) as cur: + row = cur.fetchone() + if row is None: + return None + + # TODO: Decode compressed schema str if needed + return StorageSchemaInfo.from_normalized_mapping(row._mapping, self.schema.naming) + + def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: + return self._get_stored_schema(version_hash) + + def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + """Get the latest stored schema""" + return self._get_stored_schema(schema_name=self.schema.name) + + def get_stored_state(self, pipeline_name: str) -> StateInfo: + state_table = self.schema.tables.get( + self.schema.state_table_name + ) or normalize_table_identifiers(pipeline_state_table(), self.schema.naming) + state_table_obj = self._to_table_object(state_table) + loads_table = self.schema.tables[self.schema.loads_table_name] + loads_table_obj = self._to_table_object(loads_table) + + c_load_id, c_dlt_load_id, c_pipeline_name, c_status = map( + self.schema.naming.normalize_identifier, + ("load_id", "_dlt_load_id", "pipeline_name", "status"), + ) + + query = ( + sa.select(state_table_obj) + .join(loads_table_obj, loads_table_obj.c[c_load_id] == state_table_obj.c[c_dlt_load_id]) + .where( + sa.and_( + state_table_obj.c[c_pipeline_name] == pipeline_name, + loads_table_obj.c[c_status] == 0, + ) + ) + .order_by(loads_table_obj.c[c_load_id].desc()) + ) + + with self.sql_client.execute_query(query) as cur: + row = cur.fetchone() + if not row: + return None + mapping = dict(row._mapping) + + return StateInfo.from_normalized_mapping(mapping, self.schema.naming) + + def _from_db_type( + self, db_type: str, precision: Optional[int], scale: Optional[int] + ) -> TColumnType: + raise NotImplementedError() + + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + raise NotImplementedError() diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 27d1bc7ce5..96f18cea3d 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -133,6 +133,23 @@ def drop_tables(self, *tables: str) -> None: ] self.execute_many(statements) + def _to_named_paramstyle(self, query: str, args: Sequence[Any]) -> Tuple[str, Dict[str, Any]]: + """Convert a query from "format" ( %s ) paramstyle to "named" ( :param_name ) paramstyle. + The %s are replaced with :arg0, :arg1, ... and the arguments are returned as a dictionary. + + Args: + query: SQL query with %s placeholders + args: arguments to be passed to the query + + Returns: + Tuple of the new query and a dictionary of named arguments + """ + keys = [f"arg{i}" for i in range(len(args))] + # Replace position arguments (%s) with named arguments (:arg0, :arg1, ...) + query = query % tuple(f":{key}" for key in keys) + db_args = {key: db_arg for key, db_arg in zip(keys, args)} + return query, db_args + @abstractmethod def execute_sql( self, sql: AnyStr, *args: Any, **kwargs: Any diff --git a/pyproject.toml b/pyproject.toml index 4ca80d0993..1d4c9cd768 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ dremio = ["pyarrow"] lancedb = ["lancedb", "pyarrow", "tantivy"] deltalake = ["deltalake", "pyarrow"] sql_database = ["sqlalchemy"] +sqlalchemy = ["sqlalchemy"] [tool.poetry.scripts] diff --git a/tests/cases.py b/tests/cases.py index 11358441ee..dd362b538f 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -1,3 +1,4 @@ +import datetime import hashlib from typing import Dict, List, Any, Sequence, Tuple, Literal, Union import base64 @@ -5,6 +6,7 @@ from copy import deepcopy from string import ascii_lowercase import random +import secrets from dlt.common import Decimal, pendulum, json from dlt.common.data_types import TDataType @@ -315,7 +317,7 @@ def arrow_table_all_data_types( import numpy as np data = { - "string": [random.choice(ascii_lowercase) + "\"'\\🦆\n\r" for _ in range(num_rows)], + "string": [secrets.token_urlsafe(8) + "\"'\\🦆\n\r" for _ in range(num_rows)], "float": [round(random.uniform(0, 100), 4) for _ in range(num_rows)], "int": [random.randrange(0, 100) for _ in range(num_rows)], "datetime": pd.date_range("2021-01-01T01:02:03.1234", periods=num_rows, tz=tz, unit="us"), @@ -340,7 +342,18 @@ def arrow_table_all_data_types( data["json"] = [{"a": random.randrange(0, 100)} for _ in range(num_rows)] if include_time: - data["time"] = pd.date_range("2021-01-01", periods=num_rows, tz="UTC").time + # data["time"] = pd.date_range("2021-01-01", periods=num_rows, tz="UTC").time + # data["time"] = pd.date_range("2021-01-01T01:02:03.1234", periods=num_rows, tz=tz, unit="us").time + # random time objects with different hours/minutes/seconds/microseconds + data["time"] = [ + datetime.time( + random.randrange(0, 24), + random.randrange(0, 60), + random.randrange(0, 60), + random.randrange(0, 1000000), + ) + for _ in range(num_rows) + ] if include_binary: # "binary": [hashlib.sha3_256(random.choice(ascii_lowercase).encode()).digest() for _ in range(num_rows)], diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index a86deea799..4a810cb755 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -1,4 +1,4 @@ -from datetime import datetime # noqa: I251 +from datetime import datetime, timedelta, time as dt_time # noqa: I251 import os import pytest @@ -9,7 +9,7 @@ import dlt from dlt.common import pendulum -from dlt.common.time import reduce_pendulum_datetime_precision +from dlt.common.time import reduce_pendulum_datetime_precision, ensure_pendulum_time from dlt.common.utils import uniq_id from tests.load.utils import destinations_configs, DestinationTestConfiguration @@ -121,6 +121,7 @@ def some_data(): if "binary" in record and destination_config.file_format == "parquet": record["binary"] = record["binary"].decode("ascii") + first_record = list(records[0].values()) for row in rows: for i in range(len(row)): if isinstance(row[i], datetime): @@ -132,6 +133,9 @@ def some_data(): and isinstance(row[i], float) ): row[i] = round(row[i], 4) + if isinstance(row[i], timedelta) and isinstance(first_record[i], dt_time): + # Some drivers (mysqlclient) return TIME columns as timedelta as seconds since midnight + row[i] = ensure_pendulum_time(row[i]) expected = sorted([list(r.values()) for r in records]) diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 6c6ef21140..6bfeaf0df5 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -541,7 +541,9 @@ def _get_shuffled_events(shuffle: bool = dlt.secrets.value): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, exclude=["sqlalchemy"]), + ids=lambda x: x.name, ) @pytest.mark.parametrize("github_resource", [github_repo_events, github_repo_events_table_meta]) def test_merge_with_dispatch_and_incremental( diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 95fed2343c..4cb43ae322 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -755,7 +755,9 @@ def table_3(make_data=False): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, exclude=["sqlalchemy"]), + ids=lambda x: x.name, ) def test_query_all_info_tables_fallback(destination_config: DestinationTestConfiguration) -> None: pipeline = destination_config.setup_pipeline( diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 891ca5e809..df342d4f01 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -417,7 +417,9 @@ def test_get_storage_table_with_all_types(client: SqlJobClientBase) -> None: assert len(storage_table) > 0 # column order must match TABLE_UPDATE storage_columns = list(storage_table.values()) - for c, expected_c in zip(TABLE_UPDATE, storage_columns): + for c, expected_c in zip( + TABLE_UPDATE, storage_columns + ): # TODO: c and expected_c need to be swapped # storage columns are returned with column names as in information schema assert client.capabilities.casefold_identifier(c["name"]) == expected_c["name"] # athena does not know wei data type and has no JSON type, time is not supported with parquet tables @@ -442,11 +444,19 @@ def test_get_storage_table_with_all_types(client: SqlJobClientBase) -> None: continue if client.config.destination_type == "dremio" and c["data_type"] == "json": continue + if not client.capabilities.supports_native_boolean and c["data_type"] == "bool": + # The reflected data type is probably either int or boolean depending on how the client is implemented + assert expected_c["data_type"] in ("bigint", "bool") + continue + assert c["data_type"] == expected_c["data_type"] @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True, exclude=("sqlalchemy",)), + indirect=True, + ids=lambda x: x.name, ) def test_preserve_column_order(client: SqlJobClientBase) -> None: schema = client.schema @@ -576,18 +586,21 @@ def test_load_with_all_types( client.schema._bump_version() client.update_stored_schema() - should_load_to_staging = client.should_load_data_to_staging_dataset(table_name) # type: ignore[attr-defined] - if should_load_to_staging: - with client.with_staging_dataset(): # type: ignore[attr-defined] - # create staging for merge dataset - client.initialize_storage() - client.update_stored_schema() + if isinstance(client, WithStagingDataset): + should_load_to_staging = client.should_load_data_to_staging_dataset(table_name) # type: ignore[attr-defined] + if should_load_to_staging: + with client.with_staging_dataset(): # type: ignore[attr-defined] + # create staging for merge dataset + client.initialize_storage() + client.update_stored_schema() - with client.sql_client.with_alternative_dataset_name( - client.sql_client.staging_dataset_name - if should_load_to_staging - else client.sql_client.dataset_name - ): + with client.sql_client.with_alternative_dataset_name( + client.sql_client.staging_dataset_name + if should_load_to_staging + else client.sql_client.dataset_name + ): + canonical_name = client.sql_client.make_qualified_table_name(table_name) + else: canonical_name = client.sql_client.make_qualified_table_name(table_name) # write row print(data_row) @@ -646,6 +659,8 @@ def test_write_dispositions( client.update_stored_schema() if write_disposition == "merge": + if not client.capabilities.supported_merge_strategies: + pytest.skip("destination does not support merge") # add root key client.schema.tables[table_name]["columns"]["col1"]["root_key"] = True # create staging for merge dataset @@ -665,9 +680,11 @@ def test_write_dispositions( with io.BytesIO() as f: write_dataset(client, f, [data_row], column_schemas) query = f.getvalue() - if client.should_load_data_to_staging_dataset(table_name): # type: ignore[attr-defined] + if isinstance( + client, WithStagingDataset + ) and client.should_load_data_to_staging_dataset(table_name): # load to staging dataset on merge - with client.with_staging_dataset(): # type: ignore[attr-defined] + with client.with_staging_dataset(): expect_load_file(client, file_storage, query, t) else: # load directly on other @@ -814,6 +831,8 @@ def test_get_stored_state( # get state stored_state = client.get_stored_state("pipeline") + # Ensure timezone aware datetime for comparing + stored_state.created_at = pendulum.instance(stored_state.created_at) assert doc == stored_state.as_doc() @@ -909,7 +928,11 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: "mandatory_column", "text", nullable=False ) client.schema._bump_version() - if destination_config.destination == "clickhouse": + if destination_config.destination == "clickhouse" or ( + # mysql allows adding not-null columns (they have an implicit default) + destination_config.destination == "sqlalchemy" + and client.sql_client.dialect_name == "mysql" + ): client.update_stored_schema() else: with pytest.raises(DatabaseException) as py_ex: diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index e167f0ceda..fa154c65dc 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -20,7 +20,7 @@ from dlt.destinations.sql_client import DBApiCursor, SqlClientBase from dlt.destinations.job_client_impl import SqlJobClientBase from dlt.destinations.typing import TNativeConn -from dlt.common.time import ensure_pendulum_datetime +from dlt.common.time import ensure_pendulum_datetime, to_py_datetime from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage from tests.load.utils import ( @@ -62,7 +62,9 @@ def naming(request) -> str: @pytest.mark.parametrize( "client", destinations_configs( - default_sql_configs=True, exclude=["mssql", "synapse", "dremio", "clickhouse"] + # Only databases that support search path or equivalent + default_sql_configs=True, + exclude=["mssql", "synapse", "dremio", "clickhouse", "sqlalchemy"], ), indirect=True, ids=lambda x: x.name, @@ -212,10 +214,10 @@ def test_execute_sql(client: SqlJobClientBase) -> None: assert rows[0][0] == "event" # print(rows[0][1]) # print(type(rows[0][1])) - # convert to pendulum to make sure it is supported by dbapi + # ensure datetime obj to make sure it is supported by dbapi rows = client.sql_client.execute_sql( f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at = %s", - ensure_pendulum_datetime(rows[0][1]), + to_py_datetime(ensure_pendulum_datetime(rows[0][1])), ) assert len(rows) == 1 # use rows in subsequent test @@ -620,7 +622,7 @@ def test_max_column_identifier_length(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( "client", - destinations_configs(default_sql_configs=True, exclude=["databricks"]), + destinations_configs(default_sql_configs=True, exclude=["databricks", "sqlalchemy"]), indirect=True, ids=lambda x: x.name, ) diff --git a/tests/load/utils.py b/tests/load/utils.py index 0eaf68d8f8..0fad4e1ddc 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -295,7 +295,8 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration(destination=destination) for destination in SQL_DESTINATIONS - if destination not in ("athena", "synapse", "databricks", "dremio", "clickhouse") + if destination + not in ("athena", "synapse", "databricks", "dremio", "clickhouse", "sqlalchemy") ] destination_configs += [ DestinationTestConfiguration(destination="duckdb", file_format="parquet"), @@ -305,6 +306,11 @@ def destinations_configs( # add Athena staging configs destination_configs += default_sql_configs_with_staging + destination_configs += [ + DestinationTestConfiguration( + destination="sqlalchemy", supports_merge=False, supports_dbt=False + ) + ] destination_configs += [ DestinationTestConfiguration( destination="clickhouse", file_format="jsonl", supports_dbt=False diff --git a/tests/utils.py b/tests/utils.py index 9facdfc375..e90ac5a626 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -54,6 +54,7 @@ "databricks", "clickhouse", "dremio", + "sqlalchemy", } NON_SQL_DESTINATIONS = { "filesystem", From 31d903357aca691c4f8931a0c12d9444bc79ffc0 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 11 Sep 2024 12:18:53 -0400 Subject: [PATCH 02/37] Support destination name in tests --- tests/load/pipeline/test_arrow_loading.py | 16 +- tests/load/pipeline/test_dbt_helper.py | 6 +- tests/load/pipeline/test_merge_disposition.py | 4 +- tests/load/pipeline/test_pipelines.py | 33 ++-- tests/load/pipeline/test_refresh_modes.py | 2 +- tests/load/pipeline/test_restore_state.py | 22 +-- tests/load/pipeline/test_stage_loading.py | 22 +-- .../test_write_disposition_changes.py | 2 +- tests/load/test_job_client.py | 16 +- tests/load/test_sql_client.py | 4 +- tests/load/utils.py | 146 ++++++++++-------- tests/pipeline/test_pipeline_extra.py | 4 +- 12 files changed, 149 insertions(+), 128 deletions(-) diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 4a810cb755..86711e679a 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -41,7 +41,7 @@ def test_load_arrow_item( # os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_LOAD_ID"] = "True" os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_ID"] = "True" - include_time = destination_config.destination not in ( + include_time = destination_config.destination_type not in ( "athena", "redshift", "databricks", @@ -49,15 +49,15 @@ def test_load_arrow_item( "clickhouse", ) # athena/redshift can't load TIME columns include_binary = not ( - destination_config.destination in ("redshift", "databricks") + destination_config.destination_type in ("redshift", "databricks") and destination_config.file_format == "jsonl" ) include_decimal = not ( - destination_config.destination == "databricks" and destination_config.file_format == "jsonl" + destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl" ) include_date = not ( - destination_config.destination == "databricks" and destination_config.file_format == "jsonl" + destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl" ) item, records, _ = arrow_table_all_data_types( @@ -77,7 +77,7 @@ def some_data(): # use csv for postgres to get native arrow processing destination_config.file_format = ( - destination_config.file_format if destination_config.destination != "postgres" else "csv" + destination_config.file_format if destination_config.destination_type != "postgres" else "csv" ) load_info = pipeline.run(some_data(), **destination_config.run_kwargs) @@ -107,13 +107,13 @@ def some_data(): if isinstance(row[i], memoryview): row[i] = row[i].tobytes() - if destination_config.destination == "redshift": + if destination_config.destination_type == "redshift": # Redshift needs hex string for record in records: if "binary" in record: record["binary"] = record["binary"].hex() - if destination_config.destination == "clickhouse": + if destination_config.destination_type == "clickhouse": for record in records: # Clickhouse needs base64 string for jsonl if "binary" in record and destination_config.file_format == "jsonl": @@ -128,7 +128,7 @@ def some_data(): row[i] = pendulum.instance(row[i]) # clickhouse produces rounding errors on double with jsonl, so we round the result coming from there if ( - destination_config.destination == "clickhouse" + destination_config.destination_type == "clickhouse" and destination_config.file_format == "jsonl" and isinstance(row[i], float) ): diff --git a/tests/load/pipeline/test_dbt_helper.py b/tests/load/pipeline/test_dbt_helper.py index 0ee02ba4b5..d55c81e998 100644 --- a/tests/load/pipeline/test_dbt_helper.py +++ b/tests/load/pipeline/test_dbt_helper.py @@ -36,7 +36,7 @@ def dbt_venv() -> Iterator[Venv]: def test_run_jaffle_package( destination_config: DestinationTestConfiguration, dbt_venv: Venv ) -> None: - if destination_config.destination == "athena": + if destination_config.destination_type == "athena": pytest.skip( "dbt-athena requires database to be created and we don't do it in case of Jaffle" ) @@ -71,7 +71,7 @@ def test_run_jaffle_package( ids=lambda x: x.name, ) def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_venv: Venv) -> None: - if destination_config.destination == "mssql": + if destination_config.destination_type == "mssql": pytest.skip( "mssql requires non standard SQL syntax and we do not have specialized dbt package" " for it" @@ -130,7 +130,7 @@ def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_ven def test_run_chess_dbt_to_other_dataset( destination_config: DestinationTestConfiguration, dbt_venv: Venv ) -> None: - if destination_config.destination == "mssql": + if destination_config.destination_type == "mssql": pytest.skip( "mssql requires non standard SQL syntax and we do not have specialized dbt package" " for it" diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 6bfeaf0df5..8a1ef5c2d1 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -1120,7 +1120,7 @@ def data_resource(data): assert sorted(observed, key=lambda d: d["id"]) == expected # additional tests with two records, run only on duckdb to limit test load - if destination_config.destination == "duckdb": + if destination_config.destination_type == "duckdb": # two records with same primary key # record with highest value in sort column is a delete # existing record is deleted and no record will be inserted @@ -1201,7 +1201,7 @@ def r(): ids=lambda x: x.name, ) def test_upsert_merge_strategy_config(destination_config: DestinationTestConfiguration) -> None: - if destination_config.destination == "filesystem": + if destination_config.destination_type == "filesystem": # TODO: implement validation and remove this test exception pytest.skip( "`upsert` merge strategy configuration validation has not yet been" diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 4cb43ae322..8435c13dd1 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -100,7 +100,7 @@ def data_fun() -> Iterator[Any]: # mock the correct destinations (never do that in normal code) with p.managed_state(): p._set_destinations( - destination=Destination.from_reference(destination_config.destination), + destination=Destination.from_reference(destination_config.destination_type), staging=( Destination.from_reference(destination_config.staging) if destination_config.staging @@ -162,13 +162,12 @@ def test_default_schema_name( for idx, alpha in [(0, "A"), (0, "B"), (0, "C")] ] - p = dlt.pipeline( + p = destination_config.setup_pipeline( "test_default_schema_name", - TEST_STORAGE_ROOT, - destination=destination_config.destination, - staging=destination_config.staging, dataset_name=dataset_name, + pipelines_dir=TEST_STORAGE_ROOT, ) + p.config.use_single_dataset = use_single_dataset p.extract( data, @@ -212,7 +211,7 @@ def _data(): destination_config.setup() info = dlt.run( _data(), - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name="specific" + uniq_id(), **destination_config.run_kwargs, @@ -288,7 +287,7 @@ def _data(): p = dlt.pipeline(dev_mode=True) info = p.run( _data(), - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name="iteration" + uniq_id(), **destination_config.run_kwargs, @@ -380,7 +379,7 @@ def extended_rows(): assert "new_column" not in schema.get_table("simple_rows")["columns"] # lets violate unique constraint on postgres, redshift and BQ ignore unique indexes - if destination_config.destination == "postgres": + if destination_config.destination_type == "postgres": # let it complete even with PK violation (which is a teminal error) os.environ["RAISE_ON_FAILED_JOBS"] = "false" assert p.dataset_name == dataset_name @@ -466,7 +465,7 @@ def nested_data(): info = dlt.run( nested_data(), - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name="ds_" + uniq_id(), **destination_config.run_kwargs, @@ -509,11 +508,11 @@ def other_data(): column_schemas = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) # parquet on bigquery and clickhouse does not support JSON but we still want to run the test - if destination_config.destination in ["bigquery"]: + if destination_config.destination_type in ["bigquery"]: column_schemas["col9_null"]["data_type"] = column_schemas["col9"]["data_type"] = "text" # duckdb 0.9.1 does not support TIME other than 6 - if destination_config.destination in ["duckdb", "motherduck"]: + if destination_config.destination_type in ["duckdb", "motherduck"]: column_schemas["col11_precision"]["precision"] = None # also we do not want to test col4_precision (datetime) because # those timestamps are not TZ aware in duckdb and we'd need to @@ -522,7 +521,7 @@ def other_data(): column_schemas["col4_precision"]["precision"] = 6 # drop TIME from databases not supporting it via parquet - if destination_config.destination in [ + if destination_config.destination_type in [ "redshift", "athena", "synapse", @@ -536,7 +535,7 @@ def other_data(): column_schemas.pop("col11_null") column_schemas.pop("col11_precision") - if destination_config.destination in ("redshift", "dremio"): + if destination_config.destination_type in ("redshift", "dremio"): data_types.pop("col7_precision") column_schemas.pop("col7_precision") @@ -583,10 +582,10 @@ def some_source(): assert_all_data_types_row( db_row, schema=column_schemas, - parse_json_strings=destination_config.destination + parse_json_strings=destination_config.destination_type in ["snowflake", "bigquery", "redshift"], - allow_string_binary=destination_config.destination == "clickhouse", - timestamp_precision=3 if destination_config.destination in ("athena", "dremio") else 6, + allow_string_binary=destination_config.destination_type == "clickhouse", + timestamp_precision=3 if destination_config.destination_type in ("athena", "dremio") else 6, ) @@ -839,7 +838,7 @@ def _data(): p = dlt.pipeline( pipeline_name=f"pipeline_{dataset_name}", dev_mode=dev_mode, - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name=dataset_name, ) diff --git a/tests/load/pipeline/test_refresh_modes.py b/tests/load/pipeline/test_refresh_modes.py index 59063aacea..dcb2be44dc 100644 --- a/tests/load/pipeline/test_refresh_modes.py +++ b/tests/load/pipeline/test_refresh_modes.py @@ -275,7 +275,7 @@ def test_refresh_drop_data_only(destination_config: DestinationTestConfiguration data = load_tables_to_dicts(pipeline, "some_data_1", "some_data_2", "some_data_3") # name column still remains when table was truncated instead of dropped # (except on filesystem where truncate and drop are the same) - if destination_config.destination == "filesystem": + if destination_config.destination_type == "filesystem": result = sorted([row["id"] for row in data["some_data_1"]]) assert result == [3, 4] diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index 0e03119b7b..050636c491 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -219,13 +219,13 @@ def test_get_schemas_from_destination( use_single_dataset: bool, naming_convention: str, ) -> None: - set_naming_env(destination_config.destination, naming_convention) + set_naming_env(destination_config.destination_type, naming_convention) pipeline_name = "pipe_" + uniq_id() dataset_name = "state_test_" + uniq_id() p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) - assert_naming_to_caps(destination_config.destination, p.destination.capabilities()) + assert_naming_to_caps(destination_config.destination_type, p.destination.capabilities()) p.config.use_single_dataset = use_single_dataset def _make_dn_name(schema_name: str) -> str: @@ -318,13 +318,13 @@ def _make_dn_name(schema_name: str) -> str: def test_restore_state_pipeline( destination_config: DestinationTestConfiguration, naming_convention: str ) -> None: - set_naming_env(destination_config.destination, naming_convention) + set_naming_env(destination_config.destination_type, naming_convention) # enable restoring from destination os.environ["RESTORE_FROM_DESTINATION"] = "True" pipeline_name = "pipe_" + uniq_id() dataset_name = "state_test_" + uniq_id() p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) - assert_naming_to_caps(destination_config.destination, p.destination.capabilities()) + assert_naming_to_caps(destination_config.destination_type, p.destination.capabilities()) def some_data_gen(param: str) -> Any: dlt.current.source_state()[param] = param @@ -550,7 +550,7 @@ def test_restore_schemas_while_import_schemas_exist( ) # use run to get changes p.run( - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name=dataset_name, **destination_config.run_kwargs, @@ -601,7 +601,7 @@ def some_data(param: str) -> Any: p.run( [data1, some_data("state2")], schema=Schema("default"), - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name=dataset_name, **destination_config.run_kwargs, @@ -611,7 +611,7 @@ def some_data(param: str) -> Any: # create a production pipeline in separate pipelines_dir production_p = dlt.pipeline(pipeline_name=pipeline_name, pipelines_dir=TEST_STORAGE_ROOT) production_p.run( - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name=dataset_name, **destination_config.run_kwargs, @@ -691,7 +691,9 @@ def some_data(param: str) -> Any: [5, 4, 4, 3, 2], ) except SqlClientNotAvailable: - pytest.skip(f"destination {destination_config.destination} does not support sql client") + pytest.skip( + f"destination {destination_config.destination_type} does not support sql client" + ) @pytest.mark.parametrize( @@ -719,7 +721,7 @@ def some_data(param: str) -> Any: p.run( data4, schema=Schema("sch1"), - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name=dataset_name, **destination_config.run_kwargs, @@ -749,7 +751,7 @@ def some_data(param: str) -> Any: p.run( data4, schema=Schema("sch1"), - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name=dataset_name, **destination_config.run_kwargs, diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index 0fc6a37dd9..42c288397a 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -126,7 +126,7 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: # check item of first row in db with pipeline.sql_client() as sql_client: qual_name = sql_client.make_qualified_table_name - if destination_config.destination in ["mssql", "synapse"]: + if destination_config.destination_type in ["mssql", "synapse"]: rows = sql_client.execute_sql( f"SELECT TOP 1 url FROM {qual_name('issues')} WHERE id = 388089021" ) @@ -148,7 +148,7 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: # check changes where merged in with pipeline.sql_client() as sql_client: - if destination_config.destination in ["mssql", "synapse"]: + if destination_config.destination_type in ["mssql", "synapse"]: qual_name = sql_client.make_qualified_table_name rows_1 = sql_client.execute_sql( f"SELECT TOP 1 number FROM {qual_name('issues')} WHERE id = 1232152492" @@ -279,7 +279,7 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non # redshift and athena, parquet and jsonl, exclude time types exclude_types: List[TDataType] = [] exclude_columns: List[str] = [] - if destination_config.destination in ( + if destination_config.destination_type in ( "redshift", "athena", "databricks", @@ -287,10 +287,10 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non ) and destination_config.file_format in ("parquet", "jsonl"): # Redshift copy doesn't support TIME column exclude_types.append("time") - if destination_config.destination == "synapse" and destination_config.file_format == "parquet": + if destination_config.destination_type == "synapse" and destination_config.file_format == "parquet": # TIME columns are not supported for staged parquet loads into Synapse exclude_types.append("time") - if destination_config.destination in ( + if destination_config.destination_type in ( "redshift", "dremio", ) and destination_config.file_format in ( @@ -299,7 +299,7 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non ): # Redshift can't load fixed width binary columns from parquet exclude_columns.append("col7_precision") - if destination_config.destination == "databricks" and destination_config.file_format == "jsonl": + if destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl": exclude_types.extend(["decimal", "binary", "wei", "json", "date"]) exclude_columns.append("col1_precision") @@ -309,12 +309,12 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non # bigquery and clickhouse cannot load into JSON fields from parquet if destination_config.file_format == "parquet": - if destination_config.destination in ["bigquery"]: + if destination_config.destination_type in ["bigquery"]: # change datatype to text and then allow for it in the assert (parse_json_strings) column_schemas["col9_null"]["data_type"] = column_schemas["col9"]["data_type"] = "text" # redshift cannot load from json into VARBYTE if destination_config.file_format == "jsonl": - if destination_config.destination == "redshift": + if destination_config.destination_type == "redshift": # change the datatype to text which will result in inserting base64 (allow_base64_binary) binary_cols = ["col7", "col7_null"] for col in binary_cols: @@ -344,15 +344,15 @@ def my_source(): # parquet is not really good at inserting json, best we get are strings in JSON columns parse_json_strings = ( destination_config.file_format == "parquet" - and destination_config.destination in ["redshift", "bigquery", "snowflake"] + and destination_config.destination_type in ["redshift", "bigquery", "snowflake"] ) allow_base64_binary = ( destination_config.file_format == "jsonl" - and destination_config.destination in ["redshift", "clickhouse"] + and destination_config.destination_type in ["redshift", "clickhouse"] ) allow_string_binary = ( destination_config.file_format == "parquet" - and destination_config.destination in ["clickhouse"] + and destination_config.destination_type in ["clickhouse"] ) # content must equal assert_all_data_types_row( diff --git a/tests/load/pipeline/test_write_disposition_changes.py b/tests/load/pipeline/test_write_disposition_changes.py index 1a799059d0..c28890e2c6 100644 --- a/tests/load/pipeline/test_write_disposition_changes.py +++ b/tests/load/pipeline/test_write_disposition_changes.py @@ -126,7 +126,7 @@ def source(): # schemaless destinations allow adding of root key without the pipeline failing # they do not mind adding NOT NULL columns to tables with existing data (id NOT NULL is supported at all) # doing this will result in somewhat useless behavior - destination_allows_adding_root_key = destination_config.destination in [ + destination_allows_adding_root_key = destination_config.destination_type in [ "dremio", "clickhouse", "athena", diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index df342d4f01..de4e987aa2 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -62,7 +62,7 @@ def file_storage() -> FileStorage: @pytest.fixture(scope="function") def client(request, naming) -> Iterator[SqlJobClientBase]: - yield from yield_client_with_storage(request.param.destination) + yield from yield_client_with_storage(request.param.destination_factory()) @pytest.fixture(scope="function") @@ -754,7 +754,7 @@ def test_get_resumed_job(client: SqlJobClientBase, file_storage: FileStorage) -> ) def test_default_schema_name_init_storage(destination_config: DestinationTestConfiguration) -> None: with cm_yield_client_with_storage( - destination_config.destination, + destination_config.destination_factory(), default_config_values={ "default_schema_name": ( # pass the schema that is a default schema. that should create dataset with the name `dataset_name` "event" @@ -765,7 +765,7 @@ def test_default_schema_name_init_storage(destination_config: DestinationTestCon assert client.sql_client.has_dataset() with cm_yield_client_with_storage( - destination_config.destination, + destination_config.destination_factory(), default_config_values={ "default_schema_name": ( None # no default_schema. that should create dataset with the name `dataset_name` @@ -776,7 +776,7 @@ def test_default_schema_name_init_storage(destination_config: DestinationTestCon assert client.sql_client.has_dataset() with cm_yield_client_with_storage( - destination_config.destination, + destination_config.destination_factory(), default_config_values={ "default_schema_name": ( # the default schema is not event schema . that should create dataset with the name `dataset_name` with schema suffix "event_2" @@ -805,7 +805,7 @@ def test_get_stored_state( os.environ["SCHEMA__NAMING"] = naming_convention with cm_yield_client_with_storage( - destination_config.destination, default_config_values={"default_schema_name": None} + destination_config.destination_factory(), default_config_values={"default_schema_name": None} ) as client: # event schema with event table if not client.capabilities.preferred_loader_file_format: @@ -869,7 +869,7 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: assert len(db_rows) == expected_rows with cm_yield_client_with_storage( - destination_config.destination, default_config_values={"default_schema_name": None} + destination_config.destination_factory(), default_config_values={"default_schema_name": None} ) as client: # event schema with event table if not client.capabilities.preferred_loader_file_format: @@ -928,9 +928,9 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: "mandatory_column", "text", nullable=False ) client.schema._bump_version() - if destination_config.destination == "clickhouse" or ( + if destination_config.destination_type == "clickhouse" or ( # mysql allows adding not-null columns (they have an implicit default) - destination_config.destination == "sqlalchemy" + destination_config.destination_type == "sqlalchemy" and client.sql_client.dialect_name == "mysql" ): client.update_stored_schema() diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index fa154c65dc..9743eb8ec6 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -28,6 +28,7 @@ prepare_table, AWS_BUCKET, destinations_configs, + DestinationTestConfiguration, ) # mark all tests as essential, do not remove @@ -46,7 +47,8 @@ def file_storage() -> FileStorage: @pytest.fixture(scope="function") def client(request, naming) -> Iterator[SqlJobClientBase]: - yield from yield_client_with_storage(request.param.destination) + param: DestinationTestConfiguration = request.param + yield from yield_client_with_storage(param.destination_factory()) @pytest.fixture(scope="function") diff --git a/tests/load/utils.py b/tests/load/utils.py index 0fad4e1ddc..da936a311c 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -144,7 +144,7 @@ class DestinationTestConfiguration: """Class for defining test setup for one destination.""" - destination: str + destination_type: str staging: Optional[TDestinationReferenceArg] = None file_format: Optional[TLoaderFileFormat] = None table_format: Optional[TTableFormat] = None @@ -160,10 +160,17 @@ class DestinationTestConfiguration: dev_mode: bool = False credentials: Optional[Union[CredentialsConfiguration, Dict[str, Any]]] = None env_vars: Optional[Dict[str, str]] = None + destination_name: Optional[str] = None + + def destination_factory(self, **kwargs) -> Destination[Any, Any]: + dest_type = kwargs.pop("destination", self.destination_type) + dest_name = kwargs.pop("destination_name", self.destination_name) + self.setup() + return Destination.from_reference(dest_type, destination_name=dest_name, **kwargs) @property def name(self) -> str: - name: str = self.destination + name: str = self.destination_name or self.destination_type if self.file_format: name += f"-{self.file_format}" if self.table_format: @@ -196,7 +203,7 @@ def setup(self) -> None: os.environ[f"DESTINATION__{k.upper()}"] = str(v) # For the filesystem destinations we disable compression to make analyzing the result easier - if self.destination == "filesystem" or self.disable_compression: + if self.destination_type == "filesystem" or self.disable_compression: os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" if self.credentials is not None: @@ -211,11 +218,12 @@ def setup_pipeline( self, pipeline_name: str, dataset_name: str = None, dev_mode: bool = False, **kwargs ) -> dlt.Pipeline: """Convenience method to setup pipeline with this configuration""" + self.dev_mode = dev_mode - self.setup() + factory = self.destination_factory(**kwargs) pipeline = dlt.pipeline( pipeline_name=pipeline_name, - destination=kwargs.pop("destination", self.destination), + destination=factory, staging=kwargs.pop("staging", self.staging), dataset_name=dataset_name or pipeline_name, dev_mode=dev_mode, @@ -274,13 +282,13 @@ def destinations_configs( default_sql_configs_with_staging = [ # Athena needs filesystem staging, which will be automatically set; we have to supply a bucket url though. DestinationTestConfiguration( - destination="athena", + destination_type="athena", file_format="parquet", supports_merge=False, bucket_url=AWS_BUCKET, ), DestinationTestConfiguration( - destination="athena", + destination_type="athena", file_format="parquet", bucket_url=AWS_BUCKET, supports_merge=True, @@ -293,14 +301,16 @@ def destinations_configs( # default non staging sql based configs, one per destination if default_sql_configs: destination_configs += [ - DestinationTestConfiguration(destination=destination) + DestinationTestConfiguration(destination_type=destination) for destination in SQL_DESTINATIONS if destination not in ("athena", "synapse", "databricks", "dremio", "clickhouse", "sqlalchemy") ] destination_configs += [ - DestinationTestConfiguration(destination="duckdb", file_format="parquet"), - DestinationTestConfiguration(destination="motherduck", file_format="insert_values"), + DestinationTestConfiguration(destination_type="duckdb", file_format="parquet"), + DestinationTestConfiguration( + destination_type="motherduck", file_format="insert_values" + ), ] # add Athena staging configs @@ -308,17 +318,21 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration( - destination="sqlalchemy", supports_merge=False, supports_dbt=False + destination_type="sqlalchemy", + supports_merge=False, + supports_dbt=False, + destination_name="mysql_driver", ) ] + destination_configs += [ DestinationTestConfiguration( - destination="clickhouse", file_format="jsonl", supports_dbt=False + destination_type="clickhouse", file_format="jsonl", supports_dbt=False ) ] destination_configs += [ DestinationTestConfiguration( - destination="databricks", + destination_type="databricks", file_format="parquet", bucket_url=AZ_BUCKET, extra_info="az-authorization", @@ -327,7 +341,7 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration( - destination="dremio", + destination_type="dremio", staging=filesystem(destination_name="minio"), file_format="parquet", bucket_url=AWS_BUCKET, @@ -335,24 +349,24 @@ def destinations_configs( ) ] destination_configs += [ - # DestinationTestConfiguration(destination="mssql", supports_dbt=False), - DestinationTestConfiguration(destination="synapse", supports_dbt=False), + # DestinationTestConfiguration(destination_type="mssql", supports_dbt=False), + DestinationTestConfiguration(destination_type="synapse", supports_dbt=False), ] # sanity check that when selecting default destinations, one of each sql destination is actually # provided - assert set(SQL_DESTINATIONS) == {d.destination for d in destination_configs} + assert set(SQL_DESTINATIONS) == {d.destination_type for d in destination_configs} if default_vector_configs: destination_configs += [ - DestinationTestConfiguration(destination="weaviate"), - DestinationTestConfiguration(destination="lancedb"), + DestinationTestConfiguration(destination_type="weaviate"), + DestinationTestConfiguration(destination_type="lancedb"), DestinationTestConfiguration( - destination="qdrant", + destination_type="qdrant", credentials=dict(path=str(Path(FILE_BUCKET) / "qdrant_data")), extra_info="local-file", ), - DestinationTestConfiguration(destination="qdrant", extra_info="server"), + DestinationTestConfiguration(destination_type="qdrant", extra_info="server"), ] if (default_sql_configs or all_staging_configs) and not default_sql_configs: @@ -362,7 +376,7 @@ def destinations_configs( if default_staging_configs or all_staging_configs: destination_configs += [ DestinationTestConfiguration( - destination="redshift", + destination_type="redshift", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, @@ -370,14 +384,14 @@ def destinations_configs( extra_info="s3-role", ), DestinationTestConfiguration( - destination="bigquery", + destination_type="bigquery", staging="filesystem", file_format="parquet", bucket_url=GCS_BUCKET, extra_info="gcs-authorization", ), DestinationTestConfiguration( - destination="snowflake", + destination_type="snowflake", staging="filesystem", file_format="jsonl", bucket_url=GCS_BUCKET, @@ -385,14 +399,14 @@ def destinations_configs( extra_info="gcs-integration", ), DestinationTestConfiguration( - destination="snowflake", + destination_type="snowflake", staging="filesystem", file_format="jsonl", bucket_url=AWS_BUCKET, extra_info="s3-integration", ), DestinationTestConfiguration( - destination="snowflake", + destination_type="snowflake", staging="filesystem", file_format="jsonl", bucket_url=AWS_BUCKET, @@ -400,7 +414,7 @@ def destinations_configs( extra_info="s3-integration", ), DestinationTestConfiguration( - destination="snowflake", + destination_type="snowflake", staging="filesystem", file_format="jsonl", bucket_url=AZ_BUCKET, @@ -408,14 +422,14 @@ def destinations_configs( extra_info="az-integration", ), DestinationTestConfiguration( - destination="snowflake", + destination_type="snowflake", staging="filesystem", file_format="jsonl", bucket_url=AZ_BUCKET, extra_info="az-authorization", ), DestinationTestConfiguration( - destination="databricks", + destination_type="databricks", staging="filesystem", file_format="jsonl", bucket_url=AWS_BUCKET, @@ -423,7 +437,7 @@ def destinations_configs( disable_compression=True, ), DestinationTestConfiguration( - destination="databricks", + destination_type="databricks", staging="filesystem", file_format="jsonl", bucket_url=AZ_BUCKET, @@ -431,14 +445,14 @@ def destinations_configs( disable_compression=True, ), DestinationTestConfiguration( - destination="databricks", + destination_type="databricks", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, extra_info="s3-authorization", ), DestinationTestConfiguration( - destination="synapse", + destination_type="synapse", staging="filesystem", file_format="parquet", bucket_url=AZ_BUCKET, @@ -446,35 +460,35 @@ def destinations_configs( disable_compression=True, ), DestinationTestConfiguration( - destination="clickhouse", + destination_type="clickhouse", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, extra_info="s3-authorization", ), DestinationTestConfiguration( - destination="clickhouse", + destination_type="clickhouse", staging="filesystem", file_format="parquet", bucket_url=AZ_BUCKET, extra_info="az-authorization", ), DestinationTestConfiguration( - destination="clickhouse", + destination_type="clickhouse", staging="filesystem", file_format="jsonl", bucket_url=AZ_BUCKET, extra_info="az-authorization", ), DestinationTestConfiguration( - destination="clickhouse", + destination_type="clickhouse", staging="filesystem", file_format="jsonl", bucket_url=AWS_BUCKET, extra_info="s3-authorization", ), DestinationTestConfiguration( - destination="dremio", + destination_type="dremio", staging=filesystem(destination_name="minio"), file_format="parquet", bucket_url=AWS_BUCKET, @@ -485,35 +499,35 @@ def destinations_configs( if all_staging_configs: destination_configs += [ DestinationTestConfiguration( - destination="redshift", + destination_type="redshift", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, extra_info="credential-forwarding", ), DestinationTestConfiguration( - destination="snowflake", + destination_type="snowflake", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, extra_info="credential-forwarding", ), DestinationTestConfiguration( - destination="redshift", + destination_type="redshift", staging="filesystem", file_format="jsonl", bucket_url=AWS_BUCKET, extra_info="credential-forwarding", ), DestinationTestConfiguration( - destination="bigquery", + destination_type="bigquery", staging="filesystem", file_format="jsonl", bucket_url=GCS_BUCKET, extra_info="gcs-authorization", ), DestinationTestConfiguration( - destination="synapse", + destination_type="synapse", staging="filesystem", file_format="parquet", bucket_url=AZ_BUCKET, @@ -526,7 +540,7 @@ def destinations_configs( if local_filesystem_configs: destination_configs += [ DestinationTestConfiguration( - destination="filesystem", + destination_type="filesystem", bucket_url=FILE_BUCKET, file_format="insert_values", supports_merge=False, @@ -534,7 +548,7 @@ def destinations_configs( ] destination_configs += [ DestinationTestConfiguration( - destination="filesystem", + destination_type="filesystem", bucket_url=FILE_BUCKET, file_format="parquet", supports_merge=False, @@ -542,7 +556,7 @@ def destinations_configs( ] destination_configs += [ DestinationTestConfiguration( - destination="filesystem", + destination_type="filesystem", bucket_url=FILE_BUCKET, file_format="jsonl", supports_merge=False, @@ -553,7 +567,7 @@ def destinations_configs( for bucket in DEFAULT_BUCKETS: destination_configs += [ DestinationTestConfiguration( - destination="filesystem", + destination_type="filesystem", bucket_url=bucket, extra_info=bucket, supports_merge=False, @@ -564,7 +578,7 @@ def destinations_configs( for bucket in DEFAULT_BUCKETS: destination_configs += [ DestinationTestConfiguration( - destination="filesystem", + destination_type="filesystem", bucket_url=bucket, extra_info=bucket, table_format="delta", @@ -583,27 +597,29 @@ def destinations_configs( # filter out non active destinations destination_configs = [ - conf for conf in destination_configs if conf.destination in ACTIVE_DESTINATIONS + conf for conf in destination_configs if conf.destination_type in ACTIVE_DESTINATIONS ] # filter out destinations not in subset if subset: - destination_configs = [conf for conf in destination_configs if conf.destination in subset] + destination_configs = [ + conf for conf in destination_configs if conf.destination_type in subset + ] if bucket_subset: destination_configs = [ conf for conf in destination_configs - if conf.destination != "filesystem" or conf.bucket_url in bucket_subset + if conf.destination_type != "filesystem" or conf.bucket_url in bucket_subset ] if exclude: destination_configs = [ - conf for conf in destination_configs if conf.destination not in exclude + conf for conf in destination_configs if conf.destination_type not in exclude ] if bucket_exclude: destination_configs = [ conf for conf in destination_configs - if conf.destination != "filesystem" or conf.bucket_url not in bucket_exclude + if conf.destination_type != "filesystem" or conf.bucket_url not in bucket_exclude ] if with_file_format: if not isinstance(with_file_format, Sequence): @@ -780,14 +796,14 @@ def prepare_table( def yield_client( - destination_type: str, + destination_ref: TDestinationReferenceArg, dataset_name: str = None, default_config_values: StrAny = None, schema_name: str = "event", ) -> Iterator[SqlJobClientBase]: os.environ.pop("DATASET_NAME", None) # import destination reference by name - destination = Destination.from_reference(destination_type) + destination = Destination.from_reference(destination_ref) # create initial config dest_config: DestinationClientDwhConfiguration = None dest_config = destination.spec() # type: ignore @@ -812,7 +828,7 @@ def yield_client( client: SqlJobClientBase = None # athena requires staging config to be present, so stick this in there here - if destination_type == "athena": + if destination.destination_type == "athena": staging_config = DestinationClientStagingConfiguration( bucket_url=AWS_BUCKET, )._bind_dataset_name(dataset_name=dest_config.dataset_name) @@ -825,7 +841,7 @@ def yield_client( ConfigSectionContext( sections=( "destination", - destination_type, + destination.destination_name, ) ) ): @@ -835,23 +851,23 @@ def yield_client( @contextlib.contextmanager def cm_yield_client( - destination_type: str, + destination: TDestinationReferenceArg, dataset_name: str, default_config_values: StrAny = None, schema_name: str = "event", ) -> Iterator[SqlJobClientBase]: - return yield_client(destination_type, dataset_name, default_config_values, schema_name) + return yield_client(destination, dataset_name, default_config_values, schema_name) def yield_client_with_storage( - destination_type: str, default_config_values: StrAny = None, schema_name: str = "event" + destination: TDestinationReferenceArg, + default_config_values: StrAny = None, + schema_name: str = "event", ) -> Iterator[SqlJobClientBase]: # create dataset with random name dataset_name = "test_" + uniq_id() - with cm_yield_client( - destination_type, dataset_name, default_config_values, schema_name - ) as client: + with cm_yield_client(destination, dataset_name, default_config_values, schema_name) as client: client.initialize_storage() yield client if client.is_storage_initialized(): @@ -872,9 +888,11 @@ def delete_dataset(client: SqlClientBase[Any], normalized_dataset_name: str) -> @contextlib.contextmanager def cm_yield_client_with_storage( - destination_type: str, default_config_values: StrAny = None, schema_name: str = "event" + destination: TDestinationReferenceArg, + default_config_values: StrAny = None, + schema_name: str = "event", ) -> Iterator[SqlJobClientBase]: - return yield_client_with_storage(destination_type, default_config_values, schema_name) + return yield_client_with_storage(destination, default_config_values, schema_name) def write_dataset( diff --git a/tests/pipeline/test_pipeline_extra.py b/tests/pipeline/test_pipeline_extra.py index 1fe6231279..821bec8e08 100644 --- a/tests/pipeline/test_pipeline_extra.py +++ b/tests/pipeline/test_pipeline_extra.py @@ -56,8 +56,8 @@ class BaseModel: # type: ignore[no-redef] def test_create_pipeline_all_destinations(destination_config: DestinationTestConfiguration) -> None: # create pipelines, extract and normalize. that should be possible without installing any dependencies p = dlt.pipeline( - pipeline_name=destination_config.destination + "_pipeline", - destination=destination_config.destination, + pipeline_name=destination_config.destination_type + "_pipeline", + destination=destination_config.destination_type, staging=destination_config.staging, ) # are capabilities injected From e3eaa43c4f680728120ff48d86cc261df52527e0 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 11 Sep 2024 12:20:03 -0400 Subject: [PATCH 03/37] Some job client/sql client tests running on sqlite --- .../impl/sqlalchemy/db_api_client.py | 15 ++++----- tests/cases.py | 7 ++-- tests/load/pipeline/test_arrow_loading.py | 6 ++-- tests/load/pipeline/test_pipelines.py | 4 ++- tests/load/pipeline/test_stage_loading.py | 5 ++- tests/load/test_job_client.py | 32 +++++++++++++------ tests/load/test_sql_client.py | 3 +- tests/load/utils.py | 10 ++++-- 8 files changed, 54 insertions(+), 28 deletions(-) diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index 0969fcb628..2255bc740a 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -114,7 +114,8 @@ def __init__( self.external_engine = True else: self.engine = sa.create_engine( - credentials.to_url().render_as_string(hide_password=False), **(engine_args or {}) + credentials.to_url().render_as_string(hide_password=False), + **(engine_args or {}), ) self._current_connection: Optional[Connection] = None @@ -198,7 +199,7 @@ def _sqlite_create_dataset(self, dataset_name: str) -> None: ) statement = "ATTACH DATABASE :fn AS :name" - self.execute(statement, fn=new_db_fn, name=dataset_name) + self.execute_sql(statement, fn=new_db_fn, name=dataset_name) def _sqlite_drop_dataset(self, dataset_name: str) -> None: """Drop a dataset in sqlite by detaching the database file @@ -208,10 +209,10 @@ def _sqlite_drop_dataset(self, dataset_name: str) -> None: rows = self.execute_sql("PRAGMA database_list") dbs = {row[1]: row[2] for row in rows} # db_name: filename if dataset_name not in dbs: - return + raise DatabaseUndefinedRelation(f"Database {dataset_name} does not exist") statement = "DETACH DATABASE :name" - self.execute(statement, name=dataset_name) + self.execute_sql(statement, name=dataset_name) fn = dbs[dataset_name] if not fn: # It's a memory database, nothing to do @@ -230,11 +231,7 @@ def drop_dataset(self) -> None: try: self.execute_sql(sa.schema.DropSchema(self.dataset_name, cascade=True)) except DatabaseTransientException as e: - if isinstance(e.__cause__, sa.exc.ProgrammingError): - # May not support CASCADE - self.execute_sql(sa.schema.DropSchema(self.dataset_name)) - else: - raise + self.execute_sql(sa.schema.DropSchema(self.dataset_name)) def truncate_tables(self, *tables: str) -> None: # TODO: alchemy doesn't have a construct for TRUNCATE TABLE diff --git a/tests/cases.py b/tests/cases.py index dd362b538f..162a4c4b1a 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -216,19 +216,22 @@ def assert_all_data_types_row( expected_rows = {key: value for key, value in expected_row.items() if key in schema} # prepare date to be compared: convert into pendulum instance, adjust microsecond precision if "col4" in expected_rows: - parsed_date = pendulum.instance(db_mapping["col4"]) + parsed_date = ensure_pendulum_datetime((db_mapping["col4"])) db_mapping["col4"] = reduce_pendulum_datetime_precision(parsed_date, timestamp_precision) expected_rows["col4"] = reduce_pendulum_datetime_precision( ensure_pendulum_datetime(expected_rows["col4"]), # type: ignore[arg-type] timestamp_precision, ) if "col4_precision" in expected_rows: - parsed_date = pendulum.instance(db_mapping["col4_precision"]) + parsed_date = ensure_pendulum_datetime((db_mapping["col4_precision"])) db_mapping["col4_precision"] = reduce_pendulum_datetime_precision(parsed_date, 3) expected_rows["col4_precision"] = reduce_pendulum_datetime_precision( ensure_pendulum_datetime(expected_rows["col4_precision"]), 3 # type: ignore[arg-type] ) + if "col10" in expected_rows: + db_mapping["col10"] = ensure_pendulum_date(db_mapping["col10"]) + if "col11" in expected_rows: expected_rows["col11"] = reduce_pendulum_datetime_precision( ensure_pendulum_time(expected_rows["col11"]), timestamp_precision # type: ignore[arg-type] diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 86711e679a..b2fda87860 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -54,10 +54,12 @@ def test_load_arrow_item( ) include_decimal = not ( - destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl" + destination_config.destination_type == "databricks" + and destination_config.file_format == "jsonl" ) include_date = not ( - destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl" + destination_config.destination_type == "databricks" + and destination_config.file_format == "jsonl" ) item, records, _ = arrow_table_all_data_types( diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 8435c13dd1..5af41025ba 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -585,7 +585,9 @@ def some_source(): parse_json_strings=destination_config.destination_type in ["snowflake", "bigquery", "redshift"], allow_string_binary=destination_config.destination_type == "clickhouse", - timestamp_precision=3 if destination_config.destination_type in ("athena", "dremio") else 6, + timestamp_precision=( + 3 if destination_config.destination_type in ("athena", "dremio") else 6 + ), ) diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index 42c288397a..de30615a6a 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -287,7 +287,10 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non ) and destination_config.file_format in ("parquet", "jsonl"): # Redshift copy doesn't support TIME column exclude_types.append("time") - if destination_config.destination_type == "synapse" and destination_config.file_format == "parquet": + if ( + destination_config.destination_type == "synapse" + and destination_config.file_format == "parquet" + ): # TIME columns are not supported for staged parquet loads into Synapse exclude_types.append("time") if destination_config.destination_type in ( diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index de4e987aa2..ba3baec17c 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -27,7 +27,12 @@ ) from dlt.destinations.job_client_impl import SqlJobClientBase -from dlt.common.destination.reference import StateInfo, WithStagingDataset +from dlt.common.destination.reference import ( + StateInfo, + WithStagingDataset, + DestinationClientConfiguration, +) +from dlt.common.time import ensure_pendulum_datetime from tests.cases import table_update_and_row, assert_all_data_types_row from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage @@ -202,7 +207,7 @@ def test_complete_load(naming: str, client: SqlJobClientBase) -> None: assert load_rows[0][2] == 0 import datetime # noqa: I251 - assert type(load_rows[0][3]) is datetime.datetime + assert isinstance(ensure_pendulum_datetime(load_rows[0][3]), datetime.datetime) assert load_rows[0][4] == client.schema.version_hash # make sure that hash in loads exists in schema versions table versions_table = client.sql_client.make_qualified_table_name(version_table_name) @@ -571,7 +576,7 @@ def test_load_with_all_types( if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") table_name = "event_test_table" + uniq_id() - column_schemas, data_row = get_columns_and_row_all_types(client.config.destination_type) + column_schemas, data_row = get_columns_and_row_all_types(client.config) # we should have identical content with all disposition types partial = client.schema.update_table( @@ -646,7 +651,7 @@ def test_write_dispositions( os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy table_name = "event_test_table" + uniq_id() - column_schemas, data_row = get_columns_and_row_all_types(client.config.destination_type) + column_schemas, data_row = get_columns_and_row_all_types(client.config) client.schema.update_table( new_table(table_name, write_disposition=write_disposition, columns=column_schemas.values()) ) @@ -805,7 +810,8 @@ def test_get_stored_state( os.environ["SCHEMA__NAMING"] = naming_convention with cm_yield_client_with_storage( - destination_config.destination_factory(), default_config_values={"default_schema_name": None} + destination_config.destination_factory(), + default_config_values={"default_schema_name": None}, ) as client: # event schema with event table if not client.capabilities.preferred_loader_file_format: @@ -869,7 +875,8 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: assert len(db_rows) == expected_rows with cm_yield_client_with_storage( - destination_config.destination_factory(), default_config_values={"default_schema_name": None} + destination_config.destination_factory(), + default_config_values={"default_schema_name": None}, ) as client: # event schema with event table if not client.capabilities.preferred_loader_file_format: @@ -965,11 +972,16 @@ def normalize_rows(rows: List[Dict[str, Any]], naming: NamingConvention) -> None row[naming.normalize_identifier(k)] = row.pop(k) -def get_columns_and_row_all_types(destination_type: str): +def get_columns_and_row_all_types(destination_config: DestinationClientConfiguration): + exclude_types = [] + if destination_config.destination_type in ["databricks", "clickhouse", "motherduck"]: + exclude_types.append("time") + if destination_config.destination_name == "sqlalchemy_sqlite": + exclude_types.extend(["decimal", "wei"]) return table_update_and_row( # TIME + parquet is actually a duckdb problem: https://github.com/duckdb/duckdb/pull/13283 - exclude_types=( - ["time"] if destination_type in ["databricks", "clickhouse", "motherduck"] else None + exclude_types=exclude_types, # type: ignore[arg-type] + exclude_columns=( + ["col4_precision"] if destination_config.destination_type in ["motherduck"] else None ), - exclude_columns=["col4_precision"] if destination_type in ["motherduck"] else None, ) diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 9743eb8ec6..2b1097ac82 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -149,6 +149,7 @@ def test_has_dataset(naming: str, client: SqlJobClientBase) -> None: ) def test_create_drop_dataset(naming: str, client: SqlJobClientBase) -> None: # client.sql_client.create_dataset() + # Dataset is already create in fixture, so next time it fails with pytest.raises(DatabaseException): client.sql_client.create_dataset() client.sql_client.drop_dataset() @@ -212,7 +213,7 @@ def test_execute_sql(client: SqlJobClientBase) -> None: assert len(rows) == 1 # print(rows) assert rows[0][0] == "event" - assert isinstance(rows[0][1], datetime.datetime) + assert isinstance(ensure_pendulum_datetime(rows[0][1]), datetime.datetime) assert rows[0][0] == "event" # print(rows[0][1]) # print(type(rows[0][1])) diff --git a/tests/load/utils.py b/tests/load/utils.py index da936a311c..55a3d9dd0b 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -321,8 +321,14 @@ def destinations_configs( destination_type="sqlalchemy", supports_merge=False, supports_dbt=False, - destination_name="mysql_driver", - ) + destination_name="sqlalchemy_mysql", + ), + DestinationTestConfiguration( + destination_type="sqlalchemy", + supports_merge=False, + supports_dbt=False, + destination_name="sqlalchemy_sqlite", + ), ] destination_configs += [ From 2973526bcf198fc47e0bac6ac1caf0553f9a58aa Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 4 Sep 2024 15:21:16 -0400 Subject: [PATCH 04/37] Fix more tests --- .../impl/sqlalchemy/db_api_client.py | 72 ++++++++++------ .../impl/sqlalchemy/sqlalchemy_job_client.py | 2 +- tests/load/test_sql_client.py | 83 +++++++++++-------- 3 files changed, 97 insertions(+), 60 deletions(-) diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index 2255bc740a..a72af376da 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -143,19 +143,27 @@ def native_connection(self) -> Connection: raise LoadClientNotConnected(type(self).__name__, self.dataset_name) return self._current_connection + def _in_transaction(self) -> bool: + return ( + self._current_transaction is not None + and self._current_transaction.sqla_transaction.is_active + ) + @contextmanager @raise_database_error def begin_transaction(self) -> Iterator[DBTransaction]: - if self._current_transaction is not None: + if self._in_transaction(): raise DatabaseTerminalException("Transaction already started") trans = self._current_transaction = SqlaTransactionWrapper(self._current_connection.begin()) try: yield trans except Exception: - self.rollback_transaction() + if self._in_transaction(): + self.rollback_transaction() raise else: - self.commit_transaction() + if self._in_transaction(): # Transaction could be committed/rolled back before __exit__ + self.commit_transaction() finally: self._current_transaction = None @@ -173,14 +181,15 @@ def _transaction(self) -> Iterator[DBTransaction]: New transaction will be committed/rolled back on exit. If the transaction is already open, finalization is handled by the top level context manager. """ - if self._current_transaction is not None: + if self._in_transaction(): yield self._current_transaction return with self.begin_transaction() as tx: yield tx def has_dataset(self) -> bool: - schema_names = self.engine.dialect.get_schema_names(self._current_connection) + with self._transaction(): + schema_names = self.engine.dialect.get_schema_names(self._current_connection) return self.dataset_name in schema_names def _sqlite_create_dataset(self, dataset_name: str) -> None: @@ -208,9 +217,6 @@ def _sqlite_drop_dataset(self, dataset_name: str) -> None: # Get a list of attached databases and filenames rows = self.execute_sql("PRAGMA database_list") dbs = {row[1]: row[2] for row in rows} # db_name: filename - if dataset_name not in dbs: - raise DatabaseUndefinedRelation(f"Database {dataset_name} does not exist") - statement = "DETACH DATABASE :name" self.execute_sql(statement, name=dataset_name) @@ -230,7 +236,7 @@ def drop_dataset(self) -> None: return self._sqlite_drop_dataset(self.dataset_name) try: self.execute_sql(sa.schema.DropSchema(self.dataset_name, cascade=True)) - except DatabaseTransientException as e: + except DatabaseException as e: self.execute_sql(sa.schema.DropSchema(self.dataset_name)) def truncate_tables(self, *tables: str) -> None: @@ -256,14 +262,21 @@ def execute_sql( def execute_query( self, query: Union[AnyStr, sa.sql.Executable], *args: Any, **kwargs: Any ) -> Iterator[DBApiCursor]: + if args and kwargs: + raise ValueError("Cannot use both positional and keyword arguments") if isinstance(query, str): if args: # Sqlalchemy text supports :named paramstyle for all dialects query, kwargs = self._to_named_paramstyle(query, args) # type: ignore[assignment] - args = () + args = [ + kwargs, + ] query = sa.text(query) + if kwargs: + # sqla2 takes either a dict or list of dicts + args = [kwargs] with self._transaction(): - yield SqlaDbApiCursor(self._current_connection.execute(query, *args, **kwargs)) # type: ignore[abstract] + yield SqlaDbApiCursor(self._current_connection.execute(query, *args)) # type: ignore[abstract] def get_existing_table(self, table_name: str) -> Optional[sa.Table]: """Get a table object from metadata if it exists""" @@ -271,7 +284,8 @@ def get_existing_table(self, table_name: str) -> Optional[sa.Table]: return self.metadata.tables.get(key) # type: ignore[no-any-return] def create_table(self, table_obj: sa.Table) -> None: - table_obj.create(self._current_connection) + with self._transaction(): + table_obj.create(self._current_connection) def _make_qualified_table_name(self, table: sa.Table, escape: bool = True) -> str: if escape: @@ -285,6 +299,11 @@ def make_qualified_table_name(self, table_name: str, escape: bool = True) -> str tbl = sa.Table(table_name, tmp_metadata, schema=self.dataset_name) return self._make_qualified_table_name(tbl, escape) + def fully_qualified_dataset_name(self, escape: bool = True, staging: bool = False) -> str: + if staging: + raise NotImplementedError("Staging not supported") + return self.dialect.identifier_preparer.format_schema(self.dataset_name) + def alter_table_add_column(self, column: sa.Column) -> None: """Execute an ALTER TABLE ... ADD COLUMN ... statement for the given column. The column must be fully defined and attached to a table. @@ -318,15 +337,16 @@ def reflect_table( if metadata is None: metadata = self.metadata try: - return sa.Table( - table_name, - metadata, - autoload_with=self._current_connection, - schema=self.dataset_name, - include_columns=include_columns, - extend_existing=True, - ) - except sa.exc.NoSuchTableError: + with self._transaction(): + return sa.Table( + table_name, + metadata, + autoload_with=self._current_connection, + schema=self.dataset_name, + include_columns=include_columns, + extend_existing=True, + ) + except DatabaseUndefinedRelation: return None def compare_storage_table(self, table_name: str) -> Tuple[sa.Table, List[sa.Column], bool]: @@ -358,14 +378,18 @@ def _make_database_exception(e: Exception) -> Exception: if isinstance(e, (sa.exc.ProgrammingError, sa.exc.OperationalError)): if "exist" in msg: # TODO: Hack return DatabaseUndefinedRelation(e) - elif "no such" in msg: # sqlite # TODO: Hack - return DatabaseUndefinedRelation(e) elif "unknown table" in msg: return DatabaseUndefinedRelation(e) elif "unknown database" in msg: return DatabaseUndefinedRelation(e) + elif "no such table" in msg: # sqlite # TODO: Hack + return DatabaseUndefinedRelation(e) + elif "no such database" in msg: # sqlite # TODO: Hack + return DatabaseUndefinedRelation(e) + elif "syntax" in msg: + return DatabaseTransientException(e) elif isinstance(e, (sa.exc.OperationalError, sa.exc.IntegrityError)): - raise DatabaseTerminalException(e) + return DatabaseTerminalException(e) return DatabaseTransientException(e) elif isinstance(e, sa.exc.SQLAlchemyError): return DatabaseTransientException(e) diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index 3cd6133d47..7fb10f4b18 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -437,7 +437,7 @@ def _get_stored_schema( version_table = self.schema.tables[self.schema.version_table_name] table_obj = self._to_table_object(version_table) with suppress(DatabaseUndefinedRelation): - q = sa.select([table_obj]) + q = sa.select(table_obj) if version_hash is not None: version_hash_col = self.schema.naming.normalize_identifier("version_hash") q = q.where(table_obj.c[version_hash_col] == version_hash) diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 2b1097ac82..82af82334e 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -1,7 +1,7 @@ import os import pytest import datetime # noqa: I251 -from typing import Iterator, Any +from typing import Iterator, Any, Tuple, Type, Union from threading import Thread, Event from time import sleep @@ -218,9 +218,14 @@ def test_execute_sql(client: SqlJobClientBase) -> None: # print(rows[0][1]) # print(type(rows[0][1])) # ensure datetime obj to make sure it is supported by dbapi + inserted_at = to_py_datetime(ensure_pendulum_datetime(rows[0][1])) + if client.config.destination_name == "sqlalchemy_sqlite": + # timezone aware datetime is not supported by sqlite + inserted_at = inserted_at.replace(tzinfo=None) + rows = client.sql_client.execute_sql( f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at = %s", - to_py_datetime(ensure_pendulum_datetime(rows[0][1])), + inserted_at, ) assert len(rows) == 1 # use rows in subsequent test @@ -246,20 +251,20 @@ def test_execute_sql(client: SqlJobClientBase) -> None: def test_execute_ddl(client: SqlJobClientBase) -> None: uniq_suffix = uniq_id() client.update_stored_schema() - table_name = prepare_temp_table(client) + table_name, py_type = prepare_temp_table(client) f_q_table_name = client.sql_client.make_qualified_table_name(table_name) client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (1.0)") rows = client.sql_client.execute_sql(f"SELECT * FROM {f_q_table_name}") - assert rows[0][0] == Decimal("1.0") + assert rows[0][0] == py_type("1.0") if client.config.destination_type == "dremio": username = client.config.credentials["username"] view_name = f'"@{username}"."view_tmp_{uniq_suffix}"' else: # create view, note that bigquery will not let you execute a view that does not have fully qualified table names. view_name = client.sql_client.make_qualified_table_name(f"view_tmp_{uniq_suffix}") - client.sql_client.execute_sql(f"CREATE VIEW {view_name} AS (SELECT * FROM {f_q_table_name});") + client.sql_client.execute_sql(f"CREATE VIEW {view_name} AS SELECT * FROM {f_q_table_name};") rows = client.sql_client.execute_sql(f"SELECT * FROM {view_name}") - assert rows[0][0] == Decimal("1.0") + assert rows[0][0] == py_type("1.0") @pytest.mark.parametrize( @@ -280,7 +285,7 @@ def test_execute_query(client: SqlJobClientBase) -> None: rows = curr.fetchall() assert len(rows) == 1 assert rows[0][0] == "event" - assert isinstance(rows[0][1], datetime.datetime) + assert isinstance(ensure_pendulum_datetime(rows[0][1]), datetime.datetime) with client.sql_client.execute_query( f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at = %s", rows[0][1], @@ -319,7 +324,7 @@ def test_execute_df(client: SqlJobClientBase) -> None: total_records = 3000 client.update_stored_schema() - table_name = prepare_temp_table(client) + table_name, py_type = prepare_temp_table(client) f_q_table_name = client.sql_client.make_qualified_table_name(table_name) if client.capabilities.insert_values_writer_type == "default": @@ -420,8 +425,7 @@ def test_database_exceptions(client: SqlJobClientBase) -> None: assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) if client.config.destination_type not in ["dremio", "clickhouse"]: with pytest.raises(DatabaseUndefinedRelation) as term_ex: - with client.sql_client.execute_query("DROP SCHEMA UNKNOWN"): - pass + client.sql_client.drop_dataset() assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) @@ -432,29 +436,29 @@ def test_database_exceptions(client: SqlJobClientBase) -> None: ids=lambda x: x.name, ) def test_commit_transaction(client: SqlJobClientBase) -> None: - table_name = prepare_temp_table(client) + table_name, py_type = prepare_temp_table(client) f_q_table_name = client.sql_client.make_qualified_table_name(table_name) with client.sql_client.begin_transaction(): - client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0")) + client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", py_type("1.0")) # check row still in transaction rows = client.sql_client.execute_sql( - f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) assert len(rows) == 1 # check row after commit rows = client.sql_client.execute_sql( - f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) assert len(rows) == 1 assert rows[0][0] == 1.0 with client.sql_client.begin_transaction() as tx: client.sql_client.execute_sql( - f"DELETE FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"DELETE FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) # explicit commit tx.commit_transaction() rows = client.sql_client.execute_sql( - f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) assert len(rows) == 0 @@ -468,22 +472,22 @@ def test_commit_transaction(client: SqlJobClientBase) -> None: def test_rollback_transaction(client: SqlJobClientBase) -> None: if client.capabilities.supports_transactions is False: pytest.skip("Destination does not support tx") - table_name = prepare_temp_table(client) + table_name, py_type = prepare_temp_table(client) f_q_table_name = client.sql_client.make_qualified_table_name(table_name) # test python exception with pytest.raises(RuntimeError): with client.sql_client.begin_transaction(): client.sql_client.execute_sql( - f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0") + f"INSERT INTO {f_q_table_name} VALUES (%s)", py_type("1.0") ) rows = client.sql_client.execute_sql( - f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) assert len(rows) == 1 # python exception triggers rollback raise RuntimeError("ROLLBACK") rows = client.sql_client.execute_sql( - f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) assert len(rows) == 0 @@ -492,23 +496,23 @@ def test_rollback_transaction(client: SqlJobClientBase) -> None: with pytest.raises(DatabaseException): with client.sql_client.begin_transaction(): client.sql_client.execute_sql( - f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0") + f"INSERT INTO {f_q_table_name} VALUES (%s)", py_type("1.0") ) # table does not exist client.sql_client.execute_sql( - f"SELECT col FROM {f_q_wrong_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_wrong_table_name} WHERE col = %s", py_type("1.0") ) rows = client.sql_client.execute_sql( - f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) assert len(rows) == 0 # test explicit rollback with client.sql_client.begin_transaction() as tx: - client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0")) + client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", py_type("1.0")) tx.rollback_transaction() rows = client.sql_client.execute_sql( - f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) assert len(rows) == 0 @@ -529,12 +533,12 @@ def test_rollback_transaction(client: SqlJobClientBase) -> None: def test_transaction_isolation(client: SqlJobClientBase) -> None: if client.capabilities.supports_transactions is False: pytest.skip("Destination does not support tx") - table_name = prepare_temp_table(client) + table_name, py_type = prepare_temp_table(client) f_q_table_name = client.sql_client.make_qualified_table_name(table_name) event = Event() event.clear() - def test_thread(thread_id: Decimal) -> None: + def test_thread(thread_id: Union[Decimal, float]) -> None: # make a copy of the sql_client thread_client = client.sql_client.__class__( client.sql_client.dataset_name, @@ -548,8 +552,8 @@ def test_thread(thread_id: Decimal) -> None: event.wait() with client.sql_client.begin_transaction() as tx: - client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0")) - t = Thread(target=test_thread, daemon=True, args=(Decimal("2.0"),)) + client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", py_type("1.0")) + t = Thread(target=test_thread, daemon=True, args=(py_type("2.0"),)) t.start() # thread 2.0 inserts sleep(3.0) @@ -560,13 +564,16 @@ def test_thread(thread_id: Decimal) -> None: t.join() # just in case close the connection - client.sql_client.close_connection() - # re open connection - client.sql_client.open_connection() + if ( + client.config.destination_name != "sqlalchemy_sqlite" + ): # keep sqlite connection to maintain attached datasets + client.sql_client.close_connection() + # re open connection + client.sql_client.open_connection() rows = client.sql_client.execute_sql(f"SELECT col FROM {f_q_table_name} ORDER BY col") assert len(rows) == 1 # only thread 2 is left - assert rows[0][0] == Decimal("2.0") + assert rows[0][0] == py_type("2.0") @pytest.mark.parametrize( @@ -679,11 +686,13 @@ def assert_load_id(sql_client: SqlClientBase[TNativeConn], load_id: str) -> None assert len(rows) == 1 -def prepare_temp_table(client: SqlJobClientBase) -> str: +def prepare_temp_table(client: SqlJobClientBase) -> Tuple[str, Type[Union[Decimal, float]]]: + """Return the table name and py type of value to insert""" uniq_suffix = uniq_id() table_name = f"tmp_{uniq_suffix}" ddl_suffix = "" coltype = "numeric" + py_type = Decimal if client.config.destination_type == "athena": ddl_suffix = ( f"LOCATION '{AWS_BUCKET}/ci/{table_name}' TBLPROPERTIES ('table_type'='ICEBERG'," @@ -691,6 +700,10 @@ def prepare_temp_table(client: SqlJobClientBase) -> str: ) coltype = "bigint" qualified_table_name = table_name + elif client.config.destination_name == "sqlalchemy_sqlite": + coltype = "float" + py_type = float + qualified_table_name = client.sql_client.make_qualified_table_name(table_name) elif client.config.destination_type == "clickhouse": ddl_suffix = "ENGINE = MergeTree() ORDER BY col" qualified_table_name = client.sql_client.make_qualified_table_name(table_name) @@ -699,4 +712,4 @@ def prepare_temp_table(client: SqlJobClientBase) -> str: client.sql_client.execute_sql( f"CREATE TABLE {qualified_table_name} (col {coltype}) {ddl_suffix};" ) - return table_name + return table_name, py_type From 8caf2f3cb39ea3e235e8fbfa4da8ff91e14dd761 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 5 Sep 2024 18:45:06 -0400 Subject: [PATCH 05/37] ALl sqlite tests passing --- .../impl/sqlalchemy/db_api_client.py | 39 +++++++++++++------ dlt/pipeline/pipeline.py | 2 +- tests/load/pipeline/test_arrow_loading.py | 35 +++++++++++------ tests/load/pipeline/test_merge_disposition.py | 6 ++- .../test_write_disposition_changes.py | 2 +- tests/load/test_sql_client.py | 13 ++++++- 6 files changed, 69 insertions(+), 28 deletions(-) diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index a72af376da..ca483c0a2e 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -127,6 +127,8 @@ def __init__( def open_connection(self) -> Connection: if self._current_connection is None: self._current_connection = self.engine.connect() + if self.dialect_name == "sqlite": + self._sqlite_reattach_dataset_if_exists(self.dataset_name) return self._current_connection def close_connection(self) -> None: @@ -192,20 +194,34 @@ def has_dataset(self) -> bool: schema_names = self.engine.dialect.get_schema_names(self._current_connection) return self.dataset_name in schema_names + def _sqlite_dataset_filename(self, dataset_name: str) -> str: + db_name = self.engine.url.database + current_file_path = Path(db_name) + return str( + current_file_path.parent + / f"{current_file_path.stem}__{dataset_name}{current_file_path.suffix}" + ) + + def _sqlite_is_memory_db(self) -> bool: + return self.engine.url.database == ":memory:" + + def _sqlite_reattach_dataset_if_exists(self, dataset_name: str) -> None: + """Re-attach previously created databases for a new sqlite connection + """ + if self._sqlite_is_memory_db(): + return + new_db_fn = self._sqlite_dataset_filename(dataset_name) + if Path(new_db_fn).exists(): + self._sqlite_create_dataset(dataset_name) + def _sqlite_create_dataset(self, dataset_name: str) -> None: """Mimic multiple schemas in sqlite using ATTACH DATABASE to attach a new database file to the current connection. """ - db_name = self.engine.url.database - if db_name == ":memory:": + if self._sqlite_is_memory_db(): new_db_fn = ":memory:" else: - current_file_path = Path(db_name) - # New filename e.g. ./results.db -> ./results__new_dataset_name.db - new_db_fn = str( - current_file_path.parent - / f"{current_file_path.stem}__{dataset_name}{current_file_path.suffix}" - ) + new_db_fn = self._sqlite_dataset_filename(dataset_name) statement = "ATTACH DATABASE :fn AS :name" self.execute_sql(statement, fn=new_db_fn, name=dataset_name) @@ -217,8 +233,9 @@ def _sqlite_drop_dataset(self, dataset_name: str) -> None: # Get a list of attached databases and filenames rows = self.execute_sql("PRAGMA database_list") dbs = {row[1]: row[2] for row in rows} # db_name: filename - statement = "DETACH DATABASE :name" - self.execute_sql(statement, name=dataset_name) + if dataset_name != "main": # main is the default database, it cannot be detached + statement = "DETACH DATABASE :name" + self.execute_sql(statement, name=dataset_name) fn = dbs[dataset_name] if not fn: # It's a memory database, nothing to do @@ -236,7 +253,7 @@ def drop_dataset(self) -> None: return self._sqlite_drop_dataset(self.dataset_name) try: self.execute_sql(sa.schema.DropSchema(self.dataset_name, cascade=True)) - except DatabaseException as e: + except DatabaseException: # Try again in case cascade is not supported self.execute_sql(sa.schema.DropSchema(self.dataset_name)) def truncate_tables(self, *tables: str) -> None: diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index fa10f5ac89..118c126082 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -578,7 +578,7 @@ def load( step_info = self._get_step_info(load_step) raise PipelineStepFailed( self, "load", load_step.current_load_id, l_ex, step_info - ) from l_ex + ) from l_ex @with_runtime_trace() @with_config_section(("run",)) diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index b2fda87860..a0da556e48 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, time as dt_time # noqa: I251 +from datetime import datetime, timedelta, time as dt_time, date # noqa: I251 import os import pytest @@ -9,7 +9,12 @@ import dlt from dlt.common import pendulum -from dlt.common.time import reduce_pendulum_datetime_precision, ensure_pendulum_time +from dlt.common.time import ( + reduce_pendulum_datetime_precision, + ensure_pendulum_time, + ensure_pendulum_datetime, + ensure_pendulum_date, +) from dlt.common.utils import uniq_id from tests.load.utils import destinations_configs, DestinationTestConfiguration @@ -53,10 +58,14 @@ def test_load_arrow_item( and destination_config.file_format == "jsonl" ) - include_decimal = not ( + include_decimal = True + + if ( destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl" - ) + ) or (destination_config.destination_name == "sqlalchemy_sqlite"): + include_decimal = False + include_date = not ( destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl" @@ -123,23 +132,25 @@ def some_data(): if "binary" in record and destination_config.file_format == "parquet": record["binary"] = record["binary"].decode("ascii") + expected = sorted([list(r.values()) for r in records]) first_record = list(records[0].values()) - for row in rows: - for i in range(len(row)): - if isinstance(row[i], datetime): - row[i] = pendulum.instance(row[i]) + for row, expected_row in zip(rows, expected): + for i in range(len(expected_row)): + if isinstance(expected_row[i], datetime): + row[i] = ensure_pendulum_datetime(row[i]) # clickhouse produces rounding errors on double with jsonl, so we round the result coming from there - if ( + elif ( destination_config.destination_type == "clickhouse" and destination_config.file_format == "jsonl" and isinstance(row[i], float) ): row[i] = round(row[i], 4) - if isinstance(row[i], timedelta) and isinstance(first_record[i], dt_time): + elif isinstance(first_record[i], dt_time): # Some drivers (mysqlclient) return TIME columns as timedelta as seconds since midnight + # sqlite returns iso strings row[i] = ensure_pendulum_time(row[i]) - - expected = sorted([list(r.values()) for r in records]) + elif isinstance(expected_row[i], date): + row[i] = ensure_pendulum_date(row[i]) for row in expected: for i in range(len(row)): diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 8a1ef5c2d1..ae51917672 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -542,13 +542,17 @@ def _get_shuffled_events(shuffle: bool = dlt.secrets.value): @pytest.mark.parametrize( "destination_config", - destinations_configs(default_sql_configs=True, exclude=["sqlalchemy"]), + destinations_configs(default_sql_configs=True), ids=lambda x: x.name, ) @pytest.mark.parametrize("github_resource", [github_repo_events, github_repo_events_table_meta]) def test_merge_with_dispatch_and_incremental( destination_config: DestinationTestConfiguration, github_resource: DltResource ) -> None: + if destination_config.destination_name == 'sqlalchemy_mysql': + # TODO: Github events have too many columns for MySQL + pytest.skip("MySQL can't handle too many columns") + newest_issues = list( sorted(_get_shuffled_events(True), key=lambda x: x["created_at"], reverse=True) ) diff --git a/tests/load/pipeline/test_write_disposition_changes.py b/tests/load/pipeline/test_write_disposition_changes.py index c28890e2c6..2498a7bdbb 100644 --- a/tests/load/pipeline/test_write_disposition_changes.py +++ b/tests/load/pipeline/test_write_disposition_changes.py @@ -91,7 +91,7 @@ def test_switch_from_merge(destination_config: DestinationTestConfiguration): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", destinations_configs(default_sql_configs=True, supports_merge=True), ids=lambda x: x.name ) @pytest.mark.parametrize("with_root_key", [True, False]) def test_switch_to_merge(destination_config: DestinationTestConfiguration, with_root_key: bool): diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 82af82334e..1d978c15d4 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -533,6 +533,9 @@ def test_rollback_transaction(client: SqlJobClientBase) -> None: def test_transaction_isolation(client: SqlJobClientBase) -> None: if client.capabilities.supports_transactions is False: pytest.skip("Destination does not support tx") + if client.config.destination_name == "sqlalchemy_sqlite": + # because other schema names must be attached for each connection + client.sql_client.dataset_name = "main" table_name, py_type = prepare_temp_table(client) f_q_table_name = client.sql_client.make_qualified_table_name(table_name) event = Event() @@ -577,7 +580,10 @@ def test_thread(thread_id: Union[Decimal, float]) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True, exclude=["sqlalchemy"]), + indirect=True, + ids=lambda x: x.name, ) def test_max_table_identifier_length(client: SqlJobClientBase) -> None: if client.capabilities.max_identifier_length >= 65536: @@ -607,7 +613,10 @@ def test_max_table_identifier_length(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True, exclude=["sqlalchemy"]), + indirect=True, + ids=lambda x: x.name, ) def test_max_column_identifier_length(client: SqlJobClientBase) -> None: if client.capabilities.max_column_identifier_length >= 65536: From 2a30b36b0e222f886d4cd9f102964130bd34238a Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 5 Sep 2024 18:46:01 -0400 Subject: [PATCH 06/37] Add sqlalchemy tests in ci --- .github/workflows/test_local_destinations.yml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test_local_destinations.yml b/.github/workflows/test_local_destinations.yml index 2d712814bd..734c9512d7 100644 --- a/.github/workflows/test_local_destinations.yml +++ b/.github/workflows/test_local_destinations.yml @@ -21,7 +21,7 @@ env: RUNTIME__SENTRY_DSN: https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 RUNTIME__LOG_LEVEL: ERROR RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} - ACTIVE_DESTINATIONS: "[\"duckdb\", \"postgres\", \"filesystem\", \"weaviate\", \"qdrant\"]" + ACTIVE_DESTINATIONS: "[\"duckdb\", \"postgres\", \"filesystem\", \"weaviate\", \"qdrant\", \"sqlalchemy\"]" ALL_FILESYSTEM_DRIVERS: "[\"memory\", \"file\"]" DESTINATION__WEAVIATE__VECTORIZER: text2vec-contextionary @@ -68,6 +68,16 @@ jobs: ports: - 6333:6333 + mysql: + image: mysql:5.7 + env: + MYSQL_ROOT_PASSWORD: root + MYSQL_DATABASE: dlt_data + MYSQL_USER: loader + MYSQL_PASSWORD: loader + ports: + - 3306:3306 + steps: - name: Check out uses: actions/checkout@master @@ -103,6 +113,8 @@ jobs: env: DESTINATION__POSTGRES__CREDENTIALS: postgresql://loader:loader@localhost:5432/dlt_data DESTINATION__QDRANT__CREDENTIALS__location: http://localhost:6333 + DESTINATION__SQLALCHEMY_MYSQL__CREDENTIALS: mysql://root:root@localhost:3306/dlt_data # Use root cause we need to create databases + DESTINATION__SQLALCHEMY_SQLITE__CREDENTIALS: sqlite:///_storage/dl_data.sqlite - name: Stop weaviate if: always() From e7f56c91963b50fdb5cd9532dadd7593e3bf23e4 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 5 Sep 2024 18:58:25 -0400 Subject: [PATCH 07/37] Type errors --- .../impl/sqlalchemy/db_api_client.py | 33 ++++++++----------- .../impl/sqlalchemy/sqlalchemy_job_client.py | 19 ++++++----- dlt/pipeline/pipeline.py | 2 +- tests/load/pipeline/test_merge_disposition.py | 4 +-- .../test_write_disposition_changes.py | 4 ++- 5 files changed, 31 insertions(+), 31 deletions(-) diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index ca483c0a2e..13b28385d7 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -71,9 +71,9 @@ def __init__(self, curr: sa.engine.CursorResult) -> None: self.native_cursor = curr # type: ignore[assignment] curr.columns - self.fetchall = curr.fetchall # type: ignore[method-assign] - self.fetchone = curr.fetchone # type: ignore[method-assign] - self.fetchmany = curr.fetchmany # type: ignore[method-assign] + self.fetchall = curr.fetchall # type: ignore[assignment] + self.fetchone = curr.fetchone # type: ignore[assignment] + self.fetchmany = curr.fetchmany # type: ignore[assignment] def _get_columns(self) -> List[str]: return list(self.native_cursor.keys()) # type: ignore[attr-defined] @@ -122,7 +122,7 @@ def __init__( self._current_transaction: Optional[SqlaTransactionWrapper] = None self.metadata = sa.MetaData() self.dialect = self.engine.dialect - self.dialect_name = self.dialect.name # type: ignore[attr-defined] + self.dialect_name = self.dialect.name def open_connection(self) -> Connection: if self._current_connection is None: @@ -154,8 +154,6 @@ def _in_transaction(self) -> bool: @contextmanager @raise_database_error def begin_transaction(self) -> Iterator[DBTransaction]: - if self._in_transaction(): - raise DatabaseTerminalException("Transaction already started") trans = self._current_transaction = SqlaTransactionWrapper(self._current_connection.begin()) try: yield trans @@ -191,7 +189,7 @@ def _transaction(self) -> Iterator[DBTransaction]: def has_dataset(self) -> bool: with self._transaction(): - schema_names = self.engine.dialect.get_schema_names(self._current_connection) + schema_names = self.engine.dialect.get_schema_names(self._current_connection) # type: ignore[attr-defined] return self.dataset_name in schema_names def _sqlite_dataset_filename(self, dataset_name: str) -> str: @@ -206,8 +204,7 @@ def _sqlite_is_memory_db(self) -> bool: return self.engine.url.database == ":memory:" def _sqlite_reattach_dataset_if_exists(self, dataset_name: str) -> None: - """Re-attach previously created databases for a new sqlite connection - """ + """Re-attach previously created databases for a new sqlite connection""" if self._sqlite_is_memory_db(): return new_db_fn = self._sqlite_dataset_filename(dataset_name) @@ -271,7 +268,7 @@ def execute_sql( self, sql: Union[AnyStr, sa.sql.Executable], *args: Any, **kwargs: Any ) -> Optional[Sequence[Sequence[Any]]]: with self.execute_query(sql, *args, **kwargs) as cursor: - if cursor.returns_rows: + if cursor.returns_rows: # type: ignore[attr-defined] return cursor.fetchall() return None @@ -285,15 +282,13 @@ def execute_query( if args: # Sqlalchemy text supports :named paramstyle for all dialects query, kwargs = self._to_named_paramstyle(query, args) # type: ignore[assignment] - args = [ - kwargs, - ] + args = (kwargs,) query = sa.text(query) if kwargs: # sqla2 takes either a dict or list of dicts - args = [kwargs] + args = (kwargs,) with self._transaction(): - yield SqlaDbApiCursor(self._current_connection.execute(query, *args)) # type: ignore[abstract] + yield SqlaDbApiCursor(self._current_connection.execute(query, *args)) # type: ignore[call-overload, abstract] def get_existing_table(self, table_name: str) -> Optional[sa.Table]: """Get a table object from metadata if it exists""" @@ -319,7 +314,7 @@ def make_qualified_table_name(self, table_name: str, escape: bool = True) -> str def fully_qualified_dataset_name(self, escape: bool = True, staging: bool = False) -> str: if staging: raise NotImplementedError("Staging not supported") - return self.dialect.identifier_preparer.format_schema(self.dataset_name) + return self.dialect.identifier_preparer.format_schema(self.dataset_name) # type: ignore[attr-defined, no-any-return] def alter_table_add_column(self, column: sa.Column) -> None: """Execute an ALTER TABLE ... ADD COLUMN ... statement for the given column. @@ -328,14 +323,14 @@ def alter_table_add_column(self, column: sa.Column) -> None: # TODO: May need capability to override ALTER TABLE statement for different dialects alter_tmpl = "ALTER TABLE {table} ADD COLUMN {column};" statement = alter_tmpl.format( - table=self._make_qualified_table_name(self._make_qualified_table_name(column.table)), + table=self._make_qualified_table_name(self._make_qualified_table_name(column.table)), # type: ignore[arg-type] column=self.compile_column_def(column), ) self.execute_sql(statement) def escape_column_name(self, column_name: str, escape: bool = True) -> str: - if self.dialect.requires_name_normalize: - column_name = self.dialect.normalize_name(column_name) + if self.dialect.requires_name_normalize: # type: ignore[attr-defined] + column_name = self.dialect.normalize_name(column_name) # type: ignore[func-returns-value] if escape: return self.dialect.identifier_preparer.format_column(sa.Column(column_name)) # type: ignore[attr-defined,no-any-return] return column_name diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index 7fb10f4b18..d126cc8663 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -54,9 +54,10 @@ def _db_integer_type(self, precision: Optional[int]) -> sa.types.TypeEngine: def _create_date_time_type(self, sc_t: str, precision: Optional[int]) -> sa.types.TypeEngine: """Use the dialect specific datetime/time type if possible since the generic type doesn't accept precision argument""" precision = precision if precision is not None else self.capabilities.timestamp_precision + base_type: sa.types.TypeEngine if sc_t == "timestamp": base_type = sa.DateTime() - if self.dialect.name == "mysql": # type: ignore[attr-defined] + if self.dialect.name == "mysql": # Special case, type_descriptor does not return the specifc datetime type from sqlalchemy.dialects.mysql import DATETIME @@ -76,7 +77,7 @@ def _create_date_time_type(self, sc_t: str, precision: Optional[int]) -> sa.type kwargs["fsp"] = precision # MySQL uses fsp for fractional seconds elif "precision" in params: kwargs["precision"] = precision - return dialect_type(**kwargs) # type: ignore[no-any-return] + return dialect_type(**kwargs) # type: ignore[no-any-return,misc] def _create_double_type(self) -> sa.types.TypeEngine: if dbl := getattr(sa, "Double", None): @@ -102,8 +103,8 @@ def to_db_type(self, column: TColumnSchema, table_format: TTableSchema) -> sa.ty if length is None and column.get("unique"): length = 128 if length is None: - return sa.Text() # type: ignore[no-any-return] - return sa.String(length=length) # type: ignore[no-any-return] + return sa.Text() + return sa.String(length=length) elif sc_t == "double": return self._create_double_type() elif sc_t == "bool": @@ -422,7 +423,7 @@ def _update_schema_in_storage(self, schema: Schema) -> None: schema_mapping = StorageSchemaInfo( version=schema.version, - engine_version=schema.ENGINE_VERSION, + engine_version=str(schema.ENGINE_VERSION), schema_name=schema.name, version_hash=schema.stored_version_hash, schema=schema_str, @@ -452,7 +453,9 @@ def _get_stored_schema( return None # TODO: Decode compressed schema str if needed - return StorageSchemaInfo.from_normalized_mapping(row._mapping, self.schema.naming) + return StorageSchemaInfo.from_normalized_mapping( + row._mapping, self.schema.naming # type: ignore[attr-defined] + ) def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: return self._get_stored_schema(version_hash) @@ -490,7 +493,7 @@ def get_stored_state(self, pipeline_name: str) -> StateInfo: row = cur.fetchone() if not row: return None - mapping = dict(row._mapping) + mapping = dict(row._mapping) # type: ignore[attr-defined] return StateInfo.from_normalized_mapping(mapping, self.schema.naming) @@ -499,5 +502,5 @@ def _from_db_type( ) -> TColumnType: raise NotImplementedError() - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableSchema = None) -> str: raise NotImplementedError() diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 118c126082..fa10f5ac89 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -578,7 +578,7 @@ def load( step_info = self._get_step_info(load_step) raise PipelineStepFailed( self, "load", load_step.current_load_id, l_ex, step_info - ) from l_ex + ) from l_ex @with_runtime_trace() @with_config_section(("run",)) diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index ae51917672..d1082263dd 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -549,10 +549,10 @@ def _get_shuffled_events(shuffle: bool = dlt.secrets.value): def test_merge_with_dispatch_and_incremental( destination_config: DestinationTestConfiguration, github_resource: DltResource ) -> None: - if destination_config.destination_name == 'sqlalchemy_mysql': + if destination_config.destination_name == "sqlalchemy_mysql": # TODO: Github events have too many columns for MySQL pytest.skip("MySQL can't handle too many columns") - + newest_issues = list( sorted(_get_shuffled_events(True), key=lambda x: x["created_at"], reverse=True) ) diff --git a/tests/load/pipeline/test_write_disposition_changes.py b/tests/load/pipeline/test_write_disposition_changes.py index 2498a7bdbb..f7d915903e 100644 --- a/tests/load/pipeline/test_write_disposition_changes.py +++ b/tests/load/pipeline/test_write_disposition_changes.py @@ -91,7 +91,9 @@ def test_switch_from_merge(destination_config: DestinationTestConfiguration): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True, supports_merge=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, supports_merge=True), + ids=lambda x: x.name, ) @pytest.mark.parametrize("with_root_key", [True, False]) def test_switch_to_merge(destination_config: DestinationTestConfiguration, with_root_key: bool): From 11d52db56a99a13ff3c060847fe8c4b01c71b756 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 5 Sep 2024 19:03:47 -0400 Subject: [PATCH 08/37] Test sqlalchemy in own workflow --- .github/workflows/test_local_destinations.yml | 3 +- .../test_sqlalchemy_destinations.yml | 94 +++++++++++++++++++ tests/load/sqlalchemy/__init__.py | 3 + .../test_sqlalchemy_configuration.py | 23 +++++ 4 files changed, 121 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/test_sqlalchemy_destinations.yml create mode 100644 tests/load/sqlalchemy/__init__.py create mode 100644 tests/load/sqlalchemy/test_sqlalchemy_configuration.py diff --git a/.github/workflows/test_local_destinations.yml b/.github/workflows/test_local_destinations.yml index 734c9512d7..0da3fc71e8 100644 --- a/.github/workflows/test_local_destinations.yml +++ b/.github/workflows/test_local_destinations.yml @@ -113,8 +113,7 @@ jobs: env: DESTINATION__POSTGRES__CREDENTIALS: postgresql://loader:loader@localhost:5432/dlt_data DESTINATION__QDRANT__CREDENTIALS__location: http://localhost:6333 - DESTINATION__SQLALCHEMY_MYSQL__CREDENTIALS: mysql://root:root@localhost:3306/dlt_data # Use root cause we need to create databases - DESTINATION__SQLALCHEMY_SQLITE__CREDENTIALS: sqlite:///_storage/dl_data.sqlite + - name: Stop weaviate if: always() diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml new file mode 100644 index 0000000000..145850c664 --- /dev/null +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -0,0 +1,94 @@ +# Tests destinations that can run without credentials. +# i.e. local postgres, duckdb, filesystem (with local fs/memory bucket) + +name: dest | sqlalchemy mysql and sqlite + +on: + pull_request: + branches: + - master + - devel + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +env: + # NOTE: this workflow can't use github secrets! + # DLT_SECRETS_TOML: ${{ secrets.DLT_SECRETS_TOML }} + + RUNTIME__SENTRY_DSN: https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 + RUNTIME__LOG_LEVEL: ERROR + RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} + ACTIVE_DESTINATIONS: "[\"sqlalchemy\"]" + ALL_FILESYSTEM_DRIVERS: "[\"memory\", \"file\"]" + +jobs: + get_docs_changes: + name: docs changes + uses: ./.github/workflows/get_docs_changes.yml + + run_loader: + name: dest | sqlalchemy mysql and sqlite + needs: get_docs_changes + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' + strategy: + fail-fast: false + defaults: + run: + shell: bash + runs-on: "ubuntu-latest" + + # Service containers to run with `container-job` + services: + # Label used to access the service container + mysql: + image: mysql:5.7 + env: + MYSQL_ROOT_PASSWORD: root + MYSQL_DATABASE: dlt_data + MYSQL_USER: loader + MYSQL_PASSWORD: loader + ports: + - 3306:3306 + + steps: + - name: Check out + uses: actions/checkout@master + + - name: Start weaviate + run: docker compose -f ".github/weaviate-compose.yml" up -d + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: "3.10.x" + + - name: Install Poetry + uses: snok/install-poetry@v1.3.2 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v3 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-local-destinations + + - name: Install dependencies + run: poetry install --no-interaction -E parquet -E filesystem -E sqlalchemy -E cli --with sentry-sdk --with pipeline + run: pip install mysqlclient + + - name: create secrets.toml + run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml + + # always run full suite, also on branches + - run: poetry run pytest tests/load && poetry run pytest tests/cli + name: Run tests Linux + env: + DESTINATION__SQLALCHEMY_MYSQL__CREDENTIALS: mysql://root:root@localhost:3306/dlt_data # Use root cause we need to create databases + DESTINATION__SQLALCHEMY_SQLITE__CREDENTIALS: sqlite:///_storage/dl_data.sqlite diff --git a/tests/load/sqlalchemy/__init__.py b/tests/load/sqlalchemy/__init__.py new file mode 100644 index 0000000000..250c1f7626 --- /dev/null +++ b/tests/load/sqlalchemy/__init__.py @@ -0,0 +1,3 @@ +from tests.utils import skip_if_not_active + +skip_if_not_active("sqlalchemy") diff --git a/tests/load/sqlalchemy/test_sqlalchemy_configuration.py b/tests/load/sqlalchemy/test_sqlalchemy_configuration.py new file mode 100644 index 0000000000..281593aaf7 --- /dev/null +++ b/tests/load/sqlalchemy/test_sqlalchemy_configuration.py @@ -0,0 +1,23 @@ +import pytest + +import sqlalchemy as sa + +from dlt.common.configuration import resolve_configuration +from dlt.destinations.impl.sqlalchemy.configuration import ( + SqlalchemyClientConfiguration, + SqlalchemyCredentials, +) + + +def test_sqlalchemy_credentials_from_engine() -> None: + engine = sa.create_engine("sqlite:///:memory:") + + creds = resolve_configuration(SqlalchemyCredentials(engine)) + + # Url is taken from engine + assert creds.to_url() == sa.engine.make_url("sqlite:///:memory:") + # Engine is stored on the instance + assert creds.engine is engine + + assert creds.drivername == "sqlite" + assert creds.database == ":memory:" From 9d37ea6c0e46466c07f83e8aead7cece911dc131 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 11 Sep 2024 12:20:56 -0400 Subject: [PATCH 09/37] Fix tests, type errors --- .../impl/sqlalchemy/sqlalchemy_job_client.py | 9 ++++++--- tests/load/pipeline/test_pipelines.py | 2 +- tests/load/test_insert_job_client.py | 2 +- tests/load/test_sql_client.py | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index d126cc8663..bd163c2c53 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -51,10 +51,13 @@ def _db_integer_type(self, precision: Optional[int]) -> sa.types.TypeEngine: return sa.BigInteger() raise TerminalValueError(f"Unsupported precision for integer type: {precision}") - def _create_date_time_type(self, sc_t: str, precision: Optional[int]) -> sa.types.TypeEngine: + def _create_date_time_type( + self, sc_t: str, precision: Optional[int], timezone: Optional[bool] + ) -> sa.types.TypeEngine: """Use the dialect specific datetime/time type if possible since the generic type doesn't accept precision argument""" precision = precision if precision is not None else self.capabilities.timestamp_precision base_type: sa.types.TypeEngine + timezone = timezone is None or bool(timezone) if sc_t == "timestamp": base_type = sa.DateTime() if self.dialect.name == "mysql": @@ -72,7 +75,7 @@ def _create_date_time_type(self, sc_t: str, precision: Optional[int]) -> sa.type # Find out whether the dialect type accepts precision or fsp argument params = inspect.signature(dialect_type).parameters - kwargs: Dict[str, Any] = dict(timezone=True) + kwargs: Dict[str, Any] = dict(timezone=timezone) if "fsp" in params: kwargs["fsp"] = precision # MySQL uses fsp for fractional seconds elif "precision" in params: @@ -154,7 +157,7 @@ def from_db_type(self, db_type: sa.types.TypeEngine) -> TColumnType: elif isinstance(db_type, sa.Boolean): return dict(data_type="bool") elif isinstance(db_type, sa.DateTime): - return dict(data_type="timestamp") + return dict(data_type="timestamp", timezone=db_type.timezone) elif isinstance(db_type, sa.Integer): return self._from_db_integer_type(db_type) elif isinstance(db_type, sqltypes._Binary): diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 5af41025ba..462a7d253a 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -882,7 +882,7 @@ def events(): ids=lambda x: x.name, ) def test_dest_column_hint_timezone(destination_config: DestinationTestConfiguration) -> None: - destination = destination_config.destination + destination = destination_config.destination_type input_data = [ {"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}, diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index a957c871bb..4359ac6885 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -28,7 +28,7 @@ def file_storage() -> FileStorage: @pytest.fixture(scope="function") def client(request) -> Iterator[InsertValuesJobClient]: - yield from yield_client_with_storage(request.param.destination) # type: ignore[misc] + yield from yield_client_with_storage(request.param.destination_factory()) # type: ignore[misc] @pytest.mark.essential diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 1d978c15d4..e46d1c0618 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -701,7 +701,7 @@ def prepare_temp_table(client: SqlJobClientBase) -> Tuple[str, Type[Union[Decima table_name = f"tmp_{uniq_suffix}" ddl_suffix = "" coltype = "numeric" - py_type = Decimal + py_type: Union[Type[Decimal], Type[float]] = Decimal if client.config.destination_type == "athena": ddl_suffix = ( f"LOCATION '{AWS_BUCKET}/ci/{table_name}' TBLPROPERTIES ('table_type'='ICEBERG'," From cdeb17d3faf3b98b5d8bcbcc8988c141c5276738 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 5 Sep 2024 20:44:02 -0400 Subject: [PATCH 10/37] Fix config --- dlt/destinations/impl/sqlalchemy/configuration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/destinations/impl/sqlalchemy/configuration.py b/dlt/destinations/impl/sqlalchemy/configuration.py index 0236f84086..bdbece0310 100644 --- a/dlt/destinations/impl/sqlalchemy/configuration.py +++ b/dlt/destinations/impl/sqlalchemy/configuration.py @@ -9,7 +9,7 @@ from sqlalchemy.engine import Engine, Dialect -@configspec +@configspec(init=False) class SqlalchemyCredentials(ConnectionStringCredentials): if TYPE_CHECKING: _engine: Optional["Engine"] = None From a730a9163b6ff729b57688299cefd78d9baee975 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 5 Sep 2024 21:20:53 -0400 Subject: [PATCH 11/37] CI fix --- .github/workflows/test_local_destinations.yml | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/.github/workflows/test_local_destinations.yml b/.github/workflows/test_local_destinations.yml index 0da3fc71e8..a21c3f0618 100644 --- a/.github/workflows/test_local_destinations.yml +++ b/.github/workflows/test_local_destinations.yml @@ -21,7 +21,7 @@ env: RUNTIME__SENTRY_DSN: https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 RUNTIME__LOG_LEVEL: ERROR RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} - ACTIVE_DESTINATIONS: "[\"duckdb\", \"postgres\", \"filesystem\", \"weaviate\", \"qdrant\", \"sqlalchemy\"]" + ACTIVE_DESTINATIONS: "[\"duckdb\", \"postgres\", \"filesystem\", \"weaviate\", \"qdrant\"]" ALL_FILESYSTEM_DRIVERS: "[\"memory\", \"file\"]" DESTINATION__WEAVIATE__VECTORIZER: text2vec-contextionary @@ -68,16 +68,6 @@ jobs: ports: - 6333:6333 - mysql: - image: mysql:5.7 - env: - MYSQL_ROOT_PASSWORD: root - MYSQL_DATABASE: dlt_data - MYSQL_USER: loader - MYSQL_PASSWORD: loader - ports: - - 3306:3306 - steps: - name: Check out uses: actions/checkout@master From 33265802dac67005ae6fe773eaf300da0b64b998 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 11 Sep 2024 12:21:40 -0400 Subject: [PATCH 12/37] Add alembic to handle ALTER TABLE --- .../impl/sqlalchemy/alter_table.py | 38 +++++++++++++++++++ .../impl/sqlalchemy/db_api_client.py | 23 +++++------ .../impl/sqlalchemy/sqlalchemy_job_client.py | 12 ++---- pyproject.toml | 4 +- 4 files changed, 55 insertions(+), 22 deletions(-) create mode 100644 dlt/destinations/impl/sqlalchemy/alter_table.py diff --git a/dlt/destinations/impl/sqlalchemy/alter_table.py b/dlt/destinations/impl/sqlalchemy/alter_table.py new file mode 100644 index 0000000000..f85101a740 --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/alter_table.py @@ -0,0 +1,38 @@ +from typing import List + +import sqlalchemy as sa +from alembic.runtime.migration import MigrationContext +from alembic.operations import Operations + + +class ListBuffer: + """A partial implementation of string IO to use with alembic. + SQL statements are stored in a list instead of file/stdio + """ + + def __init__(self) -> None: + self._buf = "" + self.sql_lines: List[str] = [] + + def write(self, data: str) -> None: + self._buf += data + + def flush(self) -> None: + if self._buf: + self.sql_lines.append(self._buf) + self._buf = "" + + +class MigrationMaker: + def __init__(self, dialect: sa.engine.Dialect) -> None: + self._buf = ListBuffer() + self.ctx = MigrationContext(dialect, None, {"as_sql": True, "output_buffer": self._buf}) + self.ops = Operations(self.ctx) + + def add_column(self, table_name: str, column: sa.Column, schema: str) -> None: + self.ops.add_column(table_name, column, schema=schema) + + def consume_statements(self) -> List[str]: + lines = self._buf.sql_lines[:] + self._buf.sql_lines.clear() + return lines diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index 13b28385d7..8c3df409aa 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -29,6 +29,7 @@ from dlt.destinations.typing import DBTransaction, DBApiCursor from dlt.destinations.sql_client import SqlClientBase, DBApiCursorImpl from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyCredentials +from dlt.destinations.impl.sqlalchemy.alter_table import MigrationMaker from dlt.common.typing import TFun @@ -97,6 +98,7 @@ class SqlalchemyClient(SqlClientBase[Connection]): dialect: sa.engine.interfaces.Dialect dialect_name: str dbapi = DbApiProps # type: ignore[assignment] + migrations: Optional[MigrationMaker] = None # lazy init as needed def __init__( self, @@ -316,17 +318,16 @@ def fully_qualified_dataset_name(self, escape: bool = True, staging: bool = Fals raise NotImplementedError("Staging not supported") return self.dialect.identifier_preparer.format_schema(self.dataset_name) # type: ignore[attr-defined, no-any-return] - def alter_table_add_column(self, column: sa.Column) -> None: - """Execute an ALTER TABLE ... ADD COLUMN ... statement for the given column. - The column must be fully defined and attached to a table. - """ - # TODO: May need capability to override ALTER TABLE statement for different dialects - alter_tmpl = "ALTER TABLE {table} ADD COLUMN {column};" - statement = alter_tmpl.format( - table=self._make_qualified_table_name(self._make_qualified_table_name(column.table)), # type: ignore[arg-type] - column=self.compile_column_def(column), - ) - self.execute_sql(statement) + def alter_table_add_columns(self, columns: Sequence[sa.Column]) -> None: + if not columns: + return + if self.migrations is None: + self.migrations = MigrationMaker(self.dialect) + for column in columns: + self.migrations.add_column(column.table.name, column, self.dataset_name) + statements = self.migrations.consume_statements() + for statement in statements: + self.execute_sql(statement) def escape_column_name(self, column_name: str, escape: bool = True) -> str: if self.dialect.requires_name_normalize: # type: ignore[attr-defined] diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index bd163c2c53..33f00870c1 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -113,7 +113,7 @@ def to_db_type(self, column: TColumnSchema, table_format: TTableSchema) -> sa.ty elif sc_t == "bool": return sa.Boolean() elif sc_t == "timestamp": - return self._create_date_time_type(sc_t, precision) + return self._create_date_time_type(sc_t, precision, column.get("timezone")) elif sc_t == "bigint": return self._db_integer_type(precision) elif sc_t == "binary": @@ -128,7 +128,7 @@ def to_db_type(self, column: TColumnSchema, table_format: TTableSchema) -> sa.ty elif sc_t == "date": return sa.Date() elif sc_t == "time": - return self._create_date_time_type(sc_t, precision) + return self._create_date_time_type(sc_t, precision, column.get("timezone")) raise TerminalValueError(f"Unsupported data type: {sc_t}") def _from_db_integer_type(self, db_type: sa.Integer) -> TColumnType: @@ -400,13 +400,7 @@ def update_stored_schema( with self.sql_client.begin_transaction(): for table_obj in tables_to_create: self.sql_client.create_table(table_obj) - for col in columns_to_add: - alter = "ALTER TABLE {} ADD COLUMN {}".format( - self.sql_client.make_qualified_table_name(col.table.name), - self.sql_client.compile_column_def(col), - ) - self.sql_client.execute_sql(alter) - + self.sql_client.alter_table_add_columns(columns_to_add) self._update_schema_in_storage(self.schema) return schema_update diff --git a/pyproject.toml b/pyproject.toml index 1d4c9cd768..df0bc11782 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ lancedb = { version = ">=0.8.2", optional = true, markers = "python_version >= ' tantivy = { version = ">= 0.22.0", optional = true } deltalake = { version = ">=0.19.0", optional = true } sqlalchemy = { version = ">=1.4", optional = true } +alembic = {version = "^1.13.2", optional = true} [tool.poetry.extras] gcp = ["grpcio", "google-cloud-bigquery", "db-dtypes", "gcsfs"] @@ -111,8 +112,7 @@ dremio = ["pyarrow"] lancedb = ["lancedb", "pyarrow", "tantivy"] deltalake = ["deltalake", "pyarrow"] sql_database = ["sqlalchemy"] -sqlalchemy = ["sqlalchemy"] - +sqlalchemy = ["sqlalchemy", "alembic"] [tool.poetry.scripts] dlt = "dlt.cli._dlt:_main" From 567359d1c2638a8293b175dcded8fc26e7ba4017 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 5 Sep 2024 21:41:06 -0400 Subject: [PATCH 13/37] FIx workflow --- .github/workflows/test_sqlalchemy_destinations.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml index 145850c664..6b304fe47f 100644 --- a/.github/workflows/test_sqlalchemy_destinations.yml +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -80,8 +80,7 @@ jobs: key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-local-destinations - name: Install dependencies - run: poetry install --no-interaction -E parquet -E filesystem -E sqlalchemy -E cli --with sentry-sdk --with pipeline - run: pip install mysqlclient + run: poetry install --no-interaction -E parquet -E filesystem -E sqlalchemy -E cli --with sentry-sdk --with pipeline && pip install mysqlclient - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml From babcd3c12b594aa08f45fcd2c4747af482aa4b91 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 5 Sep 2024 21:59:13 -0400 Subject: [PATCH 14/37] Install mysqlclient in venv --- .github/workflows/test_sqlalchemy_destinations.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml index 6b304fe47f..0296ebdfdc 100644 --- a/.github/workflows/test_sqlalchemy_destinations.yml +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -80,7 +80,7 @@ jobs: key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-local-destinations - name: Install dependencies - run: poetry install --no-interaction -E parquet -E filesystem -E sqlalchemy -E cli --with sentry-sdk --with pipeline && pip install mysqlclient + run: poetry install --no-interaction -E parquet -E filesystem -E sqlalchemy -E cli --with sentry-sdk --with pipeline && poetry run pip install mysqlclient - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml From 9dec1c5786e6d77352b24ac1b76e71a172494117 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 5 Sep 2024 22:05:23 -0400 Subject: [PATCH 15/37] Mysql service version --- .github/workflows/test_sqlalchemy_destinations.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml index 0296ebdfdc..6c5df752b6 100644 --- a/.github/workflows/test_sqlalchemy_destinations.yml +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -44,7 +44,7 @@ jobs: services: # Label used to access the service container mysql: - image: mysql:5.7 + image: mysql:8 env: MYSQL_ROOT_PASSWORD: root MYSQL_DATABASE: dlt_data From 3e282eaab4b8883833d02a883a7b69cd8433dd1c Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 5 Sep 2024 22:06:21 -0400 Subject: [PATCH 16/37] Single fail --- .github/workflows/test_sqlalchemy_destinations.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml index 6c5df752b6..20662133ee 100644 --- a/.github/workflows/test_sqlalchemy_destinations.yml +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -86,7 +86,7 @@ jobs: run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml # always run full suite, also on branches - - run: poetry run pytest tests/load && poetry run pytest tests/cli + - run: poetry run pytest tests/load tests/cli -x name: Run tests Linux env: DESTINATION__SQLALCHEMY_MYSQL__CREDENTIALS: mysql://root:root@localhost:3306/dlt_data # Use root cause we need to create databases From 0439015248b2913f301593a522d76a56bed8d7bf Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 5 Sep 2024 22:19:36 -0400 Subject: [PATCH 17/37] mysql healtcheck --- .github/workflows/test_sqlalchemy_destinations.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml index 20662133ee..ff299b085b 100644 --- a/.github/workflows/test_sqlalchemy_destinations.yml +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -52,6 +52,12 @@ jobs: MYSQL_PASSWORD: loader ports: - 3306:3306 + # Wait for the service to be ready before completing the job + options: >- + --health-cmd="mysqladmin ping -h localhost -u root -proot" + --health-interval=10s + --health-timeout=5s + --health-retries=5 steps: - name: Check out From 61c835588bfeaf13dfc81fa03520b208393b60ba Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 5 Sep 2024 22:25:26 -0400 Subject: [PATCH 18/37] No localhost --- .github/workflows/test_sqlalchemy_destinations.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml index ff299b085b..d5902b434f 100644 --- a/.github/workflows/test_sqlalchemy_destinations.yml +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -95,5 +95,5 @@ jobs: - run: poetry run pytest tests/load tests/cli -x name: Run tests Linux env: - DESTINATION__SQLALCHEMY_MYSQL__CREDENTIALS: mysql://root:root@localhost:3306/dlt_data # Use root cause we need to create databases + DESTINATION__SQLALCHEMY_MYSQL__CREDENTIALS: mysql://root:root@127.0.0.1:3306/dlt_data # Use root cause we need to create databases DESTINATION__SQLALCHEMY_SQLITE__CREDENTIALS: sqlite:///_storage/dl_data.sqlite From 84dc4cf91579289b58a5d1dcbbbced0d6cc35d7d Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 5 Sep 2024 22:27:11 -0400 Subject: [PATCH 19/37] Remove weaviate --- .github/workflows/test_sqlalchemy_destinations.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml index d5902b434f..d1a3b6e9e2 100644 --- a/.github/workflows/test_sqlalchemy_destinations.yml +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -63,9 +63,6 @@ jobs: - name: Check out uses: actions/checkout@master - - name: Start weaviate - run: docker compose -f ".github/weaviate-compose.yml" up -d - - name: Setup Python uses: actions/setup-python@v4 with: From 4bcc425ec654eb69d61d005aecbec2f08285718c Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 6 Sep 2024 12:28:26 -0400 Subject: [PATCH 20/37] Change ubuntu version --- .github/workflows/test_sqlalchemy_destinations.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml index d1a3b6e9e2..e33e1a2e15 100644 --- a/.github/workflows/test_sqlalchemy_destinations.yml +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -38,7 +38,7 @@ jobs: defaults: run: shell: bash - runs-on: "ubuntu-latest" + runs-on: "ubuntu-22.04" # Service containers to run with `container-job` services: From a9b7e4979926d520d97f76d9f040d53bf3e6b082 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 6 Sep 2024 14:03:19 -0400 Subject: [PATCH 21/37] Debug sqlite version --- .../workflows/test_sqlalchemy_destinations.yml | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml index e33e1a2e15..1b4c6aa087 100644 --- a/.github/workflows/test_sqlalchemy_destinations.yml +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -85,12 +85,14 @@ jobs: - name: Install dependencies run: poetry install --no-interaction -E parquet -E filesystem -E sqlalchemy -E cli --with sentry-sdk --with pipeline && poetry run pip install mysqlclient - - name: create secrets.toml - run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml + - name: Check sqlite version + run: poetry run python -c "import sqlite3; print(sqlite3.sqlite_version)" + # - name: create secrets.toml + # run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - # always run full suite, also on branches - - run: poetry run pytest tests/load tests/cli -x - name: Run tests Linux - env: - DESTINATION__SQLALCHEMY_MYSQL__CREDENTIALS: mysql://root:root@127.0.0.1:3306/dlt_data # Use root cause we need to create databases - DESTINATION__SQLALCHEMY_SQLITE__CREDENTIALS: sqlite:///_storage/dl_data.sqlite + # # always run full suite, also on branches + # - run: poetry run pytest tests/load tests/cli -x + # name: Run tests Linux + # env: + # DESTINATION__SQLALCHEMY_MYSQL__CREDENTIALS: mysql://root:root@127.0.0.1:3306/dlt_data # Use root cause we need to create databases + # DESTINATION__SQLALCHEMY_SQLITE__CREDENTIALS: sqlite:///_storage/dl_data.sqlite From e0a078130001d54b4b0a755323a1a5901da2c727 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 6 Sep 2024 14:10:14 -0400 Subject: [PATCH 22/37] Revert --- .../test_sqlalchemy_destinations.yml | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml index 1b4c6aa087..d1a3b6e9e2 100644 --- a/.github/workflows/test_sqlalchemy_destinations.yml +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -38,7 +38,7 @@ jobs: defaults: run: shell: bash - runs-on: "ubuntu-22.04" + runs-on: "ubuntu-latest" # Service containers to run with `container-job` services: @@ -85,14 +85,12 @@ jobs: - name: Install dependencies run: poetry install --no-interaction -E parquet -E filesystem -E sqlalchemy -E cli --with sentry-sdk --with pipeline && poetry run pip install mysqlclient - - name: Check sqlite version - run: poetry run python -c "import sqlite3; print(sqlite3.sqlite_version)" - # - name: create secrets.toml - # run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml + - name: create secrets.toml + run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - # # always run full suite, also on branches - # - run: poetry run pytest tests/load tests/cli -x - # name: Run tests Linux - # env: - # DESTINATION__SQLALCHEMY_MYSQL__CREDENTIALS: mysql://root:root@127.0.0.1:3306/dlt_data # Use root cause we need to create databases - # DESTINATION__SQLALCHEMY_SQLITE__CREDENTIALS: sqlite:///_storage/dl_data.sqlite + # always run full suite, also on branches + - run: poetry run pytest tests/load tests/cli -x + name: Run tests Linux + env: + DESTINATION__SQLALCHEMY_MYSQL__CREDENTIALS: mysql://root:root@127.0.0.1:3306/dlt_data # Use root cause we need to create databases + DESTINATION__SQLALCHEMY_SQLITE__CREDENTIALS: sqlite:///_storage/dl_data.sqlite From 98f8de2368eeb81d70595d8e32b65c4f61196f16 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 6 Sep 2024 14:21:06 -0400 Subject: [PATCH 23/37] Use py datetime in tests --- tests/load/test_sql_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index e46d1c0618..199b4b83b7 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -295,7 +295,7 @@ def test_execute_query(client: SqlJobClientBase) -> None: assert rows[0][0] == "event" with client.sql_client.execute_query( f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at = %s", - pendulum.now().add(seconds=1), + to_py_datetime(pendulum.now().add(seconds=1)), ) as curr: rows = curr.fetchall() assert len(rows) == 0 @@ -303,7 +303,7 @@ def test_execute_query(client: SqlJobClientBase) -> None: with client.sql_client.execute_query( f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at =" " %(date)s", - date=pendulum.now().add(seconds=1), + date=to_py_datetime(pendulum.now().add(seconds=1)), ) as curr: rows = curr.fetchall() assert len(rows) == 0 From 4f8d8f655ed9151076c31e91a5f6077c3bb6f17a Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 6 Sep 2024 14:21:19 -0400 Subject: [PATCH 24/37] Test on sqlalchemy 1.4 and 2 remove secrets toml remove secrets toml Revert "remove secrets toml" This reverts commit 7dd189c39bbd78c942dd5ea56bafdc6e02738805. Fix default pipeline name test --- .github/workflows/test_sqlalchemy_destinations.yml | 5 ++++- tests/load/pipeline/test_pipelines.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml index d1a3b6e9e2..9186e5f8cd 100644 --- a/.github/workflows/test_sqlalchemy_destinations.yml +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -35,6 +35,9 @@ jobs: if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' strategy: fail-fast: false + # Run on sqlalchemy 1.4 and 2.0 + matrix: + sqlalchemy: [1.4, 2] defaults: run: shell: bash @@ -83,7 +86,7 @@ jobs: key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-local-destinations - name: Install dependencies - run: poetry install --no-interaction -E parquet -E filesystem -E sqlalchemy -E cli --with sentry-sdk --with pipeline && poetry run pip install mysqlclient + run: poetry install --no-interaction -E parquet -E filesystem -E sqlalchemy -E cli --with sentry-sdk --with pipeline && poetry run pip install mysqlclient && poetry run pip install "sqlalchemy==${{ matrix.sqlalchemy }}" - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 462a7d253a..2a29a0a24d 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -100,7 +100,7 @@ def data_fun() -> Iterator[Any]: # mock the correct destinations (never do that in normal code) with p.managed_state(): p._set_destinations( - destination=Destination.from_reference(destination_config.destination_type), + destination=destination_config.destination_factory(), staging=( Destination.from_reference(destination_config.staging) if destination_config.staging From 79631b2b6fdf30dcb12abe0bb1eb0ae921fa6348 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 6 Sep 2024 15:00:56 -0400 Subject: [PATCH 25/37] Lint, no cli tests --- .github/workflows/test_sqlalchemy_destinations.yml | 2 +- dlt/destinations/impl/sqlalchemy/configuration.py | 7 ++++++- tests/cases.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml index 9186e5f8cd..245060e811 100644 --- a/.github/workflows/test_sqlalchemy_destinations.yml +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -92,7 +92,7 @@ jobs: run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml # always run full suite, also on branches - - run: poetry run pytest tests/load tests/cli -x + - run: poetry run pytest tests/load -x name: Run tests Linux env: DESTINATION__SQLALCHEMY_MYSQL__CREDENTIALS: mysql://root:root@127.0.0.1:3306/dlt_data # Use root cause we need to create databases diff --git a/dlt/destinations/impl/sqlalchemy/configuration.py b/dlt/destinations/impl/sqlalchemy/configuration.py index bdbece0310..f99b06a27b 100644 --- a/dlt/destinations/impl/sqlalchemy/configuration.py +++ b/dlt/destinations/impl/sqlalchemy/configuration.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Any, Final, Type, Dict +from typing import TYPE_CHECKING, Optional, Any, Final, Type, Dict, Union import dataclasses from dlt.common.configuration import configspec @@ -16,6 +16,11 @@ class SqlalchemyCredentials(ConnectionStringCredentials): username: Optional[str] = None # e.g. sqlite doesn't need username + def __init__( + self, connection_string: Optional[Union[str, Dict[str, Any], "Engine"]] = None + ) -> None: + super().__init__(connection_string) # type: ignore[arg-type] + def parse_native_representation(self, native_value: Any) -> None: from sqlalchemy.engine import Engine diff --git a/tests/cases.py b/tests/cases.py index 162a4c4b1a..9b636d9b60 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -1,4 +1,4 @@ -import datetime +import datetime # noqa: I251 import hashlib from typing import Dict, List, Any, Sequence, Tuple, Literal, Union import base64 From 8068595179f59d9d35d9d4289f69cd23cd73a7e4 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 11 Sep 2024 16:54:48 -0400 Subject: [PATCH 26/37] Update lockfile --- poetry.lock | 115 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 108 insertions(+), 7 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9cbc4b66ea..e1afddfd5f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "about-time" @@ -216,13 +216,13 @@ frozenlist = ">=1.1.0" [[package]] name = "alembic" -version = "1.12.0" +version = "1.13.2" description = "A database migration tool for SQLAlchemy." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "alembic-1.12.0-py3-none-any.whl", hash = "sha256:03226222f1cf943deee6c85d9464261a6c710cd19b4fe867a3ad1f25afda610f"}, - {file = "alembic-1.12.0.tar.gz", hash = "sha256:8e7645c32e4f200675e69f0745415335eb59a3663f5feb487abfa0b30c45888b"}, + {file = "alembic-1.13.2-py3-none-any.whl", hash = "sha256:6b8733129a6224a9a711e17c99b08462dbf7cc9670ba8f2e2ae9af860ceb1953"}, + {file = "alembic-1.13.2.tar.gz", hash = "sha256:1ff0ae32975f4fd96028c39ed9bb3c867fe3af956bd7bb37343b54c9fe7445ef"}, ] [package.dependencies] @@ -233,7 +233,7 @@ SQLAlchemy = ">=1.3.0" typing-extensions = ">=4" [package.extras] -tz = ["python-dateutil"] +tz = ["backports.zoneinfo"] [[package]] name = "alive-progress" @@ -3749,6 +3749,106 @@ files = [ {file = "google_re2-1.1-4-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f4d4f0823e8b2f6952a145295b1ff25245ce9bb136aff6fe86452e507d4c1dd"}, {file = "google_re2-1.1-4-cp39-cp39-win32.whl", hash = "sha256:1afae56b2a07bb48cfcfefaa15ed85bae26a68f5dc7f9e128e6e6ea36914e847"}, {file = "google_re2-1.1-4-cp39-cp39-win_amd64.whl", hash = "sha256:aa7d6d05911ab9c8adbf3c225a7a120ab50fd2784ac48f2f0d140c0b7afc2b55"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:222fc2ee0e40522de0b21ad3bc90ab8983be3bf3cec3d349c80d76c8bb1a4beb"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:d4763b0b9195b72132a4e7de8e5a9bf1f05542f442a9115aa27cfc2a8004f581"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:209649da10c9d4a93d8a4d100ecbf9cc3b0252169426bec3e8b4ad7e57d600cf"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:68813aa333c1604a2df4a495b2a6ed065d7c8aebf26cc7e7abb5a6835d08353c"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:370a23ec775ad14e9d1e71474d56f381224dcf3e72b15d8ca7b4ad7dd9cd5853"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:14664a66a3ddf6bc9e56f401bf029db2d169982c53eff3f5876399104df0e9a6"}, + {file = "google_re2-1.1-5-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3ea3722cc4932cbcebd553b69dce1b4a73572823cff4e6a244f1c855da21d511"}, + {file = "google_re2-1.1-5-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e14bb264c40fd7c627ef5678e295370cd6ba95ca71d835798b6e37502fc4c690"}, + {file = "google_re2-1.1-5-cp310-cp310-win32.whl", hash = "sha256:39512cd0151ea4b3969c992579c79b423018b464624ae955be685fc07d94556c"}, + {file = "google_re2-1.1-5-cp310-cp310-win_amd64.whl", hash = "sha256:ac66537aa3bc5504320d922b73156909e3c2b6da19739c866502f7827b3f9fdf"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5b5ea68d54890c9edb1b930dcb2658819354e5d3f2201f811798bbc0a142c2b4"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:33443511b6b83c35242370908efe2e8e1e7cae749c766b2b247bf30e8616066c"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:413d77bdd5ba0bfcada428b4c146e87707452ec50a4091ec8e8ba1413d7e0619"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:5171686e43304996a34baa2abcee6f28b169806d0e583c16d55e5656b092a414"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3b284db130283771558e31a02d8eb8fb756156ab98ce80035ae2e9e3a5f307c4"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:296e6aed0b169648dc4b870ff47bd34c702a32600adb9926154569ef51033f47"}, + {file = "google_re2-1.1-5-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:38d50e68ead374160b1e656bbb5d101f0b95fb4cc57f4a5c12100155001480c5"}, + {file = "google_re2-1.1-5-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2a0416a35921e5041758948bcb882456916f22845f66a93bc25070ef7262b72a"}, + {file = "google_re2-1.1-5-cp311-cp311-win32.whl", hash = "sha256:a1d59568bbb5de5dd56dd6cdc79907db26cce63eb4429260300c65f43469e3e7"}, + {file = "google_re2-1.1-5-cp311-cp311-win_amd64.whl", hash = "sha256:72f5a2f179648b8358737b2b493549370debd7d389884a54d331619b285514e3"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:cbc72c45937b1dc5acac3560eb1720007dccca7c9879138ff874c7f6baf96005"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:5fadd1417fbef7235fa9453dba4eb102e6e7d94b1e4c99d5fa3dd4e288d0d2ae"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:040f85c63cc02696485b59b187a5ef044abe2f99b92b4fb399de40b7d2904ccc"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:64e3b975ee6d9bbb2420494e41f929c1a0de4bcc16d86619ab7a87f6ea80d6bd"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:8ee370413e00f4d828eaed0e83b8af84d7a72e8ee4f4bd5d3078bc741dfc430a"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:5b89383001079323f693ba592d7aad789d7a02e75adb5d3368d92b300f5963fd"}, + {file = "google_re2-1.1-5-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:63cb4fdfbbda16ae31b41a6388ea621510db82feb8217a74bf36552ecfcd50ad"}, + {file = "google_re2-1.1-5-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ebedd84ae8be10b7a71a16162376fd67a2386fe6361ef88c622dcf7fd679daf"}, + {file = "google_re2-1.1-5-cp312-cp312-win32.whl", hash = "sha256:c8e22d1692bc2c81173330c721aff53e47ffd3c4403ff0cd9d91adfd255dd150"}, + {file = "google_re2-1.1-5-cp312-cp312-win_amd64.whl", hash = "sha256:5197a6af438bb8c4abda0bbe9c4fbd6c27c159855b211098b29d51b73e4cbcf6"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b6727e0b98417e114b92688ad2aa256102ece51f29b743db3d831df53faf1ce3"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:711e2b6417eb579c61a4951029d844f6b95b9b373b213232efd413659889a363"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:71ae8b3df22c5c154c8af0f0e99d234a450ef1644393bc2d7f53fc8c0a1e111c"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:94a04e214bc521a3807c217d50cf099bbdd0c0a80d2d996c0741dbb995b5f49f"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_14_0_arm64.whl", hash = "sha256:a770f75358508a9110c81a1257721f70c15d9bb592a2fb5c25ecbd13566e52a5"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:07c9133357f7e0b17c6694d5dcb82e0371f695d7c25faef2ff8117ef375343ff"}, + {file = "google_re2-1.1-5-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:204ca6b1cf2021548f4a9c29ac015e0a4ab0a7b6582bf2183d838132b60c8fda"}, + {file = "google_re2-1.1-5-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f0b95857c2c654f419ca684ec38c9c3325c24e6ba7d11910a5110775a557bb18"}, + {file = "google_re2-1.1-5-cp38-cp38-win32.whl", hash = "sha256:347ac770e091a0364e822220f8d26ab53e6fdcdeaec635052000845c5a3fb869"}, + {file = "google_re2-1.1-5-cp38-cp38-win_amd64.whl", hash = "sha256:ec32bb6de7ffb112a07d210cf9f797b7600645c2d5910703fa07f456dd2150e0"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:eb5adf89060f81c5ff26c28e261e6b4997530a923a6093c9726b8dec02a9a326"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:a22630c9dd9ceb41ca4316bccba2643a8b1d5c198f21c00ed5b50a94313aaf10"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:544dc17fcc2d43ec05f317366375796351dec44058e1164e03c3f7d050284d58"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:19710af5ea88751c7768575b23765ce0dfef7324d2539de576f75cdc319d6654"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:f82995a205e08ad896f4bd5ce4847c834fab877e1772a44e5f262a647d8a1dec"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:63533c4d58da9dc4bc040250f1f52b089911699f0368e0e6e15f996387a984ed"}, + {file = "google_re2-1.1-5-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:79e00fcf0cb04ea35a22b9014712d448725ce4ddc9f08cc818322566176ca4b0"}, + {file = "google_re2-1.1-5-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bc41afcefee2da6c4ed883a93d7f527c4b960cd1d26bbb0020a7b8c2d341a60a"}, + {file = "google_re2-1.1-5-cp39-cp39-win32.whl", hash = "sha256:486730b5e1f1c31b0abc6d80abe174ce4f1188fe17d1b50698f2bf79dc6e44be"}, + {file = "google_re2-1.1-5-cp39-cp39-win_amd64.whl", hash = "sha256:4de637ca328f1d23209e80967d1b987d6b352cd01b3a52a84b4d742c69c3da6c"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:621e9c199d1ff0fdb2a068ad450111a84b3bf14f96dfe5a8a7a0deae5f3f4cce"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:220acd31e7dde95373f97c3d1f3b3bd2532b38936af28b1917ee265d25bebbf4"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:db34e1098d164f76251a6ece30e8f0ddfd65bb658619f48613ce71acb3f9cbdb"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:5152bac41d8073977582f06257219541d0fc46ad99b0bbf30e8f60198a43b08c"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:6191294799e373ee1735af91f55abd23b786bdfd270768a690d9d55af9ea1b0d"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:070cbafbb4fecbb02e98feb28a1eb292fb880f434d531f38cc33ee314b521f1f"}, + {file = "google_re2-1.1-6-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8437d078b405a59a576cbed544490fe041140f64411f2d91012e8ec05ab8bf86"}, + {file = "google_re2-1.1-6-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f00f9a9af8896040e37896d9b9fc409ad4979f1ddd85bb188694a7d95ddd1164"}, + {file = "google_re2-1.1-6-cp310-cp310-win32.whl", hash = "sha256:df26345f229a898b4fd3cafd5f82259869388cee6268fc35af16a8e2293dd4e5"}, + {file = "google_re2-1.1-6-cp310-cp310-win_amd64.whl", hash = "sha256:3665d08262c57c9b28a5bdeb88632ad792c4e5f417e5645901695ab2624f5059"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b26b869d8aa1d8fe67c42836bf3416bb72f444528ee2431cfb59c0d3e02c6ce3"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:41fd4486c57dea4f222a6bb7f1ff79accf76676a73bdb8da0fcbd5ba73f8da71"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:0ee378e2e74e25960070c338c28192377c4dd41e7f4608f2688064bd2badc41e"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:a00cdbf662693367b36d075b29feb649fd7ee1b617cf84f85f2deebeda25fc64"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:4c09455014217a41499432b8c8f792f25f3df0ea2982203c3a8c8ca0e7895e69"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:6501717909185327935c7945e23bb5aa8fc7b6f237b45fe3647fa36148662158"}, + {file = "google_re2-1.1-6-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3510b04790355f199e7861c29234081900e1e1cbf2d1484da48aa0ba6d7356ab"}, + {file = "google_re2-1.1-6-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8c0e64c187ca406764f9e9ad6e750d62e69ed8f75bf2e865d0bfbc03b642361c"}, + {file = "google_re2-1.1-6-cp311-cp311-win32.whl", hash = "sha256:2a199132350542b0de0f31acbb3ca87c3a90895d1d6e5235f7792bb0af02e523"}, + {file = "google_re2-1.1-6-cp311-cp311-win_amd64.whl", hash = "sha256:83bdac8ceaece8a6db082ea3a8ba6a99a2a1ee7e9f01a9d6d50f79c6f251a01d"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:81985ff894cd45ab5a73025922ac28c0707759db8171dd2f2cc7a0e856b6b5ad"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:5635af26065e6b45456ccbea08674ae2ab62494008d9202df628df3b267bc095"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:813b6f04de79f4a8fdfe05e2cb33e0ccb40fe75d30ba441d519168f9d958bd54"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:5ec2f5332ad4fd232c3f2d6748c2c7845ccb66156a87df73abcc07f895d62ead"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:5a687b3b32a6cbb731647393b7c4e3fde244aa557f647df124ff83fb9b93e170"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:39a62f9b3db5d3021a09a47f5b91708b64a0580193e5352751eb0c689e4ad3d7"}, + {file = "google_re2-1.1-6-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ca0f0b45d4a1709cbf5d21f355e5809ac238f1ee594625a1e5ffa9ff7a09eb2b"}, + {file = "google_re2-1.1-6-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a64b3796a7a616c7861247bd061c9a836b5caf0d5963e5ea8022125601cf7b09"}, + {file = "google_re2-1.1-6-cp312-cp312-win32.whl", hash = "sha256:32783b9cb88469ba4cd9472d459fe4865280a6b1acdad4480a7b5081144c4eb7"}, + {file = "google_re2-1.1-6-cp312-cp312-win_amd64.whl", hash = "sha256:259ff3fd2d39035b9cbcbf375995f83fa5d9e6a0c5b94406ff1cc168ed41d6c6"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:e4711bcffe190acd29104d8ecfea0c0e42b754837de3fb8aad96e6cc3c613cdc"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:4d081cce43f39c2e813fe5990e1e378cbdb579d3f66ded5bade96130269ffd75"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:4f123b54d48450d2d6b14d8fad38e930fb65b5b84f1b022c10f2913bd956f5b5"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:e1928b304a2b591a28eb3175f9db7f17c40c12cf2d4ec2a85fdf1cc9c073ff91"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_14_0_arm64.whl", hash = "sha256:3a69f76146166aec1173003c1f547931bdf288c6b135fda0020468492ac4149f"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:fc08c388f4ebbbca345e84a0c56362180d33d11cbe9ccfae663e4db88e13751e"}, + {file = "google_re2-1.1-6-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b057adf38ce4e616486922f2f47fc7d19c827ba0a7f69d540a3664eba2269325"}, + {file = "google_re2-1.1-6-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4138c0b933ab099e96f5d8defce4486f7dfd480ecaf7f221f2409f28022ccbc5"}, + {file = "google_re2-1.1-6-cp38-cp38-win32.whl", hash = "sha256:9693e45b37b504634b1abbf1ee979471ac6a70a0035954592af616306ab05dd6"}, + {file = "google_re2-1.1-6-cp38-cp38-win_amd64.whl", hash = "sha256:5674d437baba0ea287a5a7f8f81f24265d6ae8f8c09384e2ef7b6f84b40a7826"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7783137cb2e04f458a530c6d0ee9ef114815c1d48b9102f023998c371a3b060e"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:a49b7153935e7a303675f4deb5f5d02ab1305adefc436071348706d147c889e0"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:a96a8bb309182090704593c60bdb369a2756b38fe358bbf0d40ddeb99c71769f"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:dff3d4be9f27ef8ec3705eed54f19ef4ab096f5876c15fe011628c69ba3b561c"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:40f818b0b39e26811fa677978112a8108269977fdab2ba0453ac4363c35d9e66"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:8a7e53538cdb40ef4296017acfbb05cab0c19998be7552db1cfb85ba40b171b9"}, + {file = "google_re2-1.1-6-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6ee18e7569fb714e5bb8c42809bf8160738637a5e71ed5a4797757a1fb4dc4de"}, + {file = "google_re2-1.1-6-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1cda4f6d1a7d5b43ea92bc395f23853fba0caf8b1e1efa6e8c48685f912fcb89"}, + {file = "google_re2-1.1-6-cp39-cp39-win32.whl", hash = "sha256:6a9cdbdc36a2bf24f897be6a6c85125876dc26fea9eb4247234aec0decbdccfd"}, + {file = "google_re2-1.1-6-cp39-cp39-win_amd64.whl", hash = "sha256:73f646cecfad7cc5b4330b4192c25f2e29730a3b8408e089ffd2078094208196"}, ] [[package]] @@ -9664,10 +9764,11 @@ redshift = ["psycopg2-binary", "psycopg2cffi"] s3 = ["botocore", "s3fs"] snowflake = ["snowflake-connector-python"] sql-database = ["sqlalchemy"] +sqlalchemy = ["alembic", "sqlalchemy"] synapse = ["adlfs", "pyarrow", "pyodbc"] weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "e5342d5cdc135a27b89747a3665ff68aa76025efcfde6f86318144ce0fd70284" +content-hash = "e6409812458e0aae06a9b5f2c816b19d73645c8598a1789b8d5b39a45e974a9f" From a9c89a0ca70fb1f3d2fe9445a2593c7f94354ddf Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 11 Sep 2024 16:54:52 -0400 Subject: [PATCH 27/37] Fix test, complex -> json --- dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py | 4 ++-- tests/load/pipeline/test_arrow_loading.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index 33f00870c1..095adeccf3 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -118,7 +118,7 @@ def to_db_type(self, column: TColumnSchema, table_format: TTableSchema) -> sa.ty return self._db_integer_type(precision) elif sc_t == "binary": return sa.LargeBinary(length=precision) - elif sc_t == "complex": + elif sc_t == "json": return sa.JSON(none_as_null=True) elif sc_t == "decimal": return self._to_db_decimal_type(column) @@ -163,7 +163,7 @@ def from_db_type(self, db_type: sa.types.TypeEngine) -> TColumnType: elif isinstance(db_type, sqltypes._Binary): return dict(data_type="binary", precision=db_type.length) elif isinstance(db_type, sa.JSON): - return dict(data_type="complex") + return dict(data_type="json") elif isinstance(db_type, sa.Numeric): return self._from_db_decimal_type(db_type) elif isinstance(db_type, sa.Date): diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index a0da556e48..369359d61a 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -252,7 +252,7 @@ def test_load_arrow_with_not_null_columns( ) -> None: """Resource schema contains non-nullable columns. Arrow schema should be written accordingly""" if ( - destination_config.destination in ("databricks", "redshift") + destination_config.destination_type in ("databricks", "redshift") and destination_config.file_format == "jsonl" ): pytest.skip( From 874c8710900a8ca87b8d7afe753fdbe3f92fdb28 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 11 Sep 2024 19:28:47 -0400 Subject: [PATCH 28/37] Refactor type mapper --- dlt/common/destination/capabilities.py | 4 +- dlt/common/destination/reference.py | 1 - dlt/destinations/impl/sqlalchemy/factory.py | 11 +- .../impl/sqlalchemy/sqlalchemy_job_client.py | 176 ++---------------- .../impl/sqlalchemy/type_mapper.py | 174 +++++++++++++++++ 5 files changed, 202 insertions(+), 164 deletions(-) create mode 100644 dlt/destinations/impl/sqlalchemy/type_mapper.py diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 722c82c176..8f0dce79ce 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -226,8 +226,8 @@ def generic_capabilities( caps.merge_strategies_selector = merge_strategies_selector return caps - def get_type_mapper(self) -> DataTypeMapper: - return self.type_mapper(self) + def get_type_mapper(self, *args: Any, **kwargs: Any) -> DataTypeMapper: + return self.type_mapper(self, *args, **kwargs) def merge_caps_file_formats( diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index b394a15e60..9e27b66335 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -393,7 +393,6 @@ def run_managed( self.run() self._state = "completed" except (DestinationTerminalException, TerminalValueError) as e: - logger.exception(f"Job {self.job_id()} failed terminally") self._state = "failed" self._exception = e logger.exception(f"Terminal exception in job {self.job_id()} in file {self._file_path}") diff --git a/dlt/destinations/impl/sqlalchemy/factory.py b/dlt/destinations/impl/sqlalchemy/factory.py index c6b78baf3b..10372cda34 100644 --- a/dlt/destinations/impl/sqlalchemy/factory.py +++ b/dlt/destinations/impl/sqlalchemy/factory.py @@ -1,7 +1,7 @@ import typing as t -from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.destination.capabilities import DataTypeMapper from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.normalizers import NamingConvention @@ -10,6 +10,14 @@ SqlalchemyClientConfiguration, ) +SqlalchemyTypeMapper: t.Type[DataTypeMapper] + +try: + from dlt.destinations.impl.sqlalchemy.type_mapper import SqlalchemyTypeMapper +except ModuleNotFoundError: + # assign mock type mapper if no sqlalchemy + from dlt.common.destination.capabilities import UnsupportedTypeMapper as SqlalchemyTypeMapper + if t.TYPE_CHECKING: # from dlt.destinations.impl.sqlalchemy.sqlalchemy_client import SqlalchemyJobClient from dlt.destinations.impl.sqlalchemy.sqlalchemy_job_client import SqlalchemyJobClient @@ -37,6 +45,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supports_ddl_transactions = True caps.max_query_parameters = 20_0000 caps.max_rows_per_insert = 10_000 # Set a default to avoid OOM on large datasets + caps.type_mapper = SqlalchemyTypeMapper return caps diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index 095adeccf3..c51d3cbe3a 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -1,26 +1,23 @@ -from typing import Iterable, Optional, Type, Dict, Any, Iterator, Sequence, List, Tuple, IO -from types import TracebackType +from typing import Iterable, Optional, Dict, Any, Iterator, Sequence, List, Tuple, IO from contextlib import suppress -import inspect import math import sqlalchemy as sa -from sqlalchemy.sql import sqltypes from dlt.common import logger from dlt.common import pendulum -from dlt.common.exceptions import TerminalValueError from dlt.common.destination.reference import ( JobClientBase, LoadJob, RunnableLoadJob, StorageSchemaInfo, StateInfo, + PreparedTableSchema, ) from dlt.destinations.job_client_impl import SqlJobClientBase from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema import Schema, TTableSchema, TColumnSchema, TSchemaTables -from dlt.common.schema.typing import TColumnType, TTableFormat, TTableSchemaColumns +from dlt.common.schema.typing import TColumnType, TTableSchemaColumns from dlt.common.schema.utils import pipeline_state_table, normalize_table_identifiers from dlt.common.storages import FileStorage from dlt.common.json import json, PY_DATETIME_DECODERS @@ -32,147 +29,6 @@ from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyClientConfiguration -class SqlaTypeMapper: - # TODO: May be merged with TypeMapper as a generic - def __init__( - self, capabilities: DestinationCapabilitiesContext, dialect: sa.engine.Dialect - ) -> None: - self.capabilities = capabilities - self.dialect = dialect - - def _db_integer_type(self, precision: Optional[int]) -> sa.types.TypeEngine: - if precision is None: - return sa.BigInteger() - elif precision <= 16: - return sa.SmallInteger() - elif precision <= 32: - return sa.Integer() - elif precision <= 64: - return sa.BigInteger() - raise TerminalValueError(f"Unsupported precision for integer type: {precision}") - - def _create_date_time_type( - self, sc_t: str, precision: Optional[int], timezone: Optional[bool] - ) -> sa.types.TypeEngine: - """Use the dialect specific datetime/time type if possible since the generic type doesn't accept precision argument""" - precision = precision if precision is not None else self.capabilities.timestamp_precision - base_type: sa.types.TypeEngine - timezone = timezone is None or bool(timezone) - if sc_t == "timestamp": - base_type = sa.DateTime() - if self.dialect.name == "mysql": - # Special case, type_descriptor does not return the specifc datetime type - from sqlalchemy.dialects.mysql import DATETIME - - return DATETIME(fsp=precision) - elif sc_t == "time": - base_type = sa.Time() - - dialect_type = type( - self.dialect.type_descriptor(base_type) - ) # Get the dialect specific subtype - precision = precision if precision is not None else self.capabilities.timestamp_precision - - # Find out whether the dialect type accepts precision or fsp argument - params = inspect.signature(dialect_type).parameters - kwargs: Dict[str, Any] = dict(timezone=timezone) - if "fsp" in params: - kwargs["fsp"] = precision # MySQL uses fsp for fractional seconds - elif "precision" in params: - kwargs["precision"] = precision - return dialect_type(**kwargs) # type: ignore[no-any-return,misc] - - def _create_double_type(self) -> sa.types.TypeEngine: - if dbl := getattr(sa, "Double", None): - # Sqlalchemy 2 has generic double type - return dbl() # type: ignore[no-any-return] - elif self.dialect.name == "mysql": - # MySQL has a specific double type - from sqlalchemy.dialects.mysql import DOUBLE - return sa.Float(precision=53) # Otherwise use float - - def _to_db_decimal_type(self, column: TColumnSchema) -> sa.types.TypeEngine: - precision, scale = column.get("precision"), column.get("scale") - if precision is None and scale is None: - precision, scale = self.capabilities.decimal_precision - return sa.Numeric(precision, scale) - - def to_db_type(self, column: TColumnSchema, table_format: TTableSchema) -> sa.types.TypeEngine: - sc_t = column["data_type"] - precision = column.get("precision") - # TODO: Precision and scale for supported types - if sc_t == "text": - length = precision - if length is None and column.get("unique"): - length = 128 - if length is None: - return sa.Text() - return sa.String(length=length) - elif sc_t == "double": - return self._create_double_type() - elif sc_t == "bool": - return sa.Boolean() - elif sc_t == "timestamp": - return self._create_date_time_type(sc_t, precision, column.get("timezone")) - elif sc_t == "bigint": - return self._db_integer_type(precision) - elif sc_t == "binary": - return sa.LargeBinary(length=precision) - elif sc_t == "json": - return sa.JSON(none_as_null=True) - elif sc_t == "decimal": - return self._to_db_decimal_type(column) - elif sc_t == "wei": - wei_precision, wei_scale = self.capabilities.wei_precision - return sa.Numeric(precision=wei_precision, scale=wei_scale) - elif sc_t == "date": - return sa.Date() - elif sc_t == "time": - return self._create_date_time_type(sc_t, precision, column.get("timezone")) - raise TerminalValueError(f"Unsupported data type: {sc_t}") - - def _from_db_integer_type(self, db_type: sa.Integer) -> TColumnType: - if isinstance(db_type, sa.SmallInteger): - return dict(data_type="bigint", precision=16) - elif isinstance(db_type, sa.Integer): - return dict(data_type="bigint", precision=32) - elif isinstance(db_type, sa.BigInteger): - return dict(data_type="bigint") - return dict(data_type="bigint") - - def _from_db_decimal_type(self, db_type: sa.Numeric) -> TColumnType: - precision, scale = db_type.precision, db_type.scale - if (precision, scale) == self.capabilities.wei_precision: - return dict(data_type="wei") - - return dict(data_type="decimal", precision=precision, scale=scale) - - def from_db_type(self, db_type: sa.types.TypeEngine) -> TColumnType: - # TODO: pass the sqla type through dialect.type_descriptor before instance check - # Possibly need to check both dialect specific and generic types - if isinstance(db_type, sa.String): - return dict(data_type="text") - elif isinstance(db_type, sa.Float): - return dict(data_type="double") - elif isinstance(db_type, sa.Boolean): - return dict(data_type="bool") - elif isinstance(db_type, sa.DateTime): - return dict(data_type="timestamp", timezone=db_type.timezone) - elif isinstance(db_type, sa.Integer): - return self._from_db_integer_type(db_type) - elif isinstance(db_type, sqltypes._Binary): - return dict(data_type="binary", precision=db_type.length) - elif isinstance(db_type, sa.JSON): - return dict(data_type="json") - elif isinstance(db_type, sa.Numeric): - return self._from_db_decimal_type(db_type) - elif isinstance(db_type, sa.Date): - return dict(data_type="date") - elif isinstance(db_type, sa.Time): - return dict(data_type="time") - raise TerminalValueError(f"Unsupported db type: {db_type}") - - class SqlalchemyJsonLInsertJob(RunnableLoadJob): def __init__(self, file_path: str, table: sa.Table) -> None: super().__init__(file_path) @@ -270,9 +126,9 @@ def __init__( self.schema = schema self.capabilities = capabilities self.config = config - self.type_mapper = SqlaTypeMapper(capabilities, self.sql_client.dialect) + self.type_mapper = self.capabilities.get_type_mapper(self.sql_client.dialect) - def _to_table_object(self, schema_table: TTableSchema) -> sa.Table: + def _to_table_object(self, schema_table: PreparedTableSchema) -> sa.Table: existing = self.sql_client.get_existing_table(schema_table["name"]) if existing is not None: existing_col_names = set(col.name for col in existing.columns) @@ -292,17 +148,17 @@ def _to_table_object(self, schema_table: TTableSchema) -> sa.Table: ) def _to_column_object( - self, schema_column: TColumnSchema, table_format: TTableSchema + self, schema_column: TColumnSchema, table: PreparedTableSchema ) -> sa.Column: return sa.Column( schema_column["name"], - self.type_mapper.to_db_type(schema_column, table_format), + self.type_mapper.to_destination_type(schema_column, table), nullable=schema_column.get("nullable", True), unique=schema_column.get("unique", False), ) def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: if file_path.endswith(".typed-jsonl"): table_obj = self._to_table_object(table) @@ -313,7 +169,7 @@ def create_load_job( return None def complete_load(self, load_id: str) -> None: - loads_table = self._to_table_object(self.schema.tables[self.schema.loads_table_name]) + loads_table = self._to_table_object(self.schema.tables[self.schema.loads_table_name]) # type: ignore[arg-type] now_ts = pendulum.now() self.sql_client.execute_sql( loads_table.insert().values( @@ -346,7 +202,7 @@ def get_storage_tables( col.name: { "name": col.name, "nullable": col.nullable, - **self.type_mapper.from_db_type(col.type), + **self.type_mapper.from_destination_type(col.type, None, None), } for col in table_obj.columns } @@ -373,7 +229,7 @@ def update_stored_schema( # Create all schema tables in metadata for table_name in only_tables or self.schema.tables: - self._to_table_object(self.schema.tables[table_name]) + self._to_table_object(self.schema.tables[table_name]) # type: ignore[arg-type] schema_update: TSchemaTables = {} tables_to_create: List[sa.Table] = [] @@ -407,7 +263,7 @@ def update_stored_schema( def _delete_schema_in_storage(self, schema: Schema) -> None: version_table = schema.tables[schema.version_table_name] - table_obj = self._to_table_object(version_table) + table_obj = self._to_table_object(version_table) # type: ignore[arg-type] schema_name_col = schema.naming.normalize_identifier("schema_name") self.sql_client.execute_sql( table_obj.delete().where(table_obj.c[schema_name_col] == schema.name) @@ -415,7 +271,7 @@ def _delete_schema_in_storage(self, schema: Schema) -> None: def _update_schema_in_storage(self, schema: Schema) -> None: version_table = schema.tables[schema.version_table_name] - table_obj = self._to_table_object(version_table) + table_obj = self._to_table_object(version_table) # type: ignore[arg-type] schema_str = json.dumps(schema.to_dict()) schema_mapping = StorageSchemaInfo( @@ -433,7 +289,7 @@ def _get_stored_schema( self, version_hash: Optional[str] = None, schema_name: Optional[str] = None ) -> Optional[StorageSchemaInfo]: version_table = self.schema.tables[self.schema.version_table_name] - table_obj = self._to_table_object(version_table) + table_obj = self._to_table_object(version_table) # type: ignore[arg-type] with suppress(DatabaseUndefinedRelation): q = sa.select(table_obj) if version_hash is not None: @@ -465,9 +321,9 @@ def get_stored_state(self, pipeline_name: str) -> StateInfo: state_table = self.schema.tables.get( self.schema.state_table_name ) or normalize_table_identifiers(pipeline_state_table(), self.schema.naming) - state_table_obj = self._to_table_object(state_table) + state_table_obj = self._to_table_object(state_table) # type: ignore[arg-type] loads_table = self.schema.tables[self.schema.loads_table_name] - loads_table_obj = self._to_table_object(loads_table) + loads_table_obj = self._to_table_object(loads_table) # type: ignore[arg-type] c_load_id, c_dlt_load_id, c_pipeline_name, c_status = map( self.schema.naming.normalize_identifier, diff --git a/dlt/destinations/impl/sqlalchemy/type_mapper.py b/dlt/destinations/impl/sqlalchemy/type_mapper.py new file mode 100644 index 0000000000..767d2115d4 --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/type_mapper.py @@ -0,0 +1,174 @@ +from typing import Optional, Dict, Any +import inspect + +import sqlalchemy as sa +from sqlalchemy.sql import sqltypes + +from dlt.common.exceptions import TerminalValueError +from dlt.common.typing import TLoaderFileFormat +from dlt.common.destination.capabilities import DataTypeMapper, DestinationCapabilitiesContext +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.schema.typing import TColumnSchema + + +# TODO: base type mapper should be a generic class to support TypeEngine instead of str types +class SqlalchemyTypeMapper(DataTypeMapper): + def __init__( + self, + capabilities: DestinationCapabilitiesContext, + dialect: Optional[sa.engine.Dialect] = None, + ): + super().__init__(capabilities) + # Mapper is used to verify supported types without client, dialect is not important for this + self.dialect = dialect or sa.engine.default.DefaultDialect() + + def _db_integer_type(self, precision: Optional[int]) -> sa.types.TypeEngine: + if precision is None: + return sa.BigInteger() + elif precision <= 16: + return sa.SmallInteger() + elif precision <= 32: + return sa.Integer() + elif precision <= 64: + return sa.BigInteger() + raise TerminalValueError(f"Unsupported precision for integer type: {precision}") + + def _create_date_time_type( + self, sc_t: str, precision: Optional[int], timezone: Optional[bool] + ) -> sa.types.TypeEngine: + """Use the dialect specific datetime/time type if possible since the generic type doesn't accept precision argument""" + precision = precision if precision is not None else self.capabilities.timestamp_precision + base_type: sa.types.TypeEngine + timezone = timezone is None or bool(timezone) + if sc_t == "timestamp": + base_type = sa.DateTime() + if self.dialect.name == "mysql": + # Special case, type_descriptor does not return the specifc datetime type + from sqlalchemy.dialects.mysql import DATETIME + + return DATETIME(fsp=precision) + elif sc_t == "time": + base_type = sa.Time() + + dialect_type = type( + self.dialect.type_descriptor(base_type) + ) # Get the dialect specific subtype + precision = precision if precision is not None else self.capabilities.timestamp_precision + + # Find out whether the dialect type accepts precision or fsp argument + params = inspect.signature(dialect_type).parameters + kwargs: Dict[str, Any] = dict(timezone=timezone) + if "fsp" in params: + kwargs["fsp"] = precision # MySQL uses fsp for fractional seconds + elif "precision" in params: + kwargs["precision"] = precision + return dialect_type(**kwargs) # type: ignore[no-any-return,misc] + + def _create_double_type(self) -> sa.types.TypeEngine: + if dbl := getattr(sa, "Double", None): + # Sqlalchemy 2 has generic double type + return dbl() # type: ignore[no-any-return] + elif self.dialect.name == "mysql": + # MySQL has a specific double type + from sqlalchemy.dialects.mysql import DOUBLE + + return DOUBLE() + return sa.Float(precision=53) # Otherwise use float + + def _to_db_decimal_type(self, column: TColumnSchema) -> sa.types.TypeEngine: + precision, scale = column.get("precision"), column.get("scale") + if precision is None and scale is None: + precision, scale = self.capabilities.decimal_precision + return sa.Numeric(precision, scale) + + def to_destination_type( # type: ignore[override] + self, column: TColumnSchema, table: PreparedTableSchema = None + ) -> sqltypes.TypeEngine: + sc_t = column["data_type"] + precision = column.get("precision") + # TODO: Precision and scale for supported types + if sc_t == "text": + length = precision + if length is None and column.get("unique"): + length = 128 + if length is None: + return sa.Text() + return sa.String(length=length) + elif sc_t == "double": + return self._create_double_type() + elif sc_t == "bool": + return sa.Boolean() + elif sc_t == "timestamp": + return self._create_date_time_type(sc_t, precision, column.get("timezone")) + elif sc_t == "bigint": + return self._db_integer_type(precision) + elif sc_t == "binary": + return sa.LargeBinary(length=precision) + elif sc_t == "json": + return sa.JSON(none_as_null=True) + elif sc_t == "decimal": + return self._to_db_decimal_type(column) + elif sc_t == "wei": + wei_precision, wei_scale = self.capabilities.wei_precision + return sa.Numeric(precision=wei_precision, scale=wei_scale) + elif sc_t == "date": + return sa.Date() + elif sc_t == "time": + return self._create_date_time_type(sc_t, precision, column.get("timezone")) + raise TerminalValueError(f"Unsupported data type: {sc_t}") + + def _from_db_integer_type(self, db_type: sa.Integer) -> TColumnSchema: + if isinstance(db_type, sa.SmallInteger): + return dict(data_type="bigint", precision=16) + elif isinstance(db_type, sa.Integer): + return dict(data_type="bigint", precision=32) + elif isinstance(db_type, sa.BigInteger): + return dict(data_type="bigint") + return dict(data_type="bigint") + + def _from_db_decimal_type(self, db_type: sa.Numeric) -> TColumnSchema: + precision, scale = db_type.precision, db_type.scale + if (precision, scale) == self.capabilities.wei_precision: + return dict(data_type="wei") + + return dict(data_type="decimal", precision=precision, scale=scale) + + def from_destination_type( # type: ignore[override] + self, + db_type: sqltypes.TypeEngine, + precision: Optional[int] = None, + scale: Optional[int] = None, + ) -> TColumnSchema: + # TODO: pass the sqla type through dialect.type_descriptor before instance check + # Possibly need to check both dialect specific and generic types + if isinstance(db_type, sa.String): + return dict(data_type="text") + elif isinstance(db_type, sa.Float): + return dict(data_type="double") + elif isinstance(db_type, sa.Boolean): + return dict(data_type="bool") + elif isinstance(db_type, sa.DateTime): + return dict(data_type="timestamp", timezone=db_type.timezone) + elif isinstance(db_type, sa.Integer): + return self._from_db_integer_type(db_type) + elif isinstance(db_type, sqltypes._Binary): + return dict(data_type="binary", precision=db_type.length) + elif isinstance(db_type, sa.JSON): + return dict(data_type="json") + elif isinstance(db_type, sa.Numeric): + return self._from_db_decimal_type(db_type) + elif isinstance(db_type, sa.Date): + return dict(data_type="date") + elif isinstance(db_type, sa.Time): + return dict(data_type="time") + raise TerminalValueError(f"Unsupported db type: {db_type}") + + pass + + def ensure_supported_type( + self, + column: TColumnSchema, + table: PreparedTableSchema, + loader_file_format: TLoaderFileFormat, + ) -> None: + pass From a69d74943b6c68e9ff08858cdc6858edcf0e4cbf Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 11 Sep 2024 19:32:59 -0400 Subject: [PATCH 29/37] Update tests destination config --- tests/load/pipeline/test_arrow_loading.py | 4 +++- tests/load/pipeline/test_merge_disposition.py | 5 ++++- tests/load/pipeline/test_pipelines.py | 2 +- tests/load/pipeline/test_stage_loading.py | 9 ++++++--- tests/load/sources/filesystem/test_filesystem_source.py | 4 ++-- .../test_sql_database_source_all_destinations.py | 8 ++++---- tests/load/test_job_client.py | 4 ++-- 7 files changed, 22 insertions(+), 14 deletions(-) diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 369359d61a..5bebf6f7ed 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -88,7 +88,9 @@ def some_data(): # use csv for postgres to get native arrow processing destination_config.file_format = ( - destination_config.file_format if destination_config.destination_type != "postgres" else "csv" + destination_config.file_format + if destination_config.destination_type != "postgres" + else "csv" ) load_info = pipeline.run(some_data(), **destination_config.run_kwargs) diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index d1082263dd..b1244de336 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -496,7 +496,10 @@ def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) assert_load_info(info) # make sure it was parquet or sql inserts files = p.get_load_package_info(p.list_completed_load_packages()[1]).jobs["completed_jobs"] - if destination_config.destination == "athena" and destination_config.table_format == "iceberg": + if ( + destination_config.destination_type == "athena" + and destination_config.table_format == "iceberg" + ): # iceberg uses sql to copy tables expected_formats.append("sql") assert all(f.job_file_info.file_format in expected_formats for f in files) diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 2a29a0a24d..659bca6cb9 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -561,7 +561,7 @@ def some_source(): if destination_config.supports_merge: expected_completed_jobs += 1 # add iceberg copy jobs - if destination_config.destination == "athena": + if destination_config.destination_type == "athena": expected_completed_jobs += 2 # if destination_config.supports_merge else 4 assert len(package_info.jobs["completed_jobs"]) == expected_completed_jobs diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index de30615a6a..cc8175b677 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -231,7 +231,7 @@ def test_truncate_staging_dataset(destination_config: DestinationTestConfigurati with staging_client: # except Athena + Iceberg which does not store tables in staging dataset if ( - destination_config.destination == "athena" + destination_config.destination_type == "athena" and destination_config.table_format == "iceberg" ): table_count = 0 @@ -257,7 +257,7 @@ def test_truncate_staging_dataset(destination_config: DestinationTestConfigurati _, staging_client = pipeline._get_destination_clients(pipeline.default_schema) with staging_client: # except for Athena which does not delete staging destination tables - if destination_config.destination == "athena": + if destination_config.destination_type == "athena": if destination_config.table_format == "iceberg": table_count = 0 else: @@ -302,7 +302,10 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non ): # Redshift can't load fixed width binary columns from parquet exclude_columns.append("col7_precision") - if destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl": + if ( + destination_config.destination_type == "databricks" + and destination_config.file_format == "jsonl" + ): exclude_types.extend(["decimal", "binary", "wei", "json", "date"]) exclude_columns.append("col1_precision") diff --git a/tests/load/sources/filesystem/test_filesystem_source.py b/tests/load/sources/filesystem/test_filesystem_source.py index 947e7e9e1c..15a1079cca 100644 --- a/tests/load/sources/filesystem/test_filesystem_source.py +++ b/tests/load/sources/filesystem/test_filesystem_source.py @@ -126,7 +126,7 @@ def test_csv_transformers( # print(pipeline.last_trace.last_normalize_info) # must contain 24 rows of A881 - if not destination_config.destination == "filesystem": + if not destination_config.destination_type == "filesystem": # TODO: comment out when filesystem destination supports queries (data pond PR) assert_query_data(pipeline, "SELECT code FROM met_csv", ["A881"] * 24) @@ -138,7 +138,7 @@ def test_csv_transformers( assert_load_info(load_info) # print(pipeline.last_trace.last_normalize_info) # must contain 48 rows of A803 - if not destination_config.destination == "filesystem": + if not destination_config.destination_type == "filesystem": # TODO: comment out when filesystem destination supports queries (data pond PR) assert_query_data(pipeline, "SELECT code FROM met_csv", ["A803"] * 48) # and 48 rows in total -> A881 got replaced diff --git a/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py b/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py index 7012602b4a..4f4e876fb6 100644 --- a/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py +++ b/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py @@ -51,10 +51,10 @@ def test_load_sql_schema_loads_all_tables( schema=sql_source_db.schema, backend=backend, reflection_level="minimal", - type_adapter_callback=default_test_callback(destination_config.destination, backend), + type_adapter_callback=default_test_callback(destination_config.destination_type, backend), ) - if destination_config.destination == "bigquery" and backend == "connectorx": + if destination_config.destination_type == "bigquery" and backend == "connectorx": # connectorx generates nanoseconds time which bigquery cannot load source.has_precision.add_map(convert_time_to_us) source.has_precision_nullable.add_map(convert_time_to_us) @@ -91,10 +91,10 @@ def test_load_sql_schema_loads_all_tables_parallel( schema=sql_source_db.schema, backend=backend, reflection_level="minimal", - type_adapter_callback=default_test_callback(destination_config.destination, backend), + type_adapter_callback=default_test_callback(destination_config.destination_type, backend), ).parallelize() - if destination_config.destination == "bigquery" and backend == "connectorx": + if destination_config.destination_type == "bigquery" and backend == "connectorx": # connectorx generates nanoseconds time which bigquery cannot load source.has_precision.add_map(convert_time_to_us) source.has_precision_nullable.add_map(convert_time_to_us) diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index ba3baec17c..84d08a5a89 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -592,9 +592,9 @@ def test_load_with_all_types( client.update_stored_schema() if isinstance(client, WithStagingDataset): - should_load_to_staging = client.should_load_data_to_staging_dataset(table_name) # type: ignore[attr-defined] + should_load_to_staging = client.should_load_data_to_staging_dataset(table_name) if should_load_to_staging: - with client.with_staging_dataset(): # type: ignore[attr-defined] + with client.with_staging_dataset(): # create staging for merge dataset client.initialize_storage() client.update_stored_schema() From c25932b63b752ec64d23685bcd86b923c14c7562 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 11 Sep 2024 20:44:56 -0400 Subject: [PATCH 30/37] Fix tests --- .../sources/filesystem/test_filesystem_source.py | 16 +++++++++++----- tests/load/utils.py | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/load/sources/filesystem/test_filesystem_source.py b/tests/load/sources/filesystem/test_filesystem_source.py index 15a1079cca..88796b9c4d 100644 --- a/tests/load/sources/filesystem/test_filesystem_source.py +++ b/tests/load/sources/filesystem/test_filesystem_source.py @@ -111,7 +111,9 @@ def test_fsspec_as_credentials(): @pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS) @pytest.mark.parametrize( "destination_config", - destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), + destinations_configs( + default_sql_configs=True, supports_merge=True, all_buckets_filesystem_configs=True + ), ids=lambda x: x.name, ) def test_csv_transformers( @@ -126,9 +128,11 @@ def test_csv_transformers( # print(pipeline.last_trace.last_normalize_info) # must contain 24 rows of A881 - if not destination_config.destination_type == "filesystem": + if destination_config.destination_type != "filesystem": + with pipeline.sql_client() as client: + table_name = client.make_qualified_table_name("met_csv") # TODO: comment out when filesystem destination supports queries (data pond PR) - assert_query_data(pipeline, "SELECT code FROM met_csv", ["A881"] * 24) + assert_query_data(pipeline, f"SELECT code FROM {table_name}", ["A881"] * 24) # load the other folder that contains data for the same day + one other day # the previous data will be replaced @@ -138,9 +142,11 @@ def test_csv_transformers( assert_load_info(load_info) # print(pipeline.last_trace.last_normalize_info) # must contain 48 rows of A803 - if not destination_config.destination_type == "filesystem": + if destination_config.destination_type != "filesystem": + with pipeline.sql_client() as client: + table_name = client.make_qualified_table_name("met_csv") # TODO: comment out when filesystem destination supports queries (data pond PR) - assert_query_data(pipeline, "SELECT code FROM met_csv", ["A803"] * 48) + assert_query_data(pipeline, f"SELECT code FROM {table_name}", ["A803"] * 48) # and 48 rows in total -> A881 got replaced # print(pipeline.default_schema.to_pretty_yaml()) assert load_table_counts(pipeline, "met_csv") == {"met_csv": 48} diff --git a/tests/load/utils.py b/tests/load/utils.py index 55a3d9dd0b..068c2b715a 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -834,7 +834,7 @@ def yield_client( client: SqlJobClientBase = None # athena requires staging config to be present, so stick this in there here - if destination.destination_type == "athena": + if destination.destination_name == "athena": staging_config = DestinationClientStagingConfiguration( bucket_url=AWS_BUCKET, )._bind_dataset_name(dataset_name=dest_config.dataset_name) From 6c426e6eb121abdc70aebc89041f929224216a7f Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 11 Sep 2024 21:24:09 -0400 Subject: [PATCH 31/37] Ignore sources tests --- .github/workflows/test_sqlalchemy_destinations.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml index 245060e811..5da2dac04b 100644 --- a/.github/workflows/test_sqlalchemy_destinations.yml +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -92,7 +92,7 @@ jobs: run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml # always run full suite, also on branches - - run: poetry run pytest tests/load -x + - run: poetry run pytest tests/load -x --ignore tests/load/sources name: Run tests Linux env: DESTINATION__SQLALCHEMY_MYSQL__CREDENTIALS: mysql://root:root@127.0.0.1:3306/dlt_data # Use root cause we need to create databases From 36b585e00458e80635f5b7ad5ca7b6316ab78389 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 12 Sep 2024 11:30:37 -0400 Subject: [PATCH 32/37] Fix overriding destination in test pipeline --- tests/load/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/load/utils.py b/tests/load/utils.py index 068c2b715a..bfe365eabc 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -220,10 +220,10 @@ def setup_pipeline( """Convenience method to setup pipeline with this configuration""" self.dev_mode = dev_mode - factory = self.destination_factory(**kwargs) + destination = kwargs.pop("destination", None) or self.destination_factory(**kwargs) pipeline = dlt.pipeline( pipeline_name=pipeline_name, - destination=factory, + destination=destination, staging=kwargs.pop("staging", self.staging), dataset_name=dataset_name or pipeline_name, dev_mode=dev_mode, From 0208c64924cf19e963fb79ac9c9596e3e6dd1c2b Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 12 Sep 2024 12:07:09 -0400 Subject: [PATCH 33/37] Fix time precision in arrow test --- tests/load/pipeline/test_arrow_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 5bebf6f7ed..f72aaec1d8 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -156,7 +156,7 @@ def some_data(): for row in expected: for i in range(len(row)): - if isinstance(row[i], datetime): + if isinstance(row[i], (datetime, dt_time)): row[i] = reduce_pendulum_datetime_precision( row[i], pipeline.destination.capabilities().timestamp_precision ) From 65f6ef7fd7da20f629708f0190efa5c5a43b50ab Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 12 Sep 2024 12:10:25 -0400 Subject: [PATCH 34/37] Lint --- tests/load/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/load/utils.py b/tests/load/utils.py index bfe365eabc..d436172dbe 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -220,7 +220,7 @@ def setup_pipeline( """Convenience method to setup pipeline with this configuration""" self.dev_mode = dev_mode - destination = kwargs.pop("destination", None) or self.destination_factory(**kwargs) + destination = kwargs.pop("destination", None) or self.destination_factory(**kwargs) pipeline = dlt.pipeline( pipeline_name=pipeline_name, destination=destination, From 6c29071c9b5ad7ae30bdcde686510d1b48599457 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 12 Sep 2024 21:50:22 -0400 Subject: [PATCH 35/37] Fix destination setup in test --- tests/load/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/load/utils.py b/tests/load/utils.py index d436172dbe..9ee99eef3d 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -220,7 +220,11 @@ def setup_pipeline( """Convenience method to setup pipeline with this configuration""" self.dev_mode = dev_mode - destination = kwargs.pop("destination", None) or self.destination_factory(**kwargs) + destination = kwargs.pop("destination") + if destination is None: + destination = self.destination_factory(**kwargs) + else: + self.setup() pipeline = dlt.pipeline( pipeline_name=pipeline_name, destination=destination, From eec4e220669399734353d56638fc5afd20f48513 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 12 Sep 2024 21:53:04 -0400 Subject: [PATCH 36/37] Fix --- tests/load/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/load/utils.py b/tests/load/utils.py index 9ee99eef3d..1c47291a6c 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -220,7 +220,7 @@ def setup_pipeline( """Convenience method to setup pipeline with this configuration""" self.dev_mode = dev_mode - destination = kwargs.pop("destination") + destination = kwargs.pop("destination", None) if destination is None: destination = self.destination_factory(**kwargs) else: From dc4c29c5419661244b275d1951d175c246b6b490 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 13 Sep 2024 15:08:33 -0400 Subject: [PATCH 37/37] Use nullpool, lazy create engine, close current connection --- .../impl/sqlalchemy/db_api_client.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index 8c3df409aa..c6c8ba53d6 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -95,10 +95,9 @@ class DbApiProps: class SqlalchemyClient(SqlClientBase[Connection]): external_engine: bool = False - dialect: sa.engine.interfaces.Dialect - dialect_name: str dbapi = DbApiProps # type: ignore[assignment] migrations: Optional[MigrationMaker] = None # lazy init as needed + _engine: Optional[sa.engine.Engine] = None def __init__( self, @@ -111,20 +110,36 @@ def __init__( super().__init__(credentials.database, dataset_name, staging_dataset_name, capabilities) self.credentials = credentials + self.engine_args = engine_args or {} + if credentials.engine: - self.engine = credentials.engine + self._engine = credentials.engine self.external_engine = True else: - self.engine = sa.create_engine( - credentials.to_url().render_as_string(hide_password=False), - **(engine_args or {}), - ) + # Default to nullpool because we don't use connection pooling + self.engine_args.setdefault("poolclass", sa.pool.NullPool) self._current_connection: Optional[Connection] = None self._current_transaction: Optional[SqlaTransactionWrapper] = None self.metadata = sa.MetaData() - self.dialect = self.engine.dialect - self.dialect_name = self.dialect.name + + @property + def engine(self) -> sa.engine.Engine: + # Create engine lazily + if self._engine is not None: + return self._engine + self._engine = sa.create_engine( + self.credentials.to_url().render_as_string(hide_password=False), **self.engine_args + ) + return self._engine + + @property + def dialect(self) -> sa.engine.interfaces.Dialect: + return self.engine.dialect + + @property + def dialect_name(self) -> str: + return self.dialect.name def open_connection(self) -> Connection: if self._current_connection is None: @@ -136,6 +151,8 @@ def open_connection(self) -> Connection: def close_connection(self) -> None: if not self.external_engine: try: + if self._current_connection is not None: + self._current_connection.close() self.engine.dispose() finally: self._current_connection = None