diff --git a/.github/workflows/test_local_destinations.yml b/.github/workflows/test_local_destinations.yml index 2d712814bd..a21c3f0618 100644 --- a/.github/workflows/test_local_destinations.yml +++ b/.github/workflows/test_local_destinations.yml @@ -104,6 +104,7 @@ jobs: DESTINATION__POSTGRES__CREDENTIALS: postgresql://loader:loader@localhost:5432/dlt_data DESTINATION__QDRANT__CREDENTIALS__location: http://localhost:6333 + - name: Stop weaviate if: always() run: docker compose -f ".github/weaviate-compose.yml" down -v diff --git a/.github/workflows/test_sqlalchemy_destinations.yml b/.github/workflows/test_sqlalchemy_destinations.yml new file mode 100644 index 0000000000..5da2dac04b --- /dev/null +++ b/.github/workflows/test_sqlalchemy_destinations.yml @@ -0,0 +1,99 @@ +# Tests destinations that can run without credentials. +# i.e. local postgres, duckdb, filesystem (with local fs/memory bucket) + +name: dest | sqlalchemy mysql and sqlite + +on: + pull_request: + branches: + - master + - devel + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +env: + # NOTE: this workflow can't use github secrets! + # DLT_SECRETS_TOML: ${{ secrets.DLT_SECRETS_TOML }} + + RUNTIME__SENTRY_DSN: https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 + RUNTIME__LOG_LEVEL: ERROR + RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} + ACTIVE_DESTINATIONS: "[\"sqlalchemy\"]" + ALL_FILESYSTEM_DRIVERS: "[\"memory\", \"file\"]" + +jobs: + get_docs_changes: + name: docs changes + uses: ./.github/workflows/get_docs_changes.yml + + run_loader: + name: dest | sqlalchemy mysql and sqlite + needs: get_docs_changes + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' + strategy: + fail-fast: false + # Run on sqlalchemy 1.4 and 2.0 + matrix: + sqlalchemy: [1.4, 2] + defaults: + run: + shell: bash + runs-on: "ubuntu-latest" + + # Service containers to run with `container-job` + services: + # Label used to access the service container + mysql: + image: mysql:8 + env: + MYSQL_ROOT_PASSWORD: root + MYSQL_DATABASE: dlt_data + MYSQL_USER: loader + MYSQL_PASSWORD: loader + ports: + - 3306:3306 + # Wait for the service to be ready before completing the job + options: >- + --health-cmd="mysqladmin ping -h localhost -u root -proot" + --health-interval=10s + --health-timeout=5s + --health-retries=5 + + steps: + - name: Check out + uses: actions/checkout@master + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: "3.10.x" + + - name: Install Poetry + uses: snok/install-poetry@v1.3.2 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v3 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-local-destinations + + - name: Install dependencies + run: poetry install --no-interaction -E parquet -E filesystem -E sqlalchemy -E cli --with sentry-sdk --with pipeline && poetry run pip install mysqlclient && poetry run pip install "sqlalchemy==${{ matrix.sqlalchemy }}" + + - name: create secrets.toml + run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml + + # always run full suite, also on branches + - run: poetry run pytest tests/load -x --ignore tests/load/sources + name: Run tests Linux + env: + DESTINATION__SQLALCHEMY_MYSQL__CREDENTIALS: mysql://root:root@127.0.0.1:3306/dlt_data # Use root cause we need to create databases + DESTINATION__SQLALCHEMY_SQLITE__CREDENTIALS: sqlite:///_storage/dl_data.sqlite diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index eed1d6189e..8f0dce79ce 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -176,6 +176,12 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): loader_parallelism_strategy: Optional[TLoaderParallelismStrategy] = None """The destination can override the parallelism strategy""" + max_query_parameters: Optional[int] = None + """The maximum number of parameters that can be supplied in a single parametrized query""" + + supports_native_boolean: bool = True + """The destination supports a native boolean type, otherwise bool columns are usually stored as integers""" + def generates_case_sensitive_identifiers(self) -> bool: """Tells if capabilities as currently adjusted, will generate case sensitive identifiers""" # must have case sensitive support and folding function must preserve casing @@ -220,8 +226,8 @@ def generic_capabilities( caps.merge_strategies_selector = merge_strategies_selector return caps - def get_type_mapper(self) -> DataTypeMapper: - return self.type_mapper(self) + def get_type_mapper(self, *args: Any, **kwargs: Any) -> DataTypeMapper: + return self.type_mapper(self, *args, **kwargs) def merge_caps_file_formats( diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 05ea5f3515..9e27b66335 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -86,6 +86,20 @@ def from_normalized_mapping( schema=normalized_doc[naming_convention.normalize_identifier("schema")], ) + def to_normalized_mapping(self, naming_convention: NamingConvention) -> Dict[str, Any]: + """Convert this instance to mapping where keys are normalized according to given naming convention + + Args: + naming_convention: Naming convention that should be used to normalize keys + + Returns: + Dict[str, Any]: Mapping with normalized keys (e.g. {Version: ..., SchemaName: ...}) + """ + return { + naming_convention.normalize_identifier(key): value + for key, value in self._asdict().items() + } + @dataclasses.dataclass class StateInfo: @@ -439,7 +453,7 @@ def __init__( self.capabilities = capabilities @abstractmethod - def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: + def initialize_storage(self, truncate_tables: Optional[Iterable[str]] = None) -> None: """Prepares storage to be used ie. creates database schema or file system folder. Truncates requested tables.""" pass diff --git a/dlt/common/json/__init__.py b/dlt/common/json/__init__.py index 72ab453cbf..fe762cdf11 100644 --- a/dlt/common/json/__init__.py +++ b/dlt/common/json/__init__.py @@ -19,35 +19,7 @@ from dlt.common.utils import map_nested_in_place -class SupportsJson(Protocol): - """Minimum adapter for different json parser implementations""" - - _impl_name: str - """Implementation name""" - - def dump( - self, obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty: bool = False - ) -> None: ... - - def typed_dump(self, obj: Any, fp: IO[bytes], pretty: bool = False) -> None: ... - - def typed_dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ... - - def typed_loads(self, s: str) -> Any: ... - - def typed_dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ... - - def typed_loadb(self, s: Union[bytes, bytearray, memoryview]) -> Any: ... - - def dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ... - - def dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ... - - def load(self, fp: Union[IO[bytes], IO[str]]) -> Any: ... - - def loads(self, s: str) -> Any: ... - - def loadb(self, s: Union[bytes, bytearray, memoryview]) -> Any: ... +TPuaDecoders = List[Callable[[Any], Any]] def custom_encode(obj: Any) -> str: @@ -104,7 +76,7 @@ def _datetime_decoder(obj: str) -> datetime: # define decoder for each prefix -DECODERS: List[Callable[[Any], Any]] = [ +DECODERS: TPuaDecoders = [ Decimal, _datetime_decoder, pendulum.Date.fromisoformat, @@ -114,6 +86,11 @@ def _datetime_decoder(obj: str) -> datetime: Wei, pendulum.Time.fromisoformat, ] +# Alternate decoders that decode date/time/datetime to stdlib types instead of pendulum +PY_DATETIME_DECODERS = list(DECODERS) +PY_DATETIME_DECODERS[1] = datetime.fromisoformat +PY_DATETIME_DECODERS[2] = date.fromisoformat +PY_DATETIME_DECODERS[7] = time.fromisoformat # how many decoders? PUA_CHARACTER_MAX = len(DECODERS) @@ -151,13 +128,13 @@ def custom_pua_encode(obj: Any) -> str: raise TypeError(repr(obj) + " is not JSON serializable") -def custom_pua_decode(obj: Any) -> Any: +def custom_pua_decode(obj: Any, decoders: TPuaDecoders = DECODERS) -> Any: if isinstance(obj, str) and len(obj) > 1: c = ord(obj[0]) - PUA_START # decode only the PUA space defined in DECODERS if c >= 0 and c <= PUA_CHARACTER_MAX: try: - return DECODERS[c](obj[1:]) + return decoders[c](obj[1:]) except Exception: # return strings that cannot be parsed # this may be due @@ -167,11 +144,11 @@ def custom_pua_decode(obj: Any) -> Any: return obj -def custom_pua_decode_nested(obj: Any) -> Any: +def custom_pua_decode_nested(obj: Any, decoders: TPuaDecoders = DECODERS) -> Any: if isinstance(obj, str): - return custom_pua_decode(obj) + return custom_pua_decode(obj, decoders) elif isinstance(obj, (list, dict)): - return map_nested_in_place(custom_pua_decode, obj) + return map_nested_in_place(custom_pua_decode, obj, decoders=decoders) return obj @@ -190,6 +167,39 @@ def may_have_pua(line: bytes) -> bool: return PUA_START_UTF8_MAGIC in line +class SupportsJson(Protocol): + """Minimum adapter for different json parser implementations""" + + _impl_name: str + """Implementation name""" + + def dump( + self, obj: Any, fp: IO[bytes], sort_keys: bool = False, pretty: bool = False + ) -> None: ... + + def typed_dump(self, obj: Any, fp: IO[bytes], pretty: bool = False) -> None: ... + + def typed_dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ... + + def typed_loads(self, s: str) -> Any: ... + + def typed_dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ... + + def typed_loadb( + self, s: Union[bytes, bytearray, memoryview], decoders: TPuaDecoders = DECODERS + ) -> Any: ... + + def dumps(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: ... + + def dumpb(self, obj: Any, sort_keys: bool = False, pretty: bool = False) -> bytes: ... + + def load(self, fp: Union[IO[bytes], IO[str]]) -> Any: ... + + def loads(self, s: str) -> Any: ... + + def loadb(self, s: Union[bytes, bytearray, memoryview]) -> Any: ... + + # pick the right impl json: SupportsJson = None if os.environ.get(known_env.DLT_USE_JSON) == "simplejson": @@ -216,4 +226,7 @@ def may_have_pua(line: bytes) -> bool: "custom_pua_remove", "SupportsJson", "may_have_pua", + "TPuaDecoders", + "DECODERS", + "PY_DATETIME_DECODERS", ] diff --git a/dlt/common/json/_orjson.py b/dlt/common/json/_orjson.py index d2d960e6ce..d066ffe875 100644 --- a/dlt/common/json/_orjson.py +++ b/dlt/common/json/_orjson.py @@ -1,7 +1,13 @@ from typing import IO, Any, Union import orjson -from dlt.common.json import custom_pua_encode, custom_pua_decode_nested, custom_encode +from dlt.common.json import ( + custom_pua_encode, + custom_pua_decode_nested, + custom_encode, + TPuaDecoders, + DECODERS, +) from dlt.common.typing import AnyFun _impl_name = "orjson" @@ -38,8 +44,8 @@ def typed_loads(s: str) -> Any: return custom_pua_decode_nested(loads(s)) -def typed_loadb(s: Union[bytes, bytearray, memoryview]) -> Any: - return custom_pua_decode_nested(loadb(s)) +def typed_loadb(s: Union[bytes, bytearray, memoryview], decoders: TPuaDecoders = DECODERS) -> Any: + return custom_pua_decode_nested(loadb(s), decoders) def dumps(obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: diff --git a/dlt/common/json/_simplejson.py b/dlt/common/json/_simplejson.py index 10ee17e2f6..e5adcc7120 100644 --- a/dlt/common/json/_simplejson.py +++ b/dlt/common/json/_simplejson.py @@ -4,7 +4,13 @@ import simplejson import platform -from dlt.common.json import custom_pua_encode, custom_pua_decode_nested, custom_encode +from dlt.common.json import ( + custom_pua_encode, + custom_pua_decode_nested, + custom_encode, + TPuaDecoders, + DECODERS, +) if platform.python_implementation() == "PyPy": # disable speedups on PyPy, it can be actually faster than Python C @@ -73,8 +79,8 @@ def typed_dumpb(obj: Any, sort_keys: bool = False, pretty: bool = False) -> byte return typed_dumps(obj, sort_keys, pretty).encode("utf-8") -def typed_loadb(s: Union[bytes, bytearray, memoryview]) -> Any: - return custom_pua_decode_nested(loadb(s)) +def typed_loadb(s: Union[bytes, bytearray, memoryview], decoders: TPuaDecoders = DECODERS) -> Any: + return custom_pua_decode_nested(loadb(s), decoders) def dumps(obj: Any, sort_keys: bool = False, pretty: bool = False) -> str: diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 51d67275af..adba832c43 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -30,6 +30,7 @@ import pyarrow.parquet import pyarrow.compute import pyarrow.dataset + from pyarrow.parquet import ParquetFile except ModuleNotFoundError: raise MissingDependencyException( "dlt pyarrow helpers", diff --git a/dlt/common/time.py b/dlt/common/time.py index 8532f566b8..26de0b5645 100644 --- a/dlt/common/time.py +++ b/dlt/common/time.py @@ -143,6 +143,14 @@ def ensure_pendulum_time(value: Union[str, datetime.time]) -> pendulum.Time: return result else: raise ValueError(f"{value} is not a valid ISO time string.") + elif isinstance(value, timedelta): + # Assume timedelta is seconds passed since midnight. Some drivers (mysqlclient) return time in this format + return pendulum.time( + value.seconds // 3600, + (value.seconds // 60) % 60, + value.seconds % 60, + value.microseconds, + ) raise TypeError(f"Cannot coerce {value} to a pendulum.Time object.") diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 37b644c0b5..9980e725ee 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -282,8 +282,10 @@ def clone_dict_nested(src: TDict) -> TDict: return update_dict_nested({}, src, copy_src_dicts=True) # type: ignore[return-value] -def map_nested_in_place(func: AnyFun, _nested: TAny) -> TAny: - """Applies `func` to all elements in `_dict` recursively, replacing elements in nested dictionaries and lists in place.""" +def map_nested_in_place(func: AnyFun, _nested: TAny, *args: Any, **kwargs: Any) -> TAny: + """Applies `func` to all elements in `_dict` recursively, replacing elements in nested dictionaries and lists in place. + Additional `*args` and `**kwargs` are passed to `func`. + """ if isinstance(_nested, tuple): if hasattr(_nested, "_asdict"): _nested = _nested._asdict() @@ -293,15 +295,15 @@ def map_nested_in_place(func: AnyFun, _nested: TAny) -> TAny: if isinstance(_nested, dict): for k, v in _nested.items(): if isinstance(v, (dict, list, tuple)): - _nested[k] = map_nested_in_place(func, v) + _nested[k] = map_nested_in_place(func, v, *args, **kwargs) else: - _nested[k] = func(v) + _nested[k] = func(v, *args, **kwargs) elif isinstance(_nested, list): for idx, _l in enumerate(_nested): if isinstance(_l, (dict, list, tuple)): - _nested[idx] = map_nested_in_place(func, _l) + _nested[idx] = map_nested_in_place(func, _l, *args, **kwargs) else: - _nested[idx] = func(_l) + _nested[idx] = func(_l, *args, **kwargs) else: raise ValueError(_nested, "Not a nested type") return _nested diff --git a/dlt/destinations/__init__.py b/dlt/destinations/__init__.py index 0546d16bcd..a856f574d8 100644 --- a/dlt/destinations/__init__.py +++ b/dlt/destinations/__init__.py @@ -16,6 +16,7 @@ from dlt.destinations.impl.databricks.factory import databricks from dlt.destinations.impl.dremio.factory import dremio from dlt.destinations.impl.clickhouse.factory import clickhouse +from dlt.destinations.impl.sqlalchemy.factory import sqlalchemy __all__ = [ @@ -37,4 +38,5 @@ "dremio", "clickhouse", "destination", + "sqlalchemy", ] diff --git a/dlt/destinations/impl/sqlalchemy/__init__.py b/dlt/destinations/impl/sqlalchemy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dlt/destinations/impl/sqlalchemy/alter_table.py b/dlt/destinations/impl/sqlalchemy/alter_table.py new file mode 100644 index 0000000000..f85101a740 --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/alter_table.py @@ -0,0 +1,38 @@ +from typing import List + +import sqlalchemy as sa +from alembic.runtime.migration import MigrationContext +from alembic.operations import Operations + + +class ListBuffer: + """A partial implementation of string IO to use with alembic. + SQL statements are stored in a list instead of file/stdio + """ + + def __init__(self) -> None: + self._buf = "" + self.sql_lines: List[str] = [] + + def write(self, data: str) -> None: + self._buf += data + + def flush(self) -> None: + if self._buf: + self.sql_lines.append(self._buf) + self._buf = "" + + +class MigrationMaker: + def __init__(self, dialect: sa.engine.Dialect) -> None: + self._buf = ListBuffer() + self.ctx = MigrationContext(dialect, None, {"as_sql": True, "output_buffer": self._buf}) + self.ops = Operations(self.ctx) + + def add_column(self, table_name: str, column: sa.Column, schema: str) -> None: + self.ops.add_column(table_name, column, schema=schema) + + def consume_statements(self) -> List[str]: + lines = self._buf.sql_lines[:] + self._buf.sql_lines.clear() + return lines diff --git a/dlt/destinations/impl/sqlalchemy/configuration.py b/dlt/destinations/impl/sqlalchemy/configuration.py new file mode 100644 index 0000000000..f99b06a27b --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/configuration.py @@ -0,0 +1,63 @@ +from typing import TYPE_CHECKING, Optional, Any, Final, Type, Dict, Union +import dataclasses + +from dlt.common.configuration import configspec +from dlt.common.configuration.specs import ConnectionStringCredentials +from dlt.common.destination.reference import DestinationClientDwhConfiguration + +if TYPE_CHECKING: + from sqlalchemy.engine import Engine, Dialect + + +@configspec(init=False) +class SqlalchemyCredentials(ConnectionStringCredentials): + if TYPE_CHECKING: + _engine: Optional["Engine"] = None + + username: Optional[str] = None # e.g. sqlite doesn't need username + + def __init__( + self, connection_string: Optional[Union[str, Dict[str, Any], "Engine"]] = None + ) -> None: + super().__init__(connection_string) # type: ignore[arg-type] + + def parse_native_representation(self, native_value: Any) -> None: + from sqlalchemy.engine import Engine + + if isinstance(native_value, Engine): + self.engine = native_value + super().parse_native_representation( + native_value.url.render_as_string(hide_password=False) + ) + else: + super().parse_native_representation(native_value) + + @property + def engine(self) -> Optional["Engine"]: + return getattr(self, "_engine", None) # type: ignore[no-any-return] + + @engine.setter + def engine(self, value: "Engine") -> None: + self._engine = value + + def get_dialect(self) -> Optional[Type["Dialect"]]: + if not self.drivername: + return None + # Type-ignore because of ported URL class has no get_dialect method, + # but here sqlalchemy should be available + if engine := self.engine: + return type(engine.dialect) + return self.to_url().get_dialect() # type: ignore[attr-defined,no-any-return] + + +@configspec +class SqlalchemyClientConfiguration(DestinationClientDwhConfiguration): + destination_type: Final[str] = dataclasses.field(default="sqlalchemy", init=False, repr=False, compare=False) # type: ignore + credentials: SqlalchemyCredentials = None + """SQLAlchemy connection string""" + + engine_args: Dict[str, Any] = dataclasses.field(default_factory=dict) + """Additional arguments passed to `sqlalchemy.create_engine`""" + + def get_dialect(self) -> Type["Dialect"]: + return self.credentials.get_dialect() diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py new file mode 100644 index 0000000000..c6c8ba53d6 --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -0,0 +1,432 @@ +from typing import ( + Optional, + Iterator, + Any, + Sequence, + ContextManager, + AnyStr, + Union, + Tuple, + List, + Dict, +) +from contextlib import contextmanager +from functools import wraps +import inspect +from pathlib import Path + +import sqlalchemy as sa +from sqlalchemy.engine import Connection + +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.destinations.exceptions import ( + DatabaseUndefinedRelation, + DatabaseTerminalException, + DatabaseTransientException, + LoadClientNotConnected, + DatabaseException, +) +from dlt.destinations.typing import DBTransaction, DBApiCursor +from dlt.destinations.sql_client import SqlClientBase, DBApiCursorImpl +from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyCredentials +from dlt.destinations.impl.sqlalchemy.alter_table import MigrationMaker +from dlt.common.typing import TFun + + +class SqlaTransactionWrapper(DBTransaction): + def __init__(self, sqla_transaction: sa.engine.Transaction) -> None: + self.sqla_transaction = sqla_transaction + + def commit_transaction(self) -> None: + if self.sqla_transaction.is_active: + self.sqla_transaction.commit() + + def rollback_transaction(self) -> None: + if self.sqla_transaction.is_active: + self.sqla_transaction.rollback() + + +def raise_database_error(f: TFun) -> TFun: + @wraps(f) + def _wrap_gen(self: "SqlalchemyClient", *args: Any, **kwargs: Any) -> Any: + try: + return (yield from f(self, *args, **kwargs)) + except Exception as e: + raise self._make_database_exception(e) from e + + @wraps(f) + def _wrap(self: "SqlalchemyClient", *args: Any, **kwargs: Any) -> Any: + try: + return f(self, *args, **kwargs) + except Exception as e: + raise self._make_database_exception(e) from e + + if inspect.isgeneratorfunction(f): + return _wrap_gen # type: ignore[return-value] + return _wrap # type: ignore[return-value] + + +class SqlaDbApiCursor(DBApiCursorImpl): + def __init__(self, curr: sa.engine.CursorResult) -> None: + # Sqlalchemy CursorResult is *mostly* compatible with DB-API cursor + self.native_cursor = curr # type: ignore[assignment] + curr.columns + + self.fetchall = curr.fetchall # type: ignore[assignment] + self.fetchone = curr.fetchone # type: ignore[assignment] + self.fetchmany = curr.fetchmany # type: ignore[assignment] + + def _get_columns(self) -> List[str]: + return list(self.native_cursor.keys()) # type: ignore[attr-defined] + + # @property + # def description(self) -> Any: + # # Get the underlying driver's cursor description, this is mostly used in tests + # return self.native_cursor.cursor.description # type: ignore[attr-defined] + + def execute(self, query: AnyStr, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("execute not implemented") + + +class DbApiProps: + # Only needed for some tests + paramstyle = "named" + + +class SqlalchemyClient(SqlClientBase[Connection]): + external_engine: bool = False + dbapi = DbApiProps # type: ignore[assignment] + migrations: Optional[MigrationMaker] = None # lazy init as needed + _engine: Optional[sa.engine.Engine] = None + + def __init__( + self, + dataset_name: str, + staging_dataset_name: str, + credentials: SqlalchemyCredentials, + capabilities: DestinationCapabilitiesContext, + engine_args: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(credentials.database, dataset_name, staging_dataset_name, capabilities) + self.credentials = credentials + + self.engine_args = engine_args or {} + + if credentials.engine: + self._engine = credentials.engine + self.external_engine = True + else: + # Default to nullpool because we don't use connection pooling + self.engine_args.setdefault("poolclass", sa.pool.NullPool) + + self._current_connection: Optional[Connection] = None + self._current_transaction: Optional[SqlaTransactionWrapper] = None + self.metadata = sa.MetaData() + + @property + def engine(self) -> sa.engine.Engine: + # Create engine lazily + if self._engine is not None: + return self._engine + self._engine = sa.create_engine( + self.credentials.to_url().render_as_string(hide_password=False), **self.engine_args + ) + return self._engine + + @property + def dialect(self) -> sa.engine.interfaces.Dialect: + return self.engine.dialect + + @property + def dialect_name(self) -> str: + return self.dialect.name + + def open_connection(self) -> Connection: + if self._current_connection is None: + self._current_connection = self.engine.connect() + if self.dialect_name == "sqlite": + self._sqlite_reattach_dataset_if_exists(self.dataset_name) + return self._current_connection + + def close_connection(self) -> None: + if not self.external_engine: + try: + if self._current_connection is not None: + self._current_connection.close() + self.engine.dispose() + finally: + self._current_connection = None + self._current_transaction = None + + @property + def native_connection(self) -> Connection: + if not self._current_connection: + raise LoadClientNotConnected(type(self).__name__, self.dataset_name) + return self._current_connection + + def _in_transaction(self) -> bool: + return ( + self._current_transaction is not None + and self._current_transaction.sqla_transaction.is_active + ) + + @contextmanager + @raise_database_error + def begin_transaction(self) -> Iterator[DBTransaction]: + trans = self._current_transaction = SqlaTransactionWrapper(self._current_connection.begin()) + try: + yield trans + except Exception: + if self._in_transaction(): + self.rollback_transaction() + raise + else: + if self._in_transaction(): # Transaction could be committed/rolled back before __exit__ + self.commit_transaction() + finally: + self._current_transaction = None + + def commit_transaction(self) -> None: + """Commits the current transaction.""" + self._current_transaction.commit_transaction() + + def rollback_transaction(self) -> None: + """Rolls back the current transaction.""" + self._current_transaction.rollback_transaction() + + @contextmanager + def _transaction(self) -> Iterator[DBTransaction]: + """Context manager yielding either a new or the currently open transaction. + New transaction will be committed/rolled back on exit. + If the transaction is already open, finalization is handled by the top level context manager. + """ + if self._in_transaction(): + yield self._current_transaction + return + with self.begin_transaction() as tx: + yield tx + + def has_dataset(self) -> bool: + with self._transaction(): + schema_names = self.engine.dialect.get_schema_names(self._current_connection) # type: ignore[attr-defined] + return self.dataset_name in schema_names + + def _sqlite_dataset_filename(self, dataset_name: str) -> str: + db_name = self.engine.url.database + current_file_path = Path(db_name) + return str( + current_file_path.parent + / f"{current_file_path.stem}__{dataset_name}{current_file_path.suffix}" + ) + + def _sqlite_is_memory_db(self) -> bool: + return self.engine.url.database == ":memory:" + + def _sqlite_reattach_dataset_if_exists(self, dataset_name: str) -> None: + """Re-attach previously created databases for a new sqlite connection""" + if self._sqlite_is_memory_db(): + return + new_db_fn = self._sqlite_dataset_filename(dataset_name) + if Path(new_db_fn).exists(): + self._sqlite_create_dataset(dataset_name) + + def _sqlite_create_dataset(self, dataset_name: str) -> None: + """Mimic multiple schemas in sqlite using ATTACH DATABASE to + attach a new database file to the current connection. + """ + if self._sqlite_is_memory_db(): + new_db_fn = ":memory:" + else: + new_db_fn = self._sqlite_dataset_filename(dataset_name) + + statement = "ATTACH DATABASE :fn AS :name" + self.execute_sql(statement, fn=new_db_fn, name=dataset_name) + + def _sqlite_drop_dataset(self, dataset_name: str) -> None: + """Drop a dataset in sqlite by detaching the database file + attached to the current connection. + """ + # Get a list of attached databases and filenames + rows = self.execute_sql("PRAGMA database_list") + dbs = {row[1]: row[2] for row in rows} # db_name: filename + if dataset_name != "main": # main is the default database, it cannot be detached + statement = "DETACH DATABASE :name" + self.execute_sql(statement, name=dataset_name) + + fn = dbs[dataset_name] + if not fn: # It's a memory database, nothing to do + return + # Delete the database file + Path(fn).unlink() + + def create_dataset(self) -> None: + if self.dialect_name == "sqlite": + return self._sqlite_create_dataset(self.dataset_name) + self.execute_sql(sa.schema.CreateSchema(self.dataset_name)) + + def drop_dataset(self) -> None: + if self.dialect_name == "sqlite": + return self._sqlite_drop_dataset(self.dataset_name) + try: + self.execute_sql(sa.schema.DropSchema(self.dataset_name, cascade=True)) + except DatabaseException: # Try again in case cascade is not supported + self.execute_sql(sa.schema.DropSchema(self.dataset_name)) + + def truncate_tables(self, *tables: str) -> None: + # TODO: alchemy doesn't have a construct for TRUNCATE TABLE + for table in tables: + tbl = sa.Table(table, self.metadata, schema=self.dataset_name, keep_existing=True) + self.execute_sql(tbl.delete()) + + def drop_tables(self, *tables: str) -> None: + for table in tables: + tbl = sa.Table(table, self.metadata, schema=self.dataset_name, keep_existing=True) + self.execute_sql(sa.schema.DropTable(tbl, if_exists=True)) + + def execute_sql( + self, sql: Union[AnyStr, sa.sql.Executable], *args: Any, **kwargs: Any + ) -> Optional[Sequence[Sequence[Any]]]: + with self.execute_query(sql, *args, **kwargs) as cursor: + if cursor.returns_rows: # type: ignore[attr-defined] + return cursor.fetchall() + return None + + @contextmanager + def execute_query( + self, query: Union[AnyStr, sa.sql.Executable], *args: Any, **kwargs: Any + ) -> Iterator[DBApiCursor]: + if args and kwargs: + raise ValueError("Cannot use both positional and keyword arguments") + if isinstance(query, str): + if args: + # Sqlalchemy text supports :named paramstyle for all dialects + query, kwargs = self._to_named_paramstyle(query, args) # type: ignore[assignment] + args = (kwargs,) + query = sa.text(query) + if kwargs: + # sqla2 takes either a dict or list of dicts + args = (kwargs,) + with self._transaction(): + yield SqlaDbApiCursor(self._current_connection.execute(query, *args)) # type: ignore[call-overload, abstract] + + def get_existing_table(self, table_name: str) -> Optional[sa.Table]: + """Get a table object from metadata if it exists""" + key = self.dataset_name + "." + table_name + return self.metadata.tables.get(key) # type: ignore[no-any-return] + + def create_table(self, table_obj: sa.Table) -> None: + with self._transaction(): + table_obj.create(self._current_connection) + + def _make_qualified_table_name(self, table: sa.Table, escape: bool = True) -> str: + if escape: + return self.dialect.identifier_preparer.format_table(table) # type: ignore[attr-defined,no-any-return] + return table.fullname # type: ignore[no-any-return] + + def make_qualified_table_name(self, table_name: str, escape: bool = True) -> str: + tbl = self.get_existing_table(table_name) + if tbl is None: + tmp_metadata = sa.MetaData() + tbl = sa.Table(table_name, tmp_metadata, schema=self.dataset_name) + return self._make_qualified_table_name(tbl, escape) + + def fully_qualified_dataset_name(self, escape: bool = True, staging: bool = False) -> str: + if staging: + raise NotImplementedError("Staging not supported") + return self.dialect.identifier_preparer.format_schema(self.dataset_name) # type: ignore[attr-defined, no-any-return] + + def alter_table_add_columns(self, columns: Sequence[sa.Column]) -> None: + if not columns: + return + if self.migrations is None: + self.migrations = MigrationMaker(self.dialect) + for column in columns: + self.migrations.add_column(column.table.name, column, self.dataset_name) + statements = self.migrations.consume_statements() + for statement in statements: + self.execute_sql(statement) + + def escape_column_name(self, column_name: str, escape: bool = True) -> str: + if self.dialect.requires_name_normalize: # type: ignore[attr-defined] + column_name = self.dialect.normalize_name(column_name) # type: ignore[func-returns-value] + if escape: + return self.dialect.identifier_preparer.format_column(sa.Column(column_name)) # type: ignore[attr-defined,no-any-return] + return column_name + + def compile_column_def(self, column: sa.Column) -> str: + """Compile a column definition including type for ADD COLUMN clause""" + return str(sa.schema.CreateColumn(column).compile(self.engine)) + + def reflect_table( + self, + table_name: str, + metadata: Optional[sa.MetaData] = None, + include_columns: Optional[Sequence[str]] = None, + ) -> Optional[sa.Table]: + """Reflect a table from the database and return the Table object""" + if metadata is None: + metadata = self.metadata + try: + with self._transaction(): + return sa.Table( + table_name, + metadata, + autoload_with=self._current_connection, + schema=self.dataset_name, + include_columns=include_columns, + extend_existing=True, + ) + except DatabaseUndefinedRelation: + return None + + def compare_storage_table(self, table_name: str) -> Tuple[sa.Table, List[sa.Column], bool]: + """Reflect the table from database and compare it with the version already in metadata. + Returns a 3 part tuple: + - The current version of the table in metadata + - List of columns that are missing from the storage table (all columns if it doesn't exist in storage) + - boolean indicating whether the table exists in storage + """ + existing = self.get_existing_table(table_name) + assert existing is not None, "Table must be present in metadata" + all_columns = list(existing.columns) + all_column_names = [c.name for c in all_columns] + tmp_metadata = sa.MetaData() + reflected = self.reflect_table( + table_name, include_columns=all_column_names, metadata=tmp_metadata + ) + if reflected is None: + missing_columns = all_columns + else: + missing_columns = [c for c in all_columns if c.name not in reflected.columns] + return existing, missing_columns, reflected is not None + + @staticmethod + def _make_database_exception(e: Exception) -> Exception: + if isinstance(e, sa.exc.NoSuchTableError): + return DatabaseUndefinedRelation(e) + msg = str(e).lower() + if isinstance(e, (sa.exc.ProgrammingError, sa.exc.OperationalError)): + if "exist" in msg: # TODO: Hack + return DatabaseUndefinedRelation(e) + elif "unknown table" in msg: + return DatabaseUndefinedRelation(e) + elif "unknown database" in msg: + return DatabaseUndefinedRelation(e) + elif "no such table" in msg: # sqlite # TODO: Hack + return DatabaseUndefinedRelation(e) + elif "no such database" in msg: # sqlite # TODO: Hack + return DatabaseUndefinedRelation(e) + elif "syntax" in msg: + return DatabaseTransientException(e) + elif isinstance(e, (sa.exc.OperationalError, sa.exc.IntegrityError)): + return DatabaseTerminalException(e) + return DatabaseTransientException(e) + elif isinstance(e, sa.exc.SQLAlchemyError): + return DatabaseTransientException(e) + else: + return e + # return DatabaseTerminalException(e) + + def _ensure_native_conn(self) -> None: + if not self.native_connection: + raise LoadClientNotConnected(type(self).__name__, self.dataset_name) diff --git a/dlt/destinations/impl/sqlalchemy/factory.py b/dlt/destinations/impl/sqlalchemy/factory.py new file mode 100644 index 0000000000..10372cda34 --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/factory.py @@ -0,0 +1,99 @@ +import typing as t + +from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.destination.capabilities import DataTypeMapper +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.normalizers import NamingConvention + +from dlt.destinations.impl.sqlalchemy.configuration import ( + SqlalchemyCredentials, + SqlalchemyClientConfiguration, +) + +SqlalchemyTypeMapper: t.Type[DataTypeMapper] + +try: + from dlt.destinations.impl.sqlalchemy.type_mapper import SqlalchemyTypeMapper +except ModuleNotFoundError: + # assign mock type mapper if no sqlalchemy + from dlt.common.destination.capabilities import UnsupportedTypeMapper as SqlalchemyTypeMapper + +if t.TYPE_CHECKING: + # from dlt.destinations.impl.sqlalchemy.sqlalchemy_client import SqlalchemyJobClient + from dlt.destinations.impl.sqlalchemy.sqlalchemy_job_client import SqlalchemyJobClient + + +class sqlalchemy(Destination[SqlalchemyClientConfiguration, "SqlalchemyJobClient"]): + spec = SqlalchemyClientConfiguration + + def _raw_capabilities(self) -> DestinationCapabilitiesContext: + # https://www.sqlalchemyql.org/docs/current/limits.html + caps = DestinationCapabilitiesContext.generic_capabilities() + caps.preferred_loader_file_format = "typed-jsonl" + caps.supported_loader_file_formats = ["typed-jsonl", "parquet"] + caps.preferred_staging_file_format = None + caps.supported_staging_file_formats = [] + caps.has_case_sensitive_identifiers = True + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 63 + caps.max_column_identifier_length = 63 + caps.max_query_length = 32 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 1024 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = True + caps.max_query_parameters = 20_0000 + caps.max_rows_per_insert = 10_000 # Set a default to avoid OOM on large datasets + caps.type_mapper = SqlalchemyTypeMapper + + return caps + + @classmethod + def adjust_capabilities( + cls, + caps: DestinationCapabilitiesContext, + config: SqlalchemyClientConfiguration, + naming: t.Optional[NamingConvention], + ) -> DestinationCapabilitiesContext: + caps = super(sqlalchemy, cls).adjust_capabilities(caps, config, naming) + dialect = config.get_dialect() + if dialect is None: + return caps + caps.max_identifier_length = dialect.max_identifier_length + caps.max_column_identifier_length = dialect.max_identifier_length + caps.supports_native_boolean = dialect.supports_native_boolean + + return caps + + @property + def client_class(self) -> t.Type["SqlalchemyJobClient"]: + from dlt.destinations.impl.sqlalchemy.sqlalchemy_job_client import SqlalchemyJobClient + + return SqlalchemyJobClient + + def __init__( + self, + credentials: t.Union[SqlalchemyCredentials, t.Dict[str, t.Any], str] = None, + destination_name: t.Optional[str] = None, + environment: t.Optional[str] = None, + engine_args: t.Optional[t.Dict[str, t.Any]] = None, + **kwargs: t.Any, + ) -> None: + """Configure the Sqlalchemy destination to use in a pipeline. + + All arguments provided here supersede other configuration sources such as environment variables and dlt config files. + + Args: + credentials: Credentials to connect to the sqlalchemy database. Can be an instance of `SqlalchemyCredentials` or + a connection string in the format `mysql://user:password@host:port/database` + destination_name: The name of the destination + environment: The environment to use + **kwargs: Additional arguments passed to the destination + """ + super().__init__( + credentials=credentials, + destination_name=destination_name, + environment=environment, + **kwargs, + ) diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py new file mode 100644 index 0000000000..c51d3cbe3a --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -0,0 +1,359 @@ +from typing import Iterable, Optional, Dict, Any, Iterator, Sequence, List, Tuple, IO +from contextlib import suppress +import math + +import sqlalchemy as sa + +from dlt.common import logger +from dlt.common import pendulum +from dlt.common.destination.reference import ( + JobClientBase, + LoadJob, + RunnableLoadJob, + StorageSchemaInfo, + StateInfo, + PreparedTableSchema, +) +from dlt.destinations.job_client_impl import SqlJobClientBase +from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.schema import Schema, TTableSchema, TColumnSchema, TSchemaTables +from dlt.common.schema.typing import TColumnType, TTableSchemaColumns +from dlt.common.schema.utils import pipeline_state_table, normalize_table_identifiers +from dlt.common.storages import FileStorage +from dlt.common.json import json, PY_DATETIME_DECODERS +from dlt.destinations.exceptions import DatabaseUndefinedRelation + + +# from dlt.destinations.impl.sqlalchemy.sql_client import SqlalchemyClient +from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient +from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyClientConfiguration + + +class SqlalchemyJsonLInsertJob(RunnableLoadJob): + def __init__(self, file_path: str, table: sa.Table) -> None: + super().__init__(file_path) + self._job_client: "SqlalchemyJobClient" = None + self.table = table + + def _open_load_file(self) -> IO[bytes]: + return FileStorage.open_zipsafe_ro(self._file_path, "rb") + + def _iter_data_items(self) -> Iterator[Dict[str, Any]]: + all_cols = {col.name: None for col in self.table.columns} + with FileStorage.open_zipsafe_ro(self._file_path, "rb") as f: + for line in f: + # Decode date/time to py datetime objects. Some drivers have issues with pendulum objects + for item in json.typed_loadb(line, decoders=PY_DATETIME_DECODERS): + # Fill any missing columns in item with None. Bulk insert fails when items have different keys + if item.keys() != all_cols.keys(): + yield {**all_cols, **item} + else: + yield item + + def _iter_data_item_chunks(self) -> Iterator[Sequence[Dict[str, Any]]]: + max_rows = self._job_client.capabilities.max_rows_per_insert or math.inf + # Limit by max query length should not be needed, + # bulk insert generates an INSERT template with a single VALUES tuple of placeholders + # If any dialects don't do that we need to check the str length of the query + # TODO: Max params may not be needed. Limits only apply to placeholders in sql string (mysql/sqlite) + max_params = self._job_client.capabilities.max_query_parameters or math.inf + chunk: List[Dict[str, Any]] = [] + params_count = 0 + for item in self._iter_data_items(): + if len(chunk) + 1 == max_rows or params_count + len(item) > max_params: + # Rotate chunk + yield chunk + chunk = [] + params_count = 0 + params_count += len(item) + chunk.append(item) + + if chunk: + yield chunk + + def run(self) -> None: + _sql_client = self._job_client.sql_client + + with _sql_client.begin_transaction(): + for chunk in self._iter_data_item_chunks(): + _sql_client.execute_sql(self.table.insert(), chunk) + + +class SqlalchemyParquetInsertJob(SqlalchemyJsonLInsertJob): + def _iter_data_item_chunks(self) -> Iterator[Sequence[Dict[str, Any]]]: + from dlt.common.libs.pyarrow import ParquetFile + + num_cols = len(self.table.columns) + max_rows = self._job_client.capabilities.max_rows_per_insert or None + max_params = self._job_client.capabilities.max_query_parameters or None + read_limit = None + + with ParquetFile(self._file_path) as reader: + if max_params is not None: + read_limit = math.floor(max_params / num_cols) + + if max_rows is not None: + if read_limit is None: + read_limit = max_rows + else: + read_limit = min(read_limit, max_rows) + + if read_limit is None: + yield reader.read().to_pylist() + return + + for chunk in reader.iter_batches(batch_size=read_limit): + yield chunk.to_pylist() + + +class SqlalchemyJobClient(SqlJobClientBase): + sql_client: SqlalchemyClient # type: ignore[assignment] + + def __init__( + self, + schema: Schema, + config: SqlalchemyClientConfiguration, + capabilities: DestinationCapabilitiesContext, + ) -> None: + self.sql_client = SqlalchemyClient( + config.normalize_dataset_name(schema), + None, + config.credentials, + capabilities, + engine_args=config.engine_args, + ) + + self.schema = schema + self.capabilities = capabilities + self.config = config + self.type_mapper = self.capabilities.get_type_mapper(self.sql_client.dialect) + + def _to_table_object(self, schema_table: PreparedTableSchema) -> sa.Table: + existing = self.sql_client.get_existing_table(schema_table["name"]) + if existing is not None: + existing_col_names = set(col.name for col in existing.columns) + new_col_names = set(schema_table["columns"]) + # Re-generate the table if columns have changed + if existing_col_names == new_col_names: + return existing + return sa.Table( + schema_table["name"], + self.sql_client.metadata, + *[ + self._to_column_object(col, schema_table) + for col in schema_table["columns"].values() + ], + extend_existing=True, + schema=self.sql_client.dataset_name, + ) + + def _to_column_object( + self, schema_column: TColumnSchema, table: PreparedTableSchema + ) -> sa.Column: + return sa.Column( + schema_column["name"], + self.type_mapper.to_destination_type(schema_column, table), + nullable=schema_column.get("nullable", True), + unique=schema_column.get("unique", False), + ) + + def create_load_job( + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + if file_path.endswith(".typed-jsonl"): + table_obj = self._to_table_object(table) + return SqlalchemyJsonLInsertJob(file_path, table_obj) + elif file_path.endswith(".parquet"): + table_obj = self._to_table_object(table) + return SqlalchemyParquetInsertJob(file_path, table_obj) + return None + + def complete_load(self, load_id: str) -> None: + loads_table = self._to_table_object(self.schema.tables[self.schema.loads_table_name]) # type: ignore[arg-type] + now_ts = pendulum.now() + self.sql_client.execute_sql( + loads_table.insert().values( + ( + load_id, + self.schema.name, + 0, + now_ts, + self.schema.version_hash, + ) + ) + ) + + def _get_table_key(self, name: str, schema: Optional[str]) -> str: + if schema is None: + return name + else: + return schema + "." + name + + def get_storage_tables( + self, table_names: Iterable[str] + ) -> Iterable[Tuple[str, TTableSchemaColumns]]: + metadata = sa.MetaData() + for table_name in table_names: + table_obj = self.sql_client.reflect_table(table_name, metadata) + if table_obj is None: + yield table_name, {} + continue + yield table_name, { + col.name: { + "name": col.name, + "nullable": col.nullable, + **self.type_mapper.from_destination_type(col.type, None, None), + } + for col in table_obj.columns + } + + def update_stored_schema( + self, only_tables: Iterable[str] = None, expected_update: TSchemaTables = None + ) -> Optional[TSchemaTables]: + # super().update_stored_schema(only_tables, expected_update) + JobClientBase.update_stored_schema(self, only_tables, expected_update) + + schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) + if schema_info is not None: + logger.info( + "Schema with hash %s inserted at %s found in storage, no upgrade required", + self.schema.stored_version_hash, + schema_info.inserted_at, + ) + return {} + else: + logger.info( + "Schema with hash %s not found in storage, upgrading", + self.schema.stored_version_hash, + ) + + # Create all schema tables in metadata + for table_name in only_tables or self.schema.tables: + self._to_table_object(self.schema.tables[table_name]) # type: ignore[arg-type] + + schema_update: TSchemaTables = {} + tables_to_create: List[sa.Table] = [] + columns_to_add: List[sa.Column] = [] + + for table_name in only_tables or self.schema.tables: + table = self.schema.tables[table_name] + table_obj, new_columns, exists = self.sql_client.compare_storage_table(table["name"]) + if not new_columns: # Nothing to do, don't create table without columns + continue + if not exists: + tables_to_create.append(table_obj) + else: + columns_to_add.extend(new_columns) + partial_table = self.prepare_load_table(table_name) + new_column_names = set(col.name for col in new_columns) + partial_table["columns"] = { + col_name: col_def + for col_name, col_def in partial_table["columns"].items() + if col_name in new_column_names + } + schema_update[table_name] = partial_table + + with self.sql_client.begin_transaction(): + for table_obj in tables_to_create: + self.sql_client.create_table(table_obj) + self.sql_client.alter_table_add_columns(columns_to_add) + self._update_schema_in_storage(self.schema) + + return schema_update + + def _delete_schema_in_storage(self, schema: Schema) -> None: + version_table = schema.tables[schema.version_table_name] + table_obj = self._to_table_object(version_table) # type: ignore[arg-type] + schema_name_col = schema.naming.normalize_identifier("schema_name") + self.sql_client.execute_sql( + table_obj.delete().where(table_obj.c[schema_name_col] == schema.name) + ) + + def _update_schema_in_storage(self, schema: Schema) -> None: + version_table = schema.tables[schema.version_table_name] + table_obj = self._to_table_object(version_table) # type: ignore[arg-type] + schema_str = json.dumps(schema.to_dict()) + + schema_mapping = StorageSchemaInfo( + version=schema.version, + engine_version=str(schema.ENGINE_VERSION), + schema_name=schema.name, + version_hash=schema.stored_version_hash, + schema=schema_str, + inserted_at=pendulum.now(), + ).to_normalized_mapping(schema.naming) + + self.sql_client.execute_sql(table_obj.insert().values(schema_mapping)) + + def _get_stored_schema( + self, version_hash: Optional[str] = None, schema_name: Optional[str] = None + ) -> Optional[StorageSchemaInfo]: + version_table = self.schema.tables[self.schema.version_table_name] + table_obj = self._to_table_object(version_table) # type: ignore[arg-type] + with suppress(DatabaseUndefinedRelation): + q = sa.select(table_obj) + if version_hash is not None: + version_hash_col = self.schema.naming.normalize_identifier("version_hash") + q = q.where(table_obj.c[version_hash_col] == version_hash) + if schema_name is not None: + schema_name_col = self.schema.naming.normalize_identifier("schema_name") + q = q.where(table_obj.c[schema_name_col] == schema_name) + inserted_at_col = self.schema.naming.normalize_identifier("inserted_at") + q = q.order_by(table_obj.c[inserted_at_col].desc()) + with self.sql_client.execute_query(q) as cur: + row = cur.fetchone() + if row is None: + return None + + # TODO: Decode compressed schema str if needed + return StorageSchemaInfo.from_normalized_mapping( + row._mapping, self.schema.naming # type: ignore[attr-defined] + ) + + def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchemaInfo]: + return self._get_stored_schema(version_hash) + + def get_stored_schema(self) -> Optional[StorageSchemaInfo]: + """Get the latest stored schema""" + return self._get_stored_schema(schema_name=self.schema.name) + + def get_stored_state(self, pipeline_name: str) -> StateInfo: + state_table = self.schema.tables.get( + self.schema.state_table_name + ) or normalize_table_identifiers(pipeline_state_table(), self.schema.naming) + state_table_obj = self._to_table_object(state_table) # type: ignore[arg-type] + loads_table = self.schema.tables[self.schema.loads_table_name] + loads_table_obj = self._to_table_object(loads_table) # type: ignore[arg-type] + + c_load_id, c_dlt_load_id, c_pipeline_name, c_status = map( + self.schema.naming.normalize_identifier, + ("load_id", "_dlt_load_id", "pipeline_name", "status"), + ) + + query = ( + sa.select(state_table_obj) + .join(loads_table_obj, loads_table_obj.c[c_load_id] == state_table_obj.c[c_dlt_load_id]) + .where( + sa.and_( + state_table_obj.c[c_pipeline_name] == pipeline_name, + loads_table_obj.c[c_status] == 0, + ) + ) + .order_by(loads_table_obj.c[c_load_id].desc()) + ) + + with self.sql_client.execute_query(query) as cur: + row = cur.fetchone() + if not row: + return None + mapping = dict(row._mapping) # type: ignore[attr-defined] + + return StateInfo.from_normalized_mapping(mapping, self.schema.naming) + + def _from_db_type( + self, db_type: str, precision: Optional[int], scale: Optional[int] + ) -> TColumnType: + raise NotImplementedError() + + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableSchema = None) -> str: + raise NotImplementedError() diff --git a/dlt/destinations/impl/sqlalchemy/type_mapper.py b/dlt/destinations/impl/sqlalchemy/type_mapper.py new file mode 100644 index 0000000000..767d2115d4 --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/type_mapper.py @@ -0,0 +1,174 @@ +from typing import Optional, Dict, Any +import inspect + +import sqlalchemy as sa +from sqlalchemy.sql import sqltypes + +from dlt.common.exceptions import TerminalValueError +from dlt.common.typing import TLoaderFileFormat +from dlt.common.destination.capabilities import DataTypeMapper, DestinationCapabilitiesContext +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.schema.typing import TColumnSchema + + +# TODO: base type mapper should be a generic class to support TypeEngine instead of str types +class SqlalchemyTypeMapper(DataTypeMapper): + def __init__( + self, + capabilities: DestinationCapabilitiesContext, + dialect: Optional[sa.engine.Dialect] = None, + ): + super().__init__(capabilities) + # Mapper is used to verify supported types without client, dialect is not important for this + self.dialect = dialect or sa.engine.default.DefaultDialect() + + def _db_integer_type(self, precision: Optional[int]) -> sa.types.TypeEngine: + if precision is None: + return sa.BigInteger() + elif precision <= 16: + return sa.SmallInteger() + elif precision <= 32: + return sa.Integer() + elif precision <= 64: + return sa.BigInteger() + raise TerminalValueError(f"Unsupported precision for integer type: {precision}") + + def _create_date_time_type( + self, sc_t: str, precision: Optional[int], timezone: Optional[bool] + ) -> sa.types.TypeEngine: + """Use the dialect specific datetime/time type if possible since the generic type doesn't accept precision argument""" + precision = precision if precision is not None else self.capabilities.timestamp_precision + base_type: sa.types.TypeEngine + timezone = timezone is None or bool(timezone) + if sc_t == "timestamp": + base_type = sa.DateTime() + if self.dialect.name == "mysql": + # Special case, type_descriptor does not return the specifc datetime type + from sqlalchemy.dialects.mysql import DATETIME + + return DATETIME(fsp=precision) + elif sc_t == "time": + base_type = sa.Time() + + dialect_type = type( + self.dialect.type_descriptor(base_type) + ) # Get the dialect specific subtype + precision = precision if precision is not None else self.capabilities.timestamp_precision + + # Find out whether the dialect type accepts precision or fsp argument + params = inspect.signature(dialect_type).parameters + kwargs: Dict[str, Any] = dict(timezone=timezone) + if "fsp" in params: + kwargs["fsp"] = precision # MySQL uses fsp for fractional seconds + elif "precision" in params: + kwargs["precision"] = precision + return dialect_type(**kwargs) # type: ignore[no-any-return,misc] + + def _create_double_type(self) -> sa.types.TypeEngine: + if dbl := getattr(sa, "Double", None): + # Sqlalchemy 2 has generic double type + return dbl() # type: ignore[no-any-return] + elif self.dialect.name == "mysql": + # MySQL has a specific double type + from sqlalchemy.dialects.mysql import DOUBLE + + return DOUBLE() + return sa.Float(precision=53) # Otherwise use float + + def _to_db_decimal_type(self, column: TColumnSchema) -> sa.types.TypeEngine: + precision, scale = column.get("precision"), column.get("scale") + if precision is None and scale is None: + precision, scale = self.capabilities.decimal_precision + return sa.Numeric(precision, scale) + + def to_destination_type( # type: ignore[override] + self, column: TColumnSchema, table: PreparedTableSchema = None + ) -> sqltypes.TypeEngine: + sc_t = column["data_type"] + precision = column.get("precision") + # TODO: Precision and scale for supported types + if sc_t == "text": + length = precision + if length is None and column.get("unique"): + length = 128 + if length is None: + return sa.Text() + return sa.String(length=length) + elif sc_t == "double": + return self._create_double_type() + elif sc_t == "bool": + return sa.Boolean() + elif sc_t == "timestamp": + return self._create_date_time_type(sc_t, precision, column.get("timezone")) + elif sc_t == "bigint": + return self._db_integer_type(precision) + elif sc_t == "binary": + return sa.LargeBinary(length=precision) + elif sc_t == "json": + return sa.JSON(none_as_null=True) + elif sc_t == "decimal": + return self._to_db_decimal_type(column) + elif sc_t == "wei": + wei_precision, wei_scale = self.capabilities.wei_precision + return sa.Numeric(precision=wei_precision, scale=wei_scale) + elif sc_t == "date": + return sa.Date() + elif sc_t == "time": + return self._create_date_time_type(sc_t, precision, column.get("timezone")) + raise TerminalValueError(f"Unsupported data type: {sc_t}") + + def _from_db_integer_type(self, db_type: sa.Integer) -> TColumnSchema: + if isinstance(db_type, sa.SmallInteger): + return dict(data_type="bigint", precision=16) + elif isinstance(db_type, sa.Integer): + return dict(data_type="bigint", precision=32) + elif isinstance(db_type, sa.BigInteger): + return dict(data_type="bigint") + return dict(data_type="bigint") + + def _from_db_decimal_type(self, db_type: sa.Numeric) -> TColumnSchema: + precision, scale = db_type.precision, db_type.scale + if (precision, scale) == self.capabilities.wei_precision: + return dict(data_type="wei") + + return dict(data_type="decimal", precision=precision, scale=scale) + + def from_destination_type( # type: ignore[override] + self, + db_type: sqltypes.TypeEngine, + precision: Optional[int] = None, + scale: Optional[int] = None, + ) -> TColumnSchema: + # TODO: pass the sqla type through dialect.type_descriptor before instance check + # Possibly need to check both dialect specific and generic types + if isinstance(db_type, sa.String): + return dict(data_type="text") + elif isinstance(db_type, sa.Float): + return dict(data_type="double") + elif isinstance(db_type, sa.Boolean): + return dict(data_type="bool") + elif isinstance(db_type, sa.DateTime): + return dict(data_type="timestamp", timezone=db_type.timezone) + elif isinstance(db_type, sa.Integer): + return self._from_db_integer_type(db_type) + elif isinstance(db_type, sqltypes._Binary): + return dict(data_type="binary", precision=db_type.length) + elif isinstance(db_type, sa.JSON): + return dict(data_type="json") + elif isinstance(db_type, sa.Numeric): + return self._from_db_decimal_type(db_type) + elif isinstance(db_type, sa.Date): + return dict(data_type="date") + elif isinstance(db_type, sa.Time): + return dict(data_type="time") + raise TerminalValueError(f"Unsupported db type: {db_type}") + + pass + + def ensure_supported_type( + self, + column: TColumnSchema, + table: PreparedTableSchema, + loader_file_format: TLoaderFileFormat, + ) -> None: + pass diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 27d1bc7ce5..96f18cea3d 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -133,6 +133,23 @@ def drop_tables(self, *tables: str) -> None: ] self.execute_many(statements) + def _to_named_paramstyle(self, query: str, args: Sequence[Any]) -> Tuple[str, Dict[str, Any]]: + """Convert a query from "format" ( %s ) paramstyle to "named" ( :param_name ) paramstyle. + The %s are replaced with :arg0, :arg1, ... and the arguments are returned as a dictionary. + + Args: + query: SQL query with %s placeholders + args: arguments to be passed to the query + + Returns: + Tuple of the new query and a dictionary of named arguments + """ + keys = [f"arg{i}" for i in range(len(args))] + # Replace position arguments (%s) with named arguments (:arg0, :arg1, ...) + query = query % tuple(f":{key}" for key in keys) + db_args = {key: db_arg for key, db_arg in zip(keys, args)} + return query, db_args + @abstractmethod def execute_sql( self, sql: AnyStr, *args: Any, **kwargs: Any diff --git a/poetry.lock b/poetry.lock index 9cbc4b66ea..e1afddfd5f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "about-time" @@ -216,13 +216,13 @@ frozenlist = ">=1.1.0" [[package]] name = "alembic" -version = "1.12.0" +version = "1.13.2" description = "A database migration tool for SQLAlchemy." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "alembic-1.12.0-py3-none-any.whl", hash = "sha256:03226222f1cf943deee6c85d9464261a6c710cd19b4fe867a3ad1f25afda610f"}, - {file = "alembic-1.12.0.tar.gz", hash = "sha256:8e7645c32e4f200675e69f0745415335eb59a3663f5feb487abfa0b30c45888b"}, + {file = "alembic-1.13.2-py3-none-any.whl", hash = "sha256:6b8733129a6224a9a711e17c99b08462dbf7cc9670ba8f2e2ae9af860ceb1953"}, + {file = "alembic-1.13.2.tar.gz", hash = "sha256:1ff0ae32975f4fd96028c39ed9bb3c867fe3af956bd7bb37343b54c9fe7445ef"}, ] [package.dependencies] @@ -233,7 +233,7 @@ SQLAlchemy = ">=1.3.0" typing-extensions = ">=4" [package.extras] -tz = ["python-dateutil"] +tz = ["backports.zoneinfo"] [[package]] name = "alive-progress" @@ -3749,6 +3749,106 @@ files = [ {file = "google_re2-1.1-4-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f4d4f0823e8b2f6952a145295b1ff25245ce9bb136aff6fe86452e507d4c1dd"}, {file = "google_re2-1.1-4-cp39-cp39-win32.whl", hash = "sha256:1afae56b2a07bb48cfcfefaa15ed85bae26a68f5dc7f9e128e6e6ea36914e847"}, {file = "google_re2-1.1-4-cp39-cp39-win_amd64.whl", hash = "sha256:aa7d6d05911ab9c8adbf3c225a7a120ab50fd2784ac48f2f0d140c0b7afc2b55"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:222fc2ee0e40522de0b21ad3bc90ab8983be3bf3cec3d349c80d76c8bb1a4beb"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:d4763b0b9195b72132a4e7de8e5a9bf1f05542f442a9115aa27cfc2a8004f581"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:209649da10c9d4a93d8a4d100ecbf9cc3b0252169426bec3e8b4ad7e57d600cf"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:68813aa333c1604a2df4a495b2a6ed065d7c8aebf26cc7e7abb5a6835d08353c"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:370a23ec775ad14e9d1e71474d56f381224dcf3e72b15d8ca7b4ad7dd9cd5853"}, + {file = "google_re2-1.1-5-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:14664a66a3ddf6bc9e56f401bf029db2d169982c53eff3f5876399104df0e9a6"}, + {file = "google_re2-1.1-5-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3ea3722cc4932cbcebd553b69dce1b4a73572823cff4e6a244f1c855da21d511"}, + {file = "google_re2-1.1-5-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e14bb264c40fd7c627ef5678e295370cd6ba95ca71d835798b6e37502fc4c690"}, + {file = "google_re2-1.1-5-cp310-cp310-win32.whl", hash = "sha256:39512cd0151ea4b3969c992579c79b423018b464624ae955be685fc07d94556c"}, + {file = "google_re2-1.1-5-cp310-cp310-win_amd64.whl", hash = "sha256:ac66537aa3bc5504320d922b73156909e3c2b6da19739c866502f7827b3f9fdf"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5b5ea68d54890c9edb1b930dcb2658819354e5d3f2201f811798bbc0a142c2b4"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:33443511b6b83c35242370908efe2e8e1e7cae749c766b2b247bf30e8616066c"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:413d77bdd5ba0bfcada428b4c146e87707452ec50a4091ec8e8ba1413d7e0619"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:5171686e43304996a34baa2abcee6f28b169806d0e583c16d55e5656b092a414"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3b284db130283771558e31a02d8eb8fb756156ab98ce80035ae2e9e3a5f307c4"}, + {file = "google_re2-1.1-5-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:296e6aed0b169648dc4b870ff47bd34c702a32600adb9926154569ef51033f47"}, + {file = "google_re2-1.1-5-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:38d50e68ead374160b1e656bbb5d101f0b95fb4cc57f4a5c12100155001480c5"}, + {file = "google_re2-1.1-5-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2a0416a35921e5041758948bcb882456916f22845f66a93bc25070ef7262b72a"}, + {file = "google_re2-1.1-5-cp311-cp311-win32.whl", hash = "sha256:a1d59568bbb5de5dd56dd6cdc79907db26cce63eb4429260300c65f43469e3e7"}, + {file = "google_re2-1.1-5-cp311-cp311-win_amd64.whl", hash = "sha256:72f5a2f179648b8358737b2b493549370debd7d389884a54d331619b285514e3"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:cbc72c45937b1dc5acac3560eb1720007dccca7c9879138ff874c7f6baf96005"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:5fadd1417fbef7235fa9453dba4eb102e6e7d94b1e4c99d5fa3dd4e288d0d2ae"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:040f85c63cc02696485b59b187a5ef044abe2f99b92b4fb399de40b7d2904ccc"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:64e3b975ee6d9bbb2420494e41f929c1a0de4bcc16d86619ab7a87f6ea80d6bd"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:8ee370413e00f4d828eaed0e83b8af84d7a72e8ee4f4bd5d3078bc741dfc430a"}, + {file = "google_re2-1.1-5-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:5b89383001079323f693ba592d7aad789d7a02e75adb5d3368d92b300f5963fd"}, + {file = "google_re2-1.1-5-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:63cb4fdfbbda16ae31b41a6388ea621510db82feb8217a74bf36552ecfcd50ad"}, + {file = "google_re2-1.1-5-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ebedd84ae8be10b7a71a16162376fd67a2386fe6361ef88c622dcf7fd679daf"}, + {file = "google_re2-1.1-5-cp312-cp312-win32.whl", hash = "sha256:c8e22d1692bc2c81173330c721aff53e47ffd3c4403ff0cd9d91adfd255dd150"}, + {file = "google_re2-1.1-5-cp312-cp312-win_amd64.whl", hash = "sha256:5197a6af438bb8c4abda0bbe9c4fbd6c27c159855b211098b29d51b73e4cbcf6"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b6727e0b98417e114b92688ad2aa256102ece51f29b743db3d831df53faf1ce3"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:711e2b6417eb579c61a4951029d844f6b95b9b373b213232efd413659889a363"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:71ae8b3df22c5c154c8af0f0e99d234a450ef1644393bc2d7f53fc8c0a1e111c"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:94a04e214bc521a3807c217d50cf099bbdd0c0a80d2d996c0741dbb995b5f49f"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_14_0_arm64.whl", hash = "sha256:a770f75358508a9110c81a1257721f70c15d9bb592a2fb5c25ecbd13566e52a5"}, + {file = "google_re2-1.1-5-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:07c9133357f7e0b17c6694d5dcb82e0371f695d7c25faef2ff8117ef375343ff"}, + {file = "google_re2-1.1-5-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:204ca6b1cf2021548f4a9c29ac015e0a4ab0a7b6582bf2183d838132b60c8fda"}, + {file = "google_re2-1.1-5-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f0b95857c2c654f419ca684ec38c9c3325c24e6ba7d11910a5110775a557bb18"}, + {file = "google_re2-1.1-5-cp38-cp38-win32.whl", hash = "sha256:347ac770e091a0364e822220f8d26ab53e6fdcdeaec635052000845c5a3fb869"}, + {file = "google_re2-1.1-5-cp38-cp38-win_amd64.whl", hash = "sha256:ec32bb6de7ffb112a07d210cf9f797b7600645c2d5910703fa07f456dd2150e0"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:eb5adf89060f81c5ff26c28e261e6b4997530a923a6093c9726b8dec02a9a326"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:a22630c9dd9ceb41ca4316bccba2643a8b1d5c198f21c00ed5b50a94313aaf10"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:544dc17fcc2d43ec05f317366375796351dec44058e1164e03c3f7d050284d58"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:19710af5ea88751c7768575b23765ce0dfef7324d2539de576f75cdc319d6654"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:f82995a205e08ad896f4bd5ce4847c834fab877e1772a44e5f262a647d8a1dec"}, + {file = "google_re2-1.1-5-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:63533c4d58da9dc4bc040250f1f52b089911699f0368e0e6e15f996387a984ed"}, + {file = "google_re2-1.1-5-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:79e00fcf0cb04ea35a22b9014712d448725ce4ddc9f08cc818322566176ca4b0"}, + {file = "google_re2-1.1-5-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bc41afcefee2da6c4ed883a93d7f527c4b960cd1d26bbb0020a7b8c2d341a60a"}, + {file = "google_re2-1.1-5-cp39-cp39-win32.whl", hash = "sha256:486730b5e1f1c31b0abc6d80abe174ce4f1188fe17d1b50698f2bf79dc6e44be"}, + {file = "google_re2-1.1-5-cp39-cp39-win_amd64.whl", hash = "sha256:4de637ca328f1d23209e80967d1b987d6b352cd01b3a52a84b4d742c69c3da6c"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:621e9c199d1ff0fdb2a068ad450111a84b3bf14f96dfe5a8a7a0deae5f3f4cce"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:220acd31e7dde95373f97c3d1f3b3bd2532b38936af28b1917ee265d25bebbf4"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:db34e1098d164f76251a6ece30e8f0ddfd65bb658619f48613ce71acb3f9cbdb"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:5152bac41d8073977582f06257219541d0fc46ad99b0bbf30e8f60198a43b08c"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:6191294799e373ee1735af91f55abd23b786bdfd270768a690d9d55af9ea1b0d"}, + {file = "google_re2-1.1-6-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:070cbafbb4fecbb02e98feb28a1eb292fb880f434d531f38cc33ee314b521f1f"}, + {file = "google_re2-1.1-6-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8437d078b405a59a576cbed544490fe041140f64411f2d91012e8ec05ab8bf86"}, + {file = "google_re2-1.1-6-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f00f9a9af8896040e37896d9b9fc409ad4979f1ddd85bb188694a7d95ddd1164"}, + {file = "google_re2-1.1-6-cp310-cp310-win32.whl", hash = "sha256:df26345f229a898b4fd3cafd5f82259869388cee6268fc35af16a8e2293dd4e5"}, + {file = "google_re2-1.1-6-cp310-cp310-win_amd64.whl", hash = "sha256:3665d08262c57c9b28a5bdeb88632ad792c4e5f417e5645901695ab2624f5059"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b26b869d8aa1d8fe67c42836bf3416bb72f444528ee2431cfb59c0d3e02c6ce3"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:41fd4486c57dea4f222a6bb7f1ff79accf76676a73bdb8da0fcbd5ba73f8da71"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:0ee378e2e74e25960070c338c28192377c4dd41e7f4608f2688064bd2badc41e"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:a00cdbf662693367b36d075b29feb649fd7ee1b617cf84f85f2deebeda25fc64"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:4c09455014217a41499432b8c8f792f25f3df0ea2982203c3a8c8ca0e7895e69"}, + {file = "google_re2-1.1-6-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:6501717909185327935c7945e23bb5aa8fc7b6f237b45fe3647fa36148662158"}, + {file = "google_re2-1.1-6-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3510b04790355f199e7861c29234081900e1e1cbf2d1484da48aa0ba6d7356ab"}, + {file = "google_re2-1.1-6-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8c0e64c187ca406764f9e9ad6e750d62e69ed8f75bf2e865d0bfbc03b642361c"}, + {file = "google_re2-1.1-6-cp311-cp311-win32.whl", hash = "sha256:2a199132350542b0de0f31acbb3ca87c3a90895d1d6e5235f7792bb0af02e523"}, + {file = "google_re2-1.1-6-cp311-cp311-win_amd64.whl", hash = "sha256:83bdac8ceaece8a6db082ea3a8ba6a99a2a1ee7e9f01a9d6d50f79c6f251a01d"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:81985ff894cd45ab5a73025922ac28c0707759db8171dd2f2cc7a0e856b6b5ad"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:5635af26065e6b45456ccbea08674ae2ab62494008d9202df628df3b267bc095"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:813b6f04de79f4a8fdfe05e2cb33e0ccb40fe75d30ba441d519168f9d958bd54"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:5ec2f5332ad4fd232c3f2d6748c2c7845ccb66156a87df73abcc07f895d62ead"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:5a687b3b32a6cbb731647393b7c4e3fde244aa557f647df124ff83fb9b93e170"}, + {file = "google_re2-1.1-6-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:39a62f9b3db5d3021a09a47f5b91708b64a0580193e5352751eb0c689e4ad3d7"}, + {file = "google_re2-1.1-6-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ca0f0b45d4a1709cbf5d21f355e5809ac238f1ee594625a1e5ffa9ff7a09eb2b"}, + {file = "google_re2-1.1-6-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a64b3796a7a616c7861247bd061c9a836b5caf0d5963e5ea8022125601cf7b09"}, + {file = "google_re2-1.1-6-cp312-cp312-win32.whl", hash = "sha256:32783b9cb88469ba4cd9472d459fe4865280a6b1acdad4480a7b5081144c4eb7"}, + {file = "google_re2-1.1-6-cp312-cp312-win_amd64.whl", hash = "sha256:259ff3fd2d39035b9cbcbf375995f83fa5d9e6a0c5b94406ff1cc168ed41d6c6"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:e4711bcffe190acd29104d8ecfea0c0e42b754837de3fb8aad96e6cc3c613cdc"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:4d081cce43f39c2e813fe5990e1e378cbdb579d3f66ded5bade96130269ffd75"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_13_0_arm64.whl", hash = "sha256:4f123b54d48450d2d6b14d8fad38e930fb65b5b84f1b022c10f2913bd956f5b5"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:e1928b304a2b591a28eb3175f9db7f17c40c12cf2d4ec2a85fdf1cc9c073ff91"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_14_0_arm64.whl", hash = "sha256:3a69f76146166aec1173003c1f547931bdf288c6b135fda0020468492ac4149f"}, + {file = "google_re2-1.1-6-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:fc08c388f4ebbbca345e84a0c56362180d33d11cbe9ccfae663e4db88e13751e"}, + {file = "google_re2-1.1-6-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b057adf38ce4e616486922f2f47fc7d19c827ba0a7f69d540a3664eba2269325"}, + {file = "google_re2-1.1-6-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4138c0b933ab099e96f5d8defce4486f7dfd480ecaf7f221f2409f28022ccbc5"}, + {file = "google_re2-1.1-6-cp38-cp38-win32.whl", hash = "sha256:9693e45b37b504634b1abbf1ee979471ac6a70a0035954592af616306ab05dd6"}, + {file = "google_re2-1.1-6-cp38-cp38-win_amd64.whl", hash = "sha256:5674d437baba0ea287a5a7f8f81f24265d6ae8f8c09384e2ef7b6f84b40a7826"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7783137cb2e04f458a530c6d0ee9ef114815c1d48b9102f023998c371a3b060e"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:a49b7153935e7a303675f4deb5f5d02ab1305adefc436071348706d147c889e0"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:a96a8bb309182090704593c60bdb369a2756b38fe358bbf0d40ddeb99c71769f"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:dff3d4be9f27ef8ec3705eed54f19ef4ab096f5876c15fe011628c69ba3b561c"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:40f818b0b39e26811fa677978112a8108269977fdab2ba0453ac4363c35d9e66"}, + {file = "google_re2-1.1-6-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:8a7e53538cdb40ef4296017acfbb05cab0c19998be7552db1cfb85ba40b171b9"}, + {file = "google_re2-1.1-6-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6ee18e7569fb714e5bb8c42809bf8160738637a5e71ed5a4797757a1fb4dc4de"}, + {file = "google_re2-1.1-6-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1cda4f6d1a7d5b43ea92bc395f23853fba0caf8b1e1efa6e8c48685f912fcb89"}, + {file = "google_re2-1.1-6-cp39-cp39-win32.whl", hash = "sha256:6a9cdbdc36a2bf24f897be6a6c85125876dc26fea9eb4247234aec0decbdccfd"}, + {file = "google_re2-1.1-6-cp39-cp39-win_amd64.whl", hash = "sha256:73f646cecfad7cc5b4330b4192c25f2e29730a3b8408e089ffd2078094208196"}, ] [[package]] @@ -9664,10 +9764,11 @@ redshift = ["psycopg2-binary", "psycopg2cffi"] s3 = ["botocore", "s3fs"] snowflake = ["snowflake-connector-python"] sql-database = ["sqlalchemy"] +sqlalchemy = ["alembic", "sqlalchemy"] synapse = ["adlfs", "pyarrow", "pyodbc"] weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "e5342d5cdc135a27b89747a3665ff68aa76025efcfde6f86318144ce0fd70284" +content-hash = "e6409812458e0aae06a9b5f2c816b19d73645c8598a1789b8d5b39a45e974a9f" diff --git a/pyproject.toml b/pyproject.toml index 4ca80d0993..df0bc11782 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ lancedb = { version = ">=0.8.2", optional = true, markers = "python_version >= ' tantivy = { version = ">= 0.22.0", optional = true } deltalake = { version = ">=0.19.0", optional = true } sqlalchemy = { version = ">=1.4", optional = true } +alembic = {version = "^1.13.2", optional = true} [tool.poetry.extras] gcp = ["grpcio", "google-cloud-bigquery", "db-dtypes", "gcsfs"] @@ -111,7 +112,7 @@ dremio = ["pyarrow"] lancedb = ["lancedb", "pyarrow", "tantivy"] deltalake = ["deltalake", "pyarrow"] sql_database = ["sqlalchemy"] - +sqlalchemy = ["sqlalchemy", "alembic"] [tool.poetry.scripts] dlt = "dlt.cli._dlt:_main" diff --git a/tests/cases.py b/tests/cases.py index 11358441ee..9b636d9b60 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -1,3 +1,4 @@ +import datetime # noqa: I251 import hashlib from typing import Dict, List, Any, Sequence, Tuple, Literal, Union import base64 @@ -5,6 +6,7 @@ from copy import deepcopy from string import ascii_lowercase import random +import secrets from dlt.common import Decimal, pendulum, json from dlt.common.data_types import TDataType @@ -214,19 +216,22 @@ def assert_all_data_types_row( expected_rows = {key: value for key, value in expected_row.items() if key in schema} # prepare date to be compared: convert into pendulum instance, adjust microsecond precision if "col4" in expected_rows: - parsed_date = pendulum.instance(db_mapping["col4"]) + parsed_date = ensure_pendulum_datetime((db_mapping["col4"])) db_mapping["col4"] = reduce_pendulum_datetime_precision(parsed_date, timestamp_precision) expected_rows["col4"] = reduce_pendulum_datetime_precision( ensure_pendulum_datetime(expected_rows["col4"]), # type: ignore[arg-type] timestamp_precision, ) if "col4_precision" in expected_rows: - parsed_date = pendulum.instance(db_mapping["col4_precision"]) + parsed_date = ensure_pendulum_datetime((db_mapping["col4_precision"])) db_mapping["col4_precision"] = reduce_pendulum_datetime_precision(parsed_date, 3) expected_rows["col4_precision"] = reduce_pendulum_datetime_precision( ensure_pendulum_datetime(expected_rows["col4_precision"]), 3 # type: ignore[arg-type] ) + if "col10" in expected_rows: + db_mapping["col10"] = ensure_pendulum_date(db_mapping["col10"]) + if "col11" in expected_rows: expected_rows["col11"] = reduce_pendulum_datetime_precision( ensure_pendulum_time(expected_rows["col11"]), timestamp_precision # type: ignore[arg-type] @@ -315,7 +320,7 @@ def arrow_table_all_data_types( import numpy as np data = { - "string": [random.choice(ascii_lowercase) + "\"'\\🦆\n\r" for _ in range(num_rows)], + "string": [secrets.token_urlsafe(8) + "\"'\\🦆\n\r" for _ in range(num_rows)], "float": [round(random.uniform(0, 100), 4) for _ in range(num_rows)], "int": [random.randrange(0, 100) for _ in range(num_rows)], "datetime": pd.date_range("2021-01-01T01:02:03.1234", periods=num_rows, tz=tz, unit="us"), @@ -340,7 +345,18 @@ def arrow_table_all_data_types( data["json"] = [{"a": random.randrange(0, 100)} for _ in range(num_rows)] if include_time: - data["time"] = pd.date_range("2021-01-01", periods=num_rows, tz="UTC").time + # data["time"] = pd.date_range("2021-01-01", periods=num_rows, tz="UTC").time + # data["time"] = pd.date_range("2021-01-01T01:02:03.1234", periods=num_rows, tz=tz, unit="us").time + # random time objects with different hours/minutes/seconds/microseconds + data["time"] = [ + datetime.time( + random.randrange(0, 24), + random.randrange(0, 60), + random.randrange(0, 60), + random.randrange(0, 1000000), + ) + for _ in range(num_rows) + ] if include_binary: # "binary": [hashlib.sha3_256(random.choice(ascii_lowercase).encode()).digest() for _ in range(num_rows)], diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index a86deea799..f72aaec1d8 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -1,4 +1,4 @@ -from datetime import datetime # noqa: I251 +from datetime import datetime, timedelta, time as dt_time, date # noqa: I251 import os import pytest @@ -9,7 +9,12 @@ import dlt from dlt.common import pendulum -from dlt.common.time import reduce_pendulum_datetime_precision +from dlt.common.time import ( + reduce_pendulum_datetime_precision, + ensure_pendulum_time, + ensure_pendulum_datetime, + ensure_pendulum_date, +) from dlt.common.utils import uniq_id from tests.load.utils import destinations_configs, DestinationTestConfiguration @@ -41,7 +46,7 @@ def test_load_arrow_item( # os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_LOAD_ID"] = "True" os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_ID"] = "True" - include_time = destination_config.destination not in ( + include_time = destination_config.destination_type not in ( "athena", "redshift", "databricks", @@ -49,15 +54,21 @@ def test_load_arrow_item( "clickhouse", ) # athena/redshift can't load TIME columns include_binary = not ( - destination_config.destination in ("redshift", "databricks") + destination_config.destination_type in ("redshift", "databricks") and destination_config.file_format == "jsonl" ) - include_decimal = not ( - destination_config.destination == "databricks" and destination_config.file_format == "jsonl" - ) + include_decimal = True + + if ( + destination_config.destination_type == "databricks" + and destination_config.file_format == "jsonl" + ) or (destination_config.destination_name == "sqlalchemy_sqlite"): + include_decimal = False + include_date = not ( - destination_config.destination == "databricks" and destination_config.file_format == "jsonl" + destination_config.destination_type == "databricks" + and destination_config.file_format == "jsonl" ) item, records, _ = arrow_table_all_data_types( @@ -77,7 +88,9 @@ def some_data(): # use csv for postgres to get native arrow processing destination_config.file_format = ( - destination_config.file_format if destination_config.destination != "postgres" else "csv" + destination_config.file_format + if destination_config.destination_type != "postgres" + else "csv" ) load_info = pipeline.run(some_data(), **destination_config.run_kwargs) @@ -107,13 +120,13 @@ def some_data(): if isinstance(row[i], memoryview): row[i] = row[i].tobytes() - if destination_config.destination == "redshift": + if destination_config.destination_type == "redshift": # Redshift needs hex string for record in records: if "binary" in record: record["binary"] = record["binary"].hex() - if destination_config.destination == "clickhouse": + if destination_config.destination_type == "clickhouse": for record in records: # Clickhouse needs base64 string for jsonl if "binary" in record and destination_config.file_format == "jsonl": @@ -121,23 +134,29 @@ def some_data(): if "binary" in record and destination_config.file_format == "parquet": record["binary"] = record["binary"].decode("ascii") - for row in rows: - for i in range(len(row)): - if isinstance(row[i], datetime): - row[i] = pendulum.instance(row[i]) + expected = sorted([list(r.values()) for r in records]) + first_record = list(records[0].values()) + for row, expected_row in zip(rows, expected): + for i in range(len(expected_row)): + if isinstance(expected_row[i], datetime): + row[i] = ensure_pendulum_datetime(row[i]) # clickhouse produces rounding errors on double with jsonl, so we round the result coming from there - if ( - destination_config.destination == "clickhouse" + elif ( + destination_config.destination_type == "clickhouse" and destination_config.file_format == "jsonl" and isinstance(row[i], float) ): row[i] = round(row[i], 4) - - expected = sorted([list(r.values()) for r in records]) + elif isinstance(first_record[i], dt_time): + # Some drivers (mysqlclient) return TIME columns as timedelta as seconds since midnight + # sqlite returns iso strings + row[i] = ensure_pendulum_time(row[i]) + elif isinstance(expected_row[i], date): + row[i] = ensure_pendulum_date(row[i]) for row in expected: for i in range(len(row)): - if isinstance(row[i], datetime): + if isinstance(row[i], (datetime, dt_time)): row[i] = reduce_pendulum_datetime_precision( row[i], pipeline.destination.capabilities().timestamp_precision ) @@ -235,7 +254,7 @@ def test_load_arrow_with_not_null_columns( ) -> None: """Resource schema contains non-nullable columns. Arrow schema should be written accordingly""" if ( - destination_config.destination in ("databricks", "redshift") + destination_config.destination_type in ("databricks", "redshift") and destination_config.file_format == "jsonl" ): pytest.skip( diff --git a/tests/load/pipeline/test_dbt_helper.py b/tests/load/pipeline/test_dbt_helper.py index 0ee02ba4b5..d55c81e998 100644 --- a/tests/load/pipeline/test_dbt_helper.py +++ b/tests/load/pipeline/test_dbt_helper.py @@ -36,7 +36,7 @@ def dbt_venv() -> Iterator[Venv]: def test_run_jaffle_package( destination_config: DestinationTestConfiguration, dbt_venv: Venv ) -> None: - if destination_config.destination == "athena": + if destination_config.destination_type == "athena": pytest.skip( "dbt-athena requires database to be created and we don't do it in case of Jaffle" ) @@ -71,7 +71,7 @@ def test_run_jaffle_package( ids=lambda x: x.name, ) def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_venv: Venv) -> None: - if destination_config.destination == "mssql": + if destination_config.destination_type == "mssql": pytest.skip( "mssql requires non standard SQL syntax and we do not have specialized dbt package" " for it" @@ -130,7 +130,7 @@ def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_ven def test_run_chess_dbt_to_other_dataset( destination_config: DestinationTestConfiguration, dbt_venv: Venv ) -> None: - if destination_config.destination == "mssql": + if destination_config.destination_type == "mssql": pytest.skip( "mssql requires non standard SQL syntax and we do not have specialized dbt package" " for it" diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 6c6ef21140..b1244de336 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -496,7 +496,10 @@ def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) assert_load_info(info) # make sure it was parquet or sql inserts files = p.get_load_package_info(p.list_completed_load_packages()[1]).jobs["completed_jobs"] - if destination_config.destination == "athena" and destination_config.table_format == "iceberg": + if ( + destination_config.destination_type == "athena" + and destination_config.table_format == "iceberg" + ): # iceberg uses sql to copy tables expected_formats.append("sql") assert all(f.job_file_info.file_format in expected_formats for f in files) @@ -541,12 +544,18 @@ def _get_shuffled_events(shuffle: bool = dlt.secrets.value): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, ) @pytest.mark.parametrize("github_resource", [github_repo_events, github_repo_events_table_meta]) def test_merge_with_dispatch_and_incremental( destination_config: DestinationTestConfiguration, github_resource: DltResource ) -> None: + if destination_config.destination_name == "sqlalchemy_mysql": + # TODO: Github events have too many columns for MySQL + pytest.skip("MySQL can't handle too many columns") + newest_issues = list( sorted(_get_shuffled_events(True), key=lambda x: x["created_at"], reverse=True) ) @@ -1118,7 +1127,7 @@ def data_resource(data): assert sorted(observed, key=lambda d: d["id"]) == expected # additional tests with two records, run only on duckdb to limit test load - if destination_config.destination == "duckdb": + if destination_config.destination_type == "duckdb": # two records with same primary key # record with highest value in sort column is a delete # existing record is deleted and no record will be inserted @@ -1199,7 +1208,7 @@ def r(): ids=lambda x: x.name, ) def test_upsert_merge_strategy_config(destination_config: DestinationTestConfiguration) -> None: - if destination_config.destination == "filesystem": + if destination_config.destination_type == "filesystem": # TODO: implement validation and remove this test exception pytest.skip( "`upsert` merge strategy configuration validation has not yet been" diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 95fed2343c..659bca6cb9 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -100,7 +100,7 @@ def data_fun() -> Iterator[Any]: # mock the correct destinations (never do that in normal code) with p.managed_state(): p._set_destinations( - destination=Destination.from_reference(destination_config.destination), + destination=destination_config.destination_factory(), staging=( Destination.from_reference(destination_config.staging) if destination_config.staging @@ -162,13 +162,12 @@ def test_default_schema_name( for idx, alpha in [(0, "A"), (0, "B"), (0, "C")] ] - p = dlt.pipeline( + p = destination_config.setup_pipeline( "test_default_schema_name", - TEST_STORAGE_ROOT, - destination=destination_config.destination, - staging=destination_config.staging, dataset_name=dataset_name, + pipelines_dir=TEST_STORAGE_ROOT, ) + p.config.use_single_dataset = use_single_dataset p.extract( data, @@ -212,7 +211,7 @@ def _data(): destination_config.setup() info = dlt.run( _data(), - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name="specific" + uniq_id(), **destination_config.run_kwargs, @@ -288,7 +287,7 @@ def _data(): p = dlt.pipeline(dev_mode=True) info = p.run( _data(), - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name="iteration" + uniq_id(), **destination_config.run_kwargs, @@ -380,7 +379,7 @@ def extended_rows(): assert "new_column" not in schema.get_table("simple_rows")["columns"] # lets violate unique constraint on postgres, redshift and BQ ignore unique indexes - if destination_config.destination == "postgres": + if destination_config.destination_type == "postgres": # let it complete even with PK violation (which is a teminal error) os.environ["RAISE_ON_FAILED_JOBS"] = "false" assert p.dataset_name == dataset_name @@ -466,7 +465,7 @@ def nested_data(): info = dlt.run( nested_data(), - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name="ds_" + uniq_id(), **destination_config.run_kwargs, @@ -509,11 +508,11 @@ def other_data(): column_schemas = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) # parquet on bigquery and clickhouse does not support JSON but we still want to run the test - if destination_config.destination in ["bigquery"]: + if destination_config.destination_type in ["bigquery"]: column_schemas["col9_null"]["data_type"] = column_schemas["col9"]["data_type"] = "text" # duckdb 0.9.1 does not support TIME other than 6 - if destination_config.destination in ["duckdb", "motherduck"]: + if destination_config.destination_type in ["duckdb", "motherduck"]: column_schemas["col11_precision"]["precision"] = None # also we do not want to test col4_precision (datetime) because # those timestamps are not TZ aware in duckdb and we'd need to @@ -522,7 +521,7 @@ def other_data(): column_schemas["col4_precision"]["precision"] = 6 # drop TIME from databases not supporting it via parquet - if destination_config.destination in [ + if destination_config.destination_type in [ "redshift", "athena", "synapse", @@ -536,7 +535,7 @@ def other_data(): column_schemas.pop("col11_null") column_schemas.pop("col11_precision") - if destination_config.destination in ("redshift", "dremio"): + if destination_config.destination_type in ("redshift", "dremio"): data_types.pop("col7_precision") column_schemas.pop("col7_precision") @@ -562,7 +561,7 @@ def some_source(): if destination_config.supports_merge: expected_completed_jobs += 1 # add iceberg copy jobs - if destination_config.destination == "athena": + if destination_config.destination_type == "athena": expected_completed_jobs += 2 # if destination_config.supports_merge else 4 assert len(package_info.jobs["completed_jobs"]) == expected_completed_jobs @@ -583,10 +582,12 @@ def some_source(): assert_all_data_types_row( db_row, schema=column_schemas, - parse_json_strings=destination_config.destination + parse_json_strings=destination_config.destination_type in ["snowflake", "bigquery", "redshift"], - allow_string_binary=destination_config.destination == "clickhouse", - timestamp_precision=3 if destination_config.destination in ("athena", "dremio") else 6, + allow_string_binary=destination_config.destination_type == "clickhouse", + timestamp_precision=( + 3 if destination_config.destination_type in ("athena", "dremio") else 6 + ), ) @@ -755,7 +756,9 @@ def table_3(make_data=False): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, exclude=["sqlalchemy"]), + ids=lambda x: x.name, ) def test_query_all_info_tables_fallback(destination_config: DestinationTestConfiguration) -> None: pipeline = destination_config.setup_pipeline( @@ -837,7 +840,7 @@ def _data(): p = dlt.pipeline( pipeline_name=f"pipeline_{dataset_name}", dev_mode=dev_mode, - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name=dataset_name, ) @@ -879,7 +882,7 @@ def events(): ids=lambda x: x.name, ) def test_dest_column_hint_timezone(destination_config: DestinationTestConfiguration) -> None: - destination = destination_config.destination + destination = destination_config.destination_type input_data = [ {"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}, diff --git a/tests/load/pipeline/test_refresh_modes.py b/tests/load/pipeline/test_refresh_modes.py index 59063aacea..dcb2be44dc 100644 --- a/tests/load/pipeline/test_refresh_modes.py +++ b/tests/load/pipeline/test_refresh_modes.py @@ -275,7 +275,7 @@ def test_refresh_drop_data_only(destination_config: DestinationTestConfiguration data = load_tables_to_dicts(pipeline, "some_data_1", "some_data_2", "some_data_3") # name column still remains when table was truncated instead of dropped # (except on filesystem where truncate and drop are the same) - if destination_config.destination == "filesystem": + if destination_config.destination_type == "filesystem": result = sorted([row["id"] for row in data["some_data_1"]]) assert result == [3, 4] diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index 0e03119b7b..050636c491 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -219,13 +219,13 @@ def test_get_schemas_from_destination( use_single_dataset: bool, naming_convention: str, ) -> None: - set_naming_env(destination_config.destination, naming_convention) + set_naming_env(destination_config.destination_type, naming_convention) pipeline_name = "pipe_" + uniq_id() dataset_name = "state_test_" + uniq_id() p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) - assert_naming_to_caps(destination_config.destination, p.destination.capabilities()) + assert_naming_to_caps(destination_config.destination_type, p.destination.capabilities()) p.config.use_single_dataset = use_single_dataset def _make_dn_name(schema_name: str) -> str: @@ -318,13 +318,13 @@ def _make_dn_name(schema_name: str) -> str: def test_restore_state_pipeline( destination_config: DestinationTestConfiguration, naming_convention: str ) -> None: - set_naming_env(destination_config.destination, naming_convention) + set_naming_env(destination_config.destination_type, naming_convention) # enable restoring from destination os.environ["RESTORE_FROM_DESTINATION"] = "True" pipeline_name = "pipe_" + uniq_id() dataset_name = "state_test_" + uniq_id() p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) - assert_naming_to_caps(destination_config.destination, p.destination.capabilities()) + assert_naming_to_caps(destination_config.destination_type, p.destination.capabilities()) def some_data_gen(param: str) -> Any: dlt.current.source_state()[param] = param @@ -550,7 +550,7 @@ def test_restore_schemas_while_import_schemas_exist( ) # use run to get changes p.run( - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name=dataset_name, **destination_config.run_kwargs, @@ -601,7 +601,7 @@ def some_data(param: str) -> Any: p.run( [data1, some_data("state2")], schema=Schema("default"), - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name=dataset_name, **destination_config.run_kwargs, @@ -611,7 +611,7 @@ def some_data(param: str) -> Any: # create a production pipeline in separate pipelines_dir production_p = dlt.pipeline(pipeline_name=pipeline_name, pipelines_dir=TEST_STORAGE_ROOT) production_p.run( - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name=dataset_name, **destination_config.run_kwargs, @@ -691,7 +691,9 @@ def some_data(param: str) -> Any: [5, 4, 4, 3, 2], ) except SqlClientNotAvailable: - pytest.skip(f"destination {destination_config.destination} does not support sql client") + pytest.skip( + f"destination {destination_config.destination_type} does not support sql client" + ) @pytest.mark.parametrize( @@ -719,7 +721,7 @@ def some_data(param: str) -> Any: p.run( data4, schema=Schema("sch1"), - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name=dataset_name, **destination_config.run_kwargs, @@ -749,7 +751,7 @@ def some_data(param: str) -> Any: p.run( data4, schema=Schema("sch1"), - destination=destination_config.destination, + destination=destination_config.destination_factory(), staging=destination_config.staging, dataset_name=dataset_name, **destination_config.run_kwargs, diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index 0fc6a37dd9..cc8175b677 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -126,7 +126,7 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: # check item of first row in db with pipeline.sql_client() as sql_client: qual_name = sql_client.make_qualified_table_name - if destination_config.destination in ["mssql", "synapse"]: + if destination_config.destination_type in ["mssql", "synapse"]: rows = sql_client.execute_sql( f"SELECT TOP 1 url FROM {qual_name('issues')} WHERE id = 388089021" ) @@ -148,7 +148,7 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: # check changes where merged in with pipeline.sql_client() as sql_client: - if destination_config.destination in ["mssql", "synapse"]: + if destination_config.destination_type in ["mssql", "synapse"]: qual_name = sql_client.make_qualified_table_name rows_1 = sql_client.execute_sql( f"SELECT TOP 1 number FROM {qual_name('issues')} WHERE id = 1232152492" @@ -231,7 +231,7 @@ def test_truncate_staging_dataset(destination_config: DestinationTestConfigurati with staging_client: # except Athena + Iceberg which does not store tables in staging dataset if ( - destination_config.destination == "athena" + destination_config.destination_type == "athena" and destination_config.table_format == "iceberg" ): table_count = 0 @@ -257,7 +257,7 @@ def test_truncate_staging_dataset(destination_config: DestinationTestConfigurati _, staging_client = pipeline._get_destination_clients(pipeline.default_schema) with staging_client: # except for Athena which does not delete staging destination tables - if destination_config.destination == "athena": + if destination_config.destination_type == "athena": if destination_config.table_format == "iceberg": table_count = 0 else: @@ -279,7 +279,7 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non # redshift and athena, parquet and jsonl, exclude time types exclude_types: List[TDataType] = [] exclude_columns: List[str] = [] - if destination_config.destination in ( + if destination_config.destination_type in ( "redshift", "athena", "databricks", @@ -287,10 +287,13 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non ) and destination_config.file_format in ("parquet", "jsonl"): # Redshift copy doesn't support TIME column exclude_types.append("time") - if destination_config.destination == "synapse" and destination_config.file_format == "parquet": + if ( + destination_config.destination_type == "synapse" + and destination_config.file_format == "parquet" + ): # TIME columns are not supported for staged parquet loads into Synapse exclude_types.append("time") - if destination_config.destination in ( + if destination_config.destination_type in ( "redshift", "dremio", ) and destination_config.file_format in ( @@ -299,7 +302,10 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non ): # Redshift can't load fixed width binary columns from parquet exclude_columns.append("col7_precision") - if destination_config.destination == "databricks" and destination_config.file_format == "jsonl": + if ( + destination_config.destination_type == "databricks" + and destination_config.file_format == "jsonl" + ): exclude_types.extend(["decimal", "binary", "wei", "json", "date"]) exclude_columns.append("col1_precision") @@ -309,12 +315,12 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non # bigquery and clickhouse cannot load into JSON fields from parquet if destination_config.file_format == "parquet": - if destination_config.destination in ["bigquery"]: + if destination_config.destination_type in ["bigquery"]: # change datatype to text and then allow for it in the assert (parse_json_strings) column_schemas["col9_null"]["data_type"] = column_schemas["col9"]["data_type"] = "text" # redshift cannot load from json into VARBYTE if destination_config.file_format == "jsonl": - if destination_config.destination == "redshift": + if destination_config.destination_type == "redshift": # change the datatype to text which will result in inserting base64 (allow_base64_binary) binary_cols = ["col7", "col7_null"] for col in binary_cols: @@ -344,15 +350,15 @@ def my_source(): # parquet is not really good at inserting json, best we get are strings in JSON columns parse_json_strings = ( destination_config.file_format == "parquet" - and destination_config.destination in ["redshift", "bigquery", "snowflake"] + and destination_config.destination_type in ["redshift", "bigquery", "snowflake"] ) allow_base64_binary = ( destination_config.file_format == "jsonl" - and destination_config.destination in ["redshift", "clickhouse"] + and destination_config.destination_type in ["redshift", "clickhouse"] ) allow_string_binary = ( destination_config.file_format == "parquet" - and destination_config.destination in ["clickhouse"] + and destination_config.destination_type in ["clickhouse"] ) # content must equal assert_all_data_types_row( diff --git a/tests/load/pipeline/test_write_disposition_changes.py b/tests/load/pipeline/test_write_disposition_changes.py index 1a799059d0..f7d915903e 100644 --- a/tests/load/pipeline/test_write_disposition_changes.py +++ b/tests/load/pipeline/test_write_disposition_changes.py @@ -91,7 +91,9 @@ def test_switch_from_merge(destination_config: DestinationTestConfiguration): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, supports_merge=True), + ids=lambda x: x.name, ) @pytest.mark.parametrize("with_root_key", [True, False]) def test_switch_to_merge(destination_config: DestinationTestConfiguration, with_root_key: bool): @@ -126,7 +128,7 @@ def source(): # schemaless destinations allow adding of root key without the pipeline failing # they do not mind adding NOT NULL columns to tables with existing data (id NOT NULL is supported at all) # doing this will result in somewhat useless behavior - destination_allows_adding_root_key = destination_config.destination in [ + destination_allows_adding_root_key = destination_config.destination_type in [ "dremio", "clickhouse", "athena", diff --git a/tests/load/sources/filesystem/test_filesystem_source.py b/tests/load/sources/filesystem/test_filesystem_source.py index 947e7e9e1c..88796b9c4d 100644 --- a/tests/load/sources/filesystem/test_filesystem_source.py +++ b/tests/load/sources/filesystem/test_filesystem_source.py @@ -111,7 +111,9 @@ def test_fsspec_as_credentials(): @pytest.mark.parametrize("bucket_url", TESTS_BUCKET_URLS) @pytest.mark.parametrize( "destination_config", - destinations_configs(default_sql_configs=True, all_buckets_filesystem_configs=True), + destinations_configs( + default_sql_configs=True, supports_merge=True, all_buckets_filesystem_configs=True + ), ids=lambda x: x.name, ) def test_csv_transformers( @@ -126,9 +128,11 @@ def test_csv_transformers( # print(pipeline.last_trace.last_normalize_info) # must contain 24 rows of A881 - if not destination_config.destination == "filesystem": + if destination_config.destination_type != "filesystem": + with pipeline.sql_client() as client: + table_name = client.make_qualified_table_name("met_csv") # TODO: comment out when filesystem destination supports queries (data pond PR) - assert_query_data(pipeline, "SELECT code FROM met_csv", ["A881"] * 24) + assert_query_data(pipeline, f"SELECT code FROM {table_name}", ["A881"] * 24) # load the other folder that contains data for the same day + one other day # the previous data will be replaced @@ -138,9 +142,11 @@ def test_csv_transformers( assert_load_info(load_info) # print(pipeline.last_trace.last_normalize_info) # must contain 48 rows of A803 - if not destination_config.destination == "filesystem": + if destination_config.destination_type != "filesystem": + with pipeline.sql_client() as client: + table_name = client.make_qualified_table_name("met_csv") # TODO: comment out when filesystem destination supports queries (data pond PR) - assert_query_data(pipeline, "SELECT code FROM met_csv", ["A803"] * 48) + assert_query_data(pipeline, f"SELECT code FROM {table_name}", ["A803"] * 48) # and 48 rows in total -> A881 got replaced # print(pipeline.default_schema.to_pretty_yaml()) assert load_table_counts(pipeline, "met_csv") == {"met_csv": 48} diff --git a/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py b/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py index 7012602b4a..4f4e876fb6 100644 --- a/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py +++ b/tests/load/sources/sql_database/test_sql_database_source_all_destinations.py @@ -51,10 +51,10 @@ def test_load_sql_schema_loads_all_tables( schema=sql_source_db.schema, backend=backend, reflection_level="minimal", - type_adapter_callback=default_test_callback(destination_config.destination, backend), + type_adapter_callback=default_test_callback(destination_config.destination_type, backend), ) - if destination_config.destination == "bigquery" and backend == "connectorx": + if destination_config.destination_type == "bigquery" and backend == "connectorx": # connectorx generates nanoseconds time which bigquery cannot load source.has_precision.add_map(convert_time_to_us) source.has_precision_nullable.add_map(convert_time_to_us) @@ -91,10 +91,10 @@ def test_load_sql_schema_loads_all_tables_parallel( schema=sql_source_db.schema, backend=backend, reflection_level="minimal", - type_adapter_callback=default_test_callback(destination_config.destination, backend), + type_adapter_callback=default_test_callback(destination_config.destination_type, backend), ).parallelize() - if destination_config.destination == "bigquery" and backend == "connectorx": + if destination_config.destination_type == "bigquery" and backend == "connectorx": # connectorx generates nanoseconds time which bigquery cannot load source.has_precision.add_map(convert_time_to_us) source.has_precision_nullable.add_map(convert_time_to_us) diff --git a/tests/load/sqlalchemy/__init__.py b/tests/load/sqlalchemy/__init__.py new file mode 100644 index 0000000000..250c1f7626 --- /dev/null +++ b/tests/load/sqlalchemy/__init__.py @@ -0,0 +1,3 @@ +from tests.utils import skip_if_not_active + +skip_if_not_active("sqlalchemy") diff --git a/tests/load/sqlalchemy/test_sqlalchemy_configuration.py b/tests/load/sqlalchemy/test_sqlalchemy_configuration.py new file mode 100644 index 0000000000..281593aaf7 --- /dev/null +++ b/tests/load/sqlalchemy/test_sqlalchemy_configuration.py @@ -0,0 +1,23 @@ +import pytest + +import sqlalchemy as sa + +from dlt.common.configuration import resolve_configuration +from dlt.destinations.impl.sqlalchemy.configuration import ( + SqlalchemyClientConfiguration, + SqlalchemyCredentials, +) + + +def test_sqlalchemy_credentials_from_engine() -> None: + engine = sa.create_engine("sqlite:///:memory:") + + creds = resolve_configuration(SqlalchemyCredentials(engine)) + + # Url is taken from engine + assert creds.to_url() == sa.engine.make_url("sqlite:///:memory:") + # Engine is stored on the instance + assert creds.engine is engine + + assert creds.drivername == "sqlite" + assert creds.database == ":memory:" diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index a957c871bb..4359ac6885 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -28,7 +28,7 @@ def file_storage() -> FileStorage: @pytest.fixture(scope="function") def client(request) -> Iterator[InsertValuesJobClient]: - yield from yield_client_with_storage(request.param.destination) # type: ignore[misc] + yield from yield_client_with_storage(request.param.destination_factory()) # type: ignore[misc] @pytest.mark.essential diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 891ca5e809..84d08a5a89 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -27,7 +27,12 @@ ) from dlt.destinations.job_client_impl import SqlJobClientBase -from dlt.common.destination.reference import StateInfo, WithStagingDataset +from dlt.common.destination.reference import ( + StateInfo, + WithStagingDataset, + DestinationClientConfiguration, +) +from dlt.common.time import ensure_pendulum_datetime from tests.cases import table_update_and_row, assert_all_data_types_row from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage @@ -62,7 +67,7 @@ def file_storage() -> FileStorage: @pytest.fixture(scope="function") def client(request, naming) -> Iterator[SqlJobClientBase]: - yield from yield_client_with_storage(request.param.destination) + yield from yield_client_with_storage(request.param.destination_factory()) @pytest.fixture(scope="function") @@ -202,7 +207,7 @@ def test_complete_load(naming: str, client: SqlJobClientBase) -> None: assert load_rows[0][2] == 0 import datetime # noqa: I251 - assert type(load_rows[0][3]) is datetime.datetime + assert isinstance(ensure_pendulum_datetime(load_rows[0][3]), datetime.datetime) assert load_rows[0][4] == client.schema.version_hash # make sure that hash in loads exists in schema versions table versions_table = client.sql_client.make_qualified_table_name(version_table_name) @@ -417,7 +422,9 @@ def test_get_storage_table_with_all_types(client: SqlJobClientBase) -> None: assert len(storage_table) > 0 # column order must match TABLE_UPDATE storage_columns = list(storage_table.values()) - for c, expected_c in zip(TABLE_UPDATE, storage_columns): + for c, expected_c in zip( + TABLE_UPDATE, storage_columns + ): # TODO: c and expected_c need to be swapped # storage columns are returned with column names as in information schema assert client.capabilities.casefold_identifier(c["name"]) == expected_c["name"] # athena does not know wei data type and has no JSON type, time is not supported with parquet tables @@ -442,11 +449,19 @@ def test_get_storage_table_with_all_types(client: SqlJobClientBase) -> None: continue if client.config.destination_type == "dremio" and c["data_type"] == "json": continue + if not client.capabilities.supports_native_boolean and c["data_type"] == "bool": + # The reflected data type is probably either int or boolean depending on how the client is implemented + assert expected_c["data_type"] in ("bigint", "bool") + continue + assert c["data_type"] == expected_c["data_type"] @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True, exclude=("sqlalchemy",)), + indirect=True, + ids=lambda x: x.name, ) def test_preserve_column_order(client: SqlJobClientBase) -> None: schema = client.schema @@ -561,7 +576,7 @@ def test_load_with_all_types( if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") table_name = "event_test_table" + uniq_id() - column_schemas, data_row = get_columns_and_row_all_types(client.config.destination_type) + column_schemas, data_row = get_columns_and_row_all_types(client.config) # we should have identical content with all disposition types partial = client.schema.update_table( @@ -576,18 +591,21 @@ def test_load_with_all_types( client.schema._bump_version() client.update_stored_schema() - should_load_to_staging = client.should_load_data_to_staging_dataset(table_name) # type: ignore[attr-defined] - if should_load_to_staging: - with client.with_staging_dataset(): # type: ignore[attr-defined] - # create staging for merge dataset - client.initialize_storage() - client.update_stored_schema() + if isinstance(client, WithStagingDataset): + should_load_to_staging = client.should_load_data_to_staging_dataset(table_name) + if should_load_to_staging: + with client.with_staging_dataset(): + # create staging for merge dataset + client.initialize_storage() + client.update_stored_schema() - with client.sql_client.with_alternative_dataset_name( - client.sql_client.staging_dataset_name - if should_load_to_staging - else client.sql_client.dataset_name - ): + with client.sql_client.with_alternative_dataset_name( + client.sql_client.staging_dataset_name + if should_load_to_staging + else client.sql_client.dataset_name + ): + canonical_name = client.sql_client.make_qualified_table_name(table_name) + else: canonical_name = client.sql_client.make_qualified_table_name(table_name) # write row print(data_row) @@ -633,7 +651,7 @@ def test_write_dispositions( os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy table_name = "event_test_table" + uniq_id() - column_schemas, data_row = get_columns_and_row_all_types(client.config.destination_type) + column_schemas, data_row = get_columns_and_row_all_types(client.config) client.schema.update_table( new_table(table_name, write_disposition=write_disposition, columns=column_schemas.values()) ) @@ -646,6 +664,8 @@ def test_write_dispositions( client.update_stored_schema() if write_disposition == "merge": + if not client.capabilities.supported_merge_strategies: + pytest.skip("destination does not support merge") # add root key client.schema.tables[table_name]["columns"]["col1"]["root_key"] = True # create staging for merge dataset @@ -665,9 +685,11 @@ def test_write_dispositions( with io.BytesIO() as f: write_dataset(client, f, [data_row], column_schemas) query = f.getvalue() - if client.should_load_data_to_staging_dataset(table_name): # type: ignore[attr-defined] + if isinstance( + client, WithStagingDataset + ) and client.should_load_data_to_staging_dataset(table_name): # load to staging dataset on merge - with client.with_staging_dataset(): # type: ignore[attr-defined] + with client.with_staging_dataset(): expect_load_file(client, file_storage, query, t) else: # load directly on other @@ -737,7 +759,7 @@ def test_get_resumed_job(client: SqlJobClientBase, file_storage: FileStorage) -> ) def test_default_schema_name_init_storage(destination_config: DestinationTestConfiguration) -> None: with cm_yield_client_with_storage( - destination_config.destination, + destination_config.destination_factory(), default_config_values={ "default_schema_name": ( # pass the schema that is a default schema. that should create dataset with the name `dataset_name` "event" @@ -748,7 +770,7 @@ def test_default_schema_name_init_storage(destination_config: DestinationTestCon assert client.sql_client.has_dataset() with cm_yield_client_with_storage( - destination_config.destination, + destination_config.destination_factory(), default_config_values={ "default_schema_name": ( None # no default_schema. that should create dataset with the name `dataset_name` @@ -759,7 +781,7 @@ def test_default_schema_name_init_storage(destination_config: DestinationTestCon assert client.sql_client.has_dataset() with cm_yield_client_with_storage( - destination_config.destination, + destination_config.destination_factory(), default_config_values={ "default_schema_name": ( # the default schema is not event schema . that should create dataset with the name `dataset_name` with schema suffix "event_2" @@ -788,7 +810,8 @@ def test_get_stored_state( os.environ["SCHEMA__NAMING"] = naming_convention with cm_yield_client_with_storage( - destination_config.destination, default_config_values={"default_schema_name": None} + destination_config.destination_factory(), + default_config_values={"default_schema_name": None}, ) as client: # event schema with event table if not client.capabilities.preferred_loader_file_format: @@ -814,6 +837,8 @@ def test_get_stored_state( # get state stored_state = client.get_stored_state("pipeline") + # Ensure timezone aware datetime for comparing + stored_state.created_at = pendulum.instance(stored_state.created_at) assert doc == stored_state.as_doc() @@ -850,7 +875,8 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: assert len(db_rows) == expected_rows with cm_yield_client_with_storage( - destination_config.destination, default_config_values={"default_schema_name": None} + destination_config.destination_factory(), + default_config_values={"default_schema_name": None}, ) as client: # event schema with event table if not client.capabilities.preferred_loader_file_format: @@ -909,7 +935,11 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: "mandatory_column", "text", nullable=False ) client.schema._bump_version() - if destination_config.destination == "clickhouse": + if destination_config.destination_type == "clickhouse" or ( + # mysql allows adding not-null columns (they have an implicit default) + destination_config.destination_type == "sqlalchemy" + and client.sql_client.dialect_name == "mysql" + ): client.update_stored_schema() else: with pytest.raises(DatabaseException) as py_ex: @@ -942,11 +972,16 @@ def normalize_rows(rows: List[Dict[str, Any]], naming: NamingConvention) -> None row[naming.normalize_identifier(k)] = row.pop(k) -def get_columns_and_row_all_types(destination_type: str): +def get_columns_and_row_all_types(destination_config: DestinationClientConfiguration): + exclude_types = [] + if destination_config.destination_type in ["databricks", "clickhouse", "motherduck"]: + exclude_types.append("time") + if destination_config.destination_name == "sqlalchemy_sqlite": + exclude_types.extend(["decimal", "wei"]) return table_update_and_row( # TIME + parquet is actually a duckdb problem: https://github.com/duckdb/duckdb/pull/13283 - exclude_types=( - ["time"] if destination_type in ["databricks", "clickhouse", "motherduck"] else None + exclude_types=exclude_types, # type: ignore[arg-type] + exclude_columns=( + ["col4_precision"] if destination_config.destination_type in ["motherduck"] else None ), - exclude_columns=["col4_precision"] if destination_type in ["motherduck"] else None, ) diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index e167f0ceda..199b4b83b7 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -1,7 +1,7 @@ import os import pytest import datetime # noqa: I251 -from typing import Iterator, Any +from typing import Iterator, Any, Tuple, Type, Union from threading import Thread, Event from time import sleep @@ -20,7 +20,7 @@ from dlt.destinations.sql_client import DBApiCursor, SqlClientBase from dlt.destinations.job_client_impl import SqlJobClientBase from dlt.destinations.typing import TNativeConn -from dlt.common.time import ensure_pendulum_datetime +from dlt.common.time import ensure_pendulum_datetime, to_py_datetime from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage from tests.load.utils import ( @@ -28,6 +28,7 @@ prepare_table, AWS_BUCKET, destinations_configs, + DestinationTestConfiguration, ) # mark all tests as essential, do not remove @@ -46,7 +47,8 @@ def file_storage() -> FileStorage: @pytest.fixture(scope="function") def client(request, naming) -> Iterator[SqlJobClientBase]: - yield from yield_client_with_storage(request.param.destination) + param: DestinationTestConfiguration = request.param + yield from yield_client_with_storage(param.destination_factory()) @pytest.fixture(scope="function") @@ -62,7 +64,9 @@ def naming(request) -> str: @pytest.mark.parametrize( "client", destinations_configs( - default_sql_configs=True, exclude=["mssql", "synapse", "dremio", "clickhouse"] + # Only databases that support search path or equivalent + default_sql_configs=True, + exclude=["mssql", "synapse", "dremio", "clickhouse", "sqlalchemy"], ), indirect=True, ids=lambda x: x.name, @@ -145,6 +149,7 @@ def test_has_dataset(naming: str, client: SqlJobClientBase) -> None: ) def test_create_drop_dataset(naming: str, client: SqlJobClientBase) -> None: # client.sql_client.create_dataset() + # Dataset is already create in fixture, so next time it fails with pytest.raises(DatabaseException): client.sql_client.create_dataset() client.sql_client.drop_dataset() @@ -208,14 +213,19 @@ def test_execute_sql(client: SqlJobClientBase) -> None: assert len(rows) == 1 # print(rows) assert rows[0][0] == "event" - assert isinstance(rows[0][1], datetime.datetime) + assert isinstance(ensure_pendulum_datetime(rows[0][1]), datetime.datetime) assert rows[0][0] == "event" # print(rows[0][1]) # print(type(rows[0][1])) - # convert to pendulum to make sure it is supported by dbapi + # ensure datetime obj to make sure it is supported by dbapi + inserted_at = to_py_datetime(ensure_pendulum_datetime(rows[0][1])) + if client.config.destination_name == "sqlalchemy_sqlite": + # timezone aware datetime is not supported by sqlite + inserted_at = inserted_at.replace(tzinfo=None) + rows = client.sql_client.execute_sql( f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at = %s", - ensure_pendulum_datetime(rows[0][1]), + inserted_at, ) assert len(rows) == 1 # use rows in subsequent test @@ -241,20 +251,20 @@ def test_execute_sql(client: SqlJobClientBase) -> None: def test_execute_ddl(client: SqlJobClientBase) -> None: uniq_suffix = uniq_id() client.update_stored_schema() - table_name = prepare_temp_table(client) + table_name, py_type = prepare_temp_table(client) f_q_table_name = client.sql_client.make_qualified_table_name(table_name) client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (1.0)") rows = client.sql_client.execute_sql(f"SELECT * FROM {f_q_table_name}") - assert rows[0][0] == Decimal("1.0") + assert rows[0][0] == py_type("1.0") if client.config.destination_type == "dremio": username = client.config.credentials["username"] view_name = f'"@{username}"."view_tmp_{uniq_suffix}"' else: # create view, note that bigquery will not let you execute a view that does not have fully qualified table names. view_name = client.sql_client.make_qualified_table_name(f"view_tmp_{uniq_suffix}") - client.sql_client.execute_sql(f"CREATE VIEW {view_name} AS (SELECT * FROM {f_q_table_name});") + client.sql_client.execute_sql(f"CREATE VIEW {view_name} AS SELECT * FROM {f_q_table_name};") rows = client.sql_client.execute_sql(f"SELECT * FROM {view_name}") - assert rows[0][0] == Decimal("1.0") + assert rows[0][0] == py_type("1.0") @pytest.mark.parametrize( @@ -275,7 +285,7 @@ def test_execute_query(client: SqlJobClientBase) -> None: rows = curr.fetchall() assert len(rows) == 1 assert rows[0][0] == "event" - assert isinstance(rows[0][1], datetime.datetime) + assert isinstance(ensure_pendulum_datetime(rows[0][1]), datetime.datetime) with client.sql_client.execute_query( f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at = %s", rows[0][1], @@ -285,7 +295,7 @@ def test_execute_query(client: SqlJobClientBase) -> None: assert rows[0][0] == "event" with client.sql_client.execute_query( f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at = %s", - pendulum.now().add(seconds=1), + to_py_datetime(pendulum.now().add(seconds=1)), ) as curr: rows = curr.fetchall() assert len(rows) == 0 @@ -293,7 +303,7 @@ def test_execute_query(client: SqlJobClientBase) -> None: with client.sql_client.execute_query( f"SELECT schema_name, inserted_at FROM {version_table_name} WHERE inserted_at =" " %(date)s", - date=pendulum.now().add(seconds=1), + date=to_py_datetime(pendulum.now().add(seconds=1)), ) as curr: rows = curr.fetchall() assert len(rows) == 0 @@ -314,7 +324,7 @@ def test_execute_df(client: SqlJobClientBase) -> None: total_records = 3000 client.update_stored_schema() - table_name = prepare_temp_table(client) + table_name, py_type = prepare_temp_table(client) f_q_table_name = client.sql_client.make_qualified_table_name(table_name) if client.capabilities.insert_values_writer_type == "default": @@ -415,8 +425,7 @@ def test_database_exceptions(client: SqlJobClientBase) -> None: assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) if client.config.destination_type not in ["dremio", "clickhouse"]: with pytest.raises(DatabaseUndefinedRelation) as term_ex: - with client.sql_client.execute_query("DROP SCHEMA UNKNOWN"): - pass + client.sql_client.drop_dataset() assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) @@ -427,29 +436,29 @@ def test_database_exceptions(client: SqlJobClientBase) -> None: ids=lambda x: x.name, ) def test_commit_transaction(client: SqlJobClientBase) -> None: - table_name = prepare_temp_table(client) + table_name, py_type = prepare_temp_table(client) f_q_table_name = client.sql_client.make_qualified_table_name(table_name) with client.sql_client.begin_transaction(): - client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0")) + client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", py_type("1.0")) # check row still in transaction rows = client.sql_client.execute_sql( - f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) assert len(rows) == 1 # check row after commit rows = client.sql_client.execute_sql( - f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) assert len(rows) == 1 assert rows[0][0] == 1.0 with client.sql_client.begin_transaction() as tx: client.sql_client.execute_sql( - f"DELETE FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"DELETE FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) # explicit commit tx.commit_transaction() rows = client.sql_client.execute_sql( - f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) assert len(rows) == 0 @@ -463,22 +472,22 @@ def test_commit_transaction(client: SqlJobClientBase) -> None: def test_rollback_transaction(client: SqlJobClientBase) -> None: if client.capabilities.supports_transactions is False: pytest.skip("Destination does not support tx") - table_name = prepare_temp_table(client) + table_name, py_type = prepare_temp_table(client) f_q_table_name = client.sql_client.make_qualified_table_name(table_name) # test python exception with pytest.raises(RuntimeError): with client.sql_client.begin_transaction(): client.sql_client.execute_sql( - f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0") + f"INSERT INTO {f_q_table_name} VALUES (%s)", py_type("1.0") ) rows = client.sql_client.execute_sql( - f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) assert len(rows) == 1 # python exception triggers rollback raise RuntimeError("ROLLBACK") rows = client.sql_client.execute_sql( - f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) assert len(rows) == 0 @@ -487,23 +496,23 @@ def test_rollback_transaction(client: SqlJobClientBase) -> None: with pytest.raises(DatabaseException): with client.sql_client.begin_transaction(): client.sql_client.execute_sql( - f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0") + f"INSERT INTO {f_q_table_name} VALUES (%s)", py_type("1.0") ) # table does not exist client.sql_client.execute_sql( - f"SELECT col FROM {f_q_wrong_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_wrong_table_name} WHERE col = %s", py_type("1.0") ) rows = client.sql_client.execute_sql( - f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) assert len(rows) == 0 # test explicit rollback with client.sql_client.begin_transaction() as tx: - client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0")) + client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", py_type("1.0")) tx.rollback_transaction() rows = client.sql_client.execute_sql( - f"SELECT col FROM {f_q_table_name} WHERE col = %s", Decimal("1.0") + f"SELECT col FROM {f_q_table_name} WHERE col = %s", py_type("1.0") ) assert len(rows) == 0 @@ -524,12 +533,15 @@ def test_rollback_transaction(client: SqlJobClientBase) -> None: def test_transaction_isolation(client: SqlJobClientBase) -> None: if client.capabilities.supports_transactions is False: pytest.skip("Destination does not support tx") - table_name = prepare_temp_table(client) + if client.config.destination_name == "sqlalchemy_sqlite": + # because other schema names must be attached for each connection + client.sql_client.dataset_name = "main" + table_name, py_type = prepare_temp_table(client) f_q_table_name = client.sql_client.make_qualified_table_name(table_name) event = Event() event.clear() - def test_thread(thread_id: Decimal) -> None: + def test_thread(thread_id: Union[Decimal, float]) -> None: # make a copy of the sql_client thread_client = client.sql_client.__class__( client.sql_client.dataset_name, @@ -543,8 +555,8 @@ def test_thread(thread_id: Decimal) -> None: event.wait() with client.sql_client.begin_transaction() as tx: - client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", Decimal("1.0")) - t = Thread(target=test_thread, daemon=True, args=(Decimal("2.0"),)) + client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES (%s)", py_type("1.0")) + t = Thread(target=test_thread, daemon=True, args=(py_type("2.0"),)) t.start() # thread 2.0 inserts sleep(3.0) @@ -555,17 +567,23 @@ def test_thread(thread_id: Decimal) -> None: t.join() # just in case close the connection - client.sql_client.close_connection() - # re open connection - client.sql_client.open_connection() + if ( + client.config.destination_name != "sqlalchemy_sqlite" + ): # keep sqlite connection to maintain attached datasets + client.sql_client.close_connection() + # re open connection + client.sql_client.open_connection() rows = client.sql_client.execute_sql(f"SELECT col FROM {f_q_table_name} ORDER BY col") assert len(rows) == 1 # only thread 2 is left - assert rows[0][0] == Decimal("2.0") + assert rows[0][0] == py_type("2.0") @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True, exclude=["sqlalchemy"]), + indirect=True, + ids=lambda x: x.name, ) def test_max_table_identifier_length(client: SqlJobClientBase) -> None: if client.capabilities.max_identifier_length >= 65536: @@ -595,7 +613,10 @@ def test_max_table_identifier_length(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True, exclude=["sqlalchemy"]), + indirect=True, + ids=lambda x: x.name, ) def test_max_column_identifier_length(client: SqlJobClientBase) -> None: if client.capabilities.max_column_identifier_length >= 65536: @@ -620,7 +641,7 @@ def test_max_column_identifier_length(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( "client", - destinations_configs(default_sql_configs=True, exclude=["databricks"]), + destinations_configs(default_sql_configs=True, exclude=["databricks", "sqlalchemy"]), indirect=True, ids=lambda x: x.name, ) @@ -674,11 +695,13 @@ def assert_load_id(sql_client: SqlClientBase[TNativeConn], load_id: str) -> None assert len(rows) == 1 -def prepare_temp_table(client: SqlJobClientBase) -> str: +def prepare_temp_table(client: SqlJobClientBase) -> Tuple[str, Type[Union[Decimal, float]]]: + """Return the table name and py type of value to insert""" uniq_suffix = uniq_id() table_name = f"tmp_{uniq_suffix}" ddl_suffix = "" coltype = "numeric" + py_type: Union[Type[Decimal], Type[float]] = Decimal if client.config.destination_type == "athena": ddl_suffix = ( f"LOCATION '{AWS_BUCKET}/ci/{table_name}' TBLPROPERTIES ('table_type'='ICEBERG'," @@ -686,6 +709,10 @@ def prepare_temp_table(client: SqlJobClientBase) -> str: ) coltype = "bigint" qualified_table_name = table_name + elif client.config.destination_name == "sqlalchemy_sqlite": + coltype = "float" + py_type = float + qualified_table_name = client.sql_client.make_qualified_table_name(table_name) elif client.config.destination_type == "clickhouse": ddl_suffix = "ENGINE = MergeTree() ORDER BY col" qualified_table_name = client.sql_client.make_qualified_table_name(table_name) @@ -694,4 +721,4 @@ def prepare_temp_table(client: SqlJobClientBase) -> str: client.sql_client.execute_sql( f"CREATE TABLE {qualified_table_name} (col {coltype}) {ddl_suffix};" ) - return table_name + return table_name, py_type diff --git a/tests/load/utils.py b/tests/load/utils.py index 0eaf68d8f8..1c47291a6c 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -144,7 +144,7 @@ class DestinationTestConfiguration: """Class for defining test setup for one destination.""" - destination: str + destination_type: str staging: Optional[TDestinationReferenceArg] = None file_format: Optional[TLoaderFileFormat] = None table_format: Optional[TTableFormat] = None @@ -160,10 +160,17 @@ class DestinationTestConfiguration: dev_mode: bool = False credentials: Optional[Union[CredentialsConfiguration, Dict[str, Any]]] = None env_vars: Optional[Dict[str, str]] = None + destination_name: Optional[str] = None + + def destination_factory(self, **kwargs) -> Destination[Any, Any]: + dest_type = kwargs.pop("destination", self.destination_type) + dest_name = kwargs.pop("destination_name", self.destination_name) + self.setup() + return Destination.from_reference(dest_type, destination_name=dest_name, **kwargs) @property def name(self) -> str: - name: str = self.destination + name: str = self.destination_name or self.destination_type if self.file_format: name += f"-{self.file_format}" if self.table_format: @@ -196,7 +203,7 @@ def setup(self) -> None: os.environ[f"DESTINATION__{k.upper()}"] = str(v) # For the filesystem destinations we disable compression to make analyzing the result easier - if self.destination == "filesystem" or self.disable_compression: + if self.destination_type == "filesystem" or self.disable_compression: os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" if self.credentials is not None: @@ -211,11 +218,16 @@ def setup_pipeline( self, pipeline_name: str, dataset_name: str = None, dev_mode: bool = False, **kwargs ) -> dlt.Pipeline: """Convenience method to setup pipeline with this configuration""" + self.dev_mode = dev_mode - self.setup() + destination = kwargs.pop("destination", None) + if destination is None: + destination = self.destination_factory(**kwargs) + else: + self.setup() pipeline = dlt.pipeline( pipeline_name=pipeline_name, - destination=kwargs.pop("destination", self.destination), + destination=destination, staging=kwargs.pop("staging", self.staging), dataset_name=dataset_name or pipeline_name, dev_mode=dev_mode, @@ -274,13 +286,13 @@ def destinations_configs( default_sql_configs_with_staging = [ # Athena needs filesystem staging, which will be automatically set; we have to supply a bucket url though. DestinationTestConfiguration( - destination="athena", + destination_type="athena", file_format="parquet", supports_merge=False, bucket_url=AWS_BUCKET, ), DestinationTestConfiguration( - destination="athena", + destination_type="athena", file_format="parquet", bucket_url=AWS_BUCKET, supports_merge=True, @@ -293,13 +305,16 @@ def destinations_configs( # default non staging sql based configs, one per destination if default_sql_configs: destination_configs += [ - DestinationTestConfiguration(destination=destination) + DestinationTestConfiguration(destination_type=destination) for destination in SQL_DESTINATIONS - if destination not in ("athena", "synapse", "databricks", "dremio", "clickhouse") + if destination + not in ("athena", "synapse", "databricks", "dremio", "clickhouse", "sqlalchemy") ] destination_configs += [ - DestinationTestConfiguration(destination="duckdb", file_format="parquet"), - DestinationTestConfiguration(destination="motherduck", file_format="insert_values"), + DestinationTestConfiguration(destination_type="duckdb", file_format="parquet"), + DestinationTestConfiguration( + destination_type="motherduck", file_format="insert_values" + ), ] # add Athena staging configs @@ -307,12 +322,27 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration( - destination="clickhouse", file_format="jsonl", supports_dbt=False + destination_type="sqlalchemy", + supports_merge=False, + supports_dbt=False, + destination_name="sqlalchemy_mysql", + ), + DestinationTestConfiguration( + destination_type="sqlalchemy", + supports_merge=False, + supports_dbt=False, + destination_name="sqlalchemy_sqlite", + ), + ] + + destination_configs += [ + DestinationTestConfiguration( + destination_type="clickhouse", file_format="jsonl", supports_dbt=False ) ] destination_configs += [ DestinationTestConfiguration( - destination="databricks", + destination_type="databricks", file_format="parquet", bucket_url=AZ_BUCKET, extra_info="az-authorization", @@ -321,7 +351,7 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration( - destination="dremio", + destination_type="dremio", staging=filesystem(destination_name="minio"), file_format="parquet", bucket_url=AWS_BUCKET, @@ -329,24 +359,24 @@ def destinations_configs( ) ] destination_configs += [ - # DestinationTestConfiguration(destination="mssql", supports_dbt=False), - DestinationTestConfiguration(destination="synapse", supports_dbt=False), + # DestinationTestConfiguration(destination_type="mssql", supports_dbt=False), + DestinationTestConfiguration(destination_type="synapse", supports_dbt=False), ] # sanity check that when selecting default destinations, one of each sql destination is actually # provided - assert set(SQL_DESTINATIONS) == {d.destination for d in destination_configs} + assert set(SQL_DESTINATIONS) == {d.destination_type for d in destination_configs} if default_vector_configs: destination_configs += [ - DestinationTestConfiguration(destination="weaviate"), - DestinationTestConfiguration(destination="lancedb"), + DestinationTestConfiguration(destination_type="weaviate"), + DestinationTestConfiguration(destination_type="lancedb"), DestinationTestConfiguration( - destination="qdrant", + destination_type="qdrant", credentials=dict(path=str(Path(FILE_BUCKET) / "qdrant_data")), extra_info="local-file", ), - DestinationTestConfiguration(destination="qdrant", extra_info="server"), + DestinationTestConfiguration(destination_type="qdrant", extra_info="server"), ] if (default_sql_configs or all_staging_configs) and not default_sql_configs: @@ -356,7 +386,7 @@ def destinations_configs( if default_staging_configs or all_staging_configs: destination_configs += [ DestinationTestConfiguration( - destination="redshift", + destination_type="redshift", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, @@ -364,14 +394,14 @@ def destinations_configs( extra_info="s3-role", ), DestinationTestConfiguration( - destination="bigquery", + destination_type="bigquery", staging="filesystem", file_format="parquet", bucket_url=GCS_BUCKET, extra_info="gcs-authorization", ), DestinationTestConfiguration( - destination="snowflake", + destination_type="snowflake", staging="filesystem", file_format="jsonl", bucket_url=GCS_BUCKET, @@ -379,14 +409,14 @@ def destinations_configs( extra_info="gcs-integration", ), DestinationTestConfiguration( - destination="snowflake", + destination_type="snowflake", staging="filesystem", file_format="jsonl", bucket_url=AWS_BUCKET, extra_info="s3-integration", ), DestinationTestConfiguration( - destination="snowflake", + destination_type="snowflake", staging="filesystem", file_format="jsonl", bucket_url=AWS_BUCKET, @@ -394,7 +424,7 @@ def destinations_configs( extra_info="s3-integration", ), DestinationTestConfiguration( - destination="snowflake", + destination_type="snowflake", staging="filesystem", file_format="jsonl", bucket_url=AZ_BUCKET, @@ -402,14 +432,14 @@ def destinations_configs( extra_info="az-integration", ), DestinationTestConfiguration( - destination="snowflake", + destination_type="snowflake", staging="filesystem", file_format="jsonl", bucket_url=AZ_BUCKET, extra_info="az-authorization", ), DestinationTestConfiguration( - destination="databricks", + destination_type="databricks", staging="filesystem", file_format="jsonl", bucket_url=AWS_BUCKET, @@ -417,7 +447,7 @@ def destinations_configs( disable_compression=True, ), DestinationTestConfiguration( - destination="databricks", + destination_type="databricks", staging="filesystem", file_format="jsonl", bucket_url=AZ_BUCKET, @@ -425,14 +455,14 @@ def destinations_configs( disable_compression=True, ), DestinationTestConfiguration( - destination="databricks", + destination_type="databricks", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, extra_info="s3-authorization", ), DestinationTestConfiguration( - destination="synapse", + destination_type="synapse", staging="filesystem", file_format="parquet", bucket_url=AZ_BUCKET, @@ -440,35 +470,35 @@ def destinations_configs( disable_compression=True, ), DestinationTestConfiguration( - destination="clickhouse", + destination_type="clickhouse", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, extra_info="s3-authorization", ), DestinationTestConfiguration( - destination="clickhouse", + destination_type="clickhouse", staging="filesystem", file_format="parquet", bucket_url=AZ_BUCKET, extra_info="az-authorization", ), DestinationTestConfiguration( - destination="clickhouse", + destination_type="clickhouse", staging="filesystem", file_format="jsonl", bucket_url=AZ_BUCKET, extra_info="az-authorization", ), DestinationTestConfiguration( - destination="clickhouse", + destination_type="clickhouse", staging="filesystem", file_format="jsonl", bucket_url=AWS_BUCKET, extra_info="s3-authorization", ), DestinationTestConfiguration( - destination="dremio", + destination_type="dremio", staging=filesystem(destination_name="minio"), file_format="parquet", bucket_url=AWS_BUCKET, @@ -479,35 +509,35 @@ def destinations_configs( if all_staging_configs: destination_configs += [ DestinationTestConfiguration( - destination="redshift", + destination_type="redshift", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, extra_info="credential-forwarding", ), DestinationTestConfiguration( - destination="snowflake", + destination_type="snowflake", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, extra_info="credential-forwarding", ), DestinationTestConfiguration( - destination="redshift", + destination_type="redshift", staging="filesystem", file_format="jsonl", bucket_url=AWS_BUCKET, extra_info="credential-forwarding", ), DestinationTestConfiguration( - destination="bigquery", + destination_type="bigquery", staging="filesystem", file_format="jsonl", bucket_url=GCS_BUCKET, extra_info="gcs-authorization", ), DestinationTestConfiguration( - destination="synapse", + destination_type="synapse", staging="filesystem", file_format="parquet", bucket_url=AZ_BUCKET, @@ -520,7 +550,7 @@ def destinations_configs( if local_filesystem_configs: destination_configs += [ DestinationTestConfiguration( - destination="filesystem", + destination_type="filesystem", bucket_url=FILE_BUCKET, file_format="insert_values", supports_merge=False, @@ -528,7 +558,7 @@ def destinations_configs( ] destination_configs += [ DestinationTestConfiguration( - destination="filesystem", + destination_type="filesystem", bucket_url=FILE_BUCKET, file_format="parquet", supports_merge=False, @@ -536,7 +566,7 @@ def destinations_configs( ] destination_configs += [ DestinationTestConfiguration( - destination="filesystem", + destination_type="filesystem", bucket_url=FILE_BUCKET, file_format="jsonl", supports_merge=False, @@ -547,7 +577,7 @@ def destinations_configs( for bucket in DEFAULT_BUCKETS: destination_configs += [ DestinationTestConfiguration( - destination="filesystem", + destination_type="filesystem", bucket_url=bucket, extra_info=bucket, supports_merge=False, @@ -558,7 +588,7 @@ def destinations_configs( for bucket in DEFAULT_BUCKETS: destination_configs += [ DestinationTestConfiguration( - destination="filesystem", + destination_type="filesystem", bucket_url=bucket, extra_info=bucket, table_format="delta", @@ -577,27 +607,29 @@ def destinations_configs( # filter out non active destinations destination_configs = [ - conf for conf in destination_configs if conf.destination in ACTIVE_DESTINATIONS + conf for conf in destination_configs if conf.destination_type in ACTIVE_DESTINATIONS ] # filter out destinations not in subset if subset: - destination_configs = [conf for conf in destination_configs if conf.destination in subset] + destination_configs = [ + conf for conf in destination_configs if conf.destination_type in subset + ] if bucket_subset: destination_configs = [ conf for conf in destination_configs - if conf.destination != "filesystem" or conf.bucket_url in bucket_subset + if conf.destination_type != "filesystem" or conf.bucket_url in bucket_subset ] if exclude: destination_configs = [ - conf for conf in destination_configs if conf.destination not in exclude + conf for conf in destination_configs if conf.destination_type not in exclude ] if bucket_exclude: destination_configs = [ conf for conf in destination_configs - if conf.destination != "filesystem" or conf.bucket_url not in bucket_exclude + if conf.destination_type != "filesystem" or conf.bucket_url not in bucket_exclude ] if with_file_format: if not isinstance(with_file_format, Sequence): @@ -774,14 +806,14 @@ def prepare_table( def yield_client( - destination_type: str, + destination_ref: TDestinationReferenceArg, dataset_name: str = None, default_config_values: StrAny = None, schema_name: str = "event", ) -> Iterator[SqlJobClientBase]: os.environ.pop("DATASET_NAME", None) # import destination reference by name - destination = Destination.from_reference(destination_type) + destination = Destination.from_reference(destination_ref) # create initial config dest_config: DestinationClientDwhConfiguration = None dest_config = destination.spec() # type: ignore @@ -806,7 +838,7 @@ def yield_client( client: SqlJobClientBase = None # athena requires staging config to be present, so stick this in there here - if destination_type == "athena": + if destination.destination_name == "athena": staging_config = DestinationClientStagingConfiguration( bucket_url=AWS_BUCKET, )._bind_dataset_name(dataset_name=dest_config.dataset_name) @@ -819,7 +851,7 @@ def yield_client( ConfigSectionContext( sections=( "destination", - destination_type, + destination.destination_name, ) ) ): @@ -829,23 +861,23 @@ def yield_client( @contextlib.contextmanager def cm_yield_client( - destination_type: str, + destination: TDestinationReferenceArg, dataset_name: str, default_config_values: StrAny = None, schema_name: str = "event", ) -> Iterator[SqlJobClientBase]: - return yield_client(destination_type, dataset_name, default_config_values, schema_name) + return yield_client(destination, dataset_name, default_config_values, schema_name) def yield_client_with_storage( - destination_type: str, default_config_values: StrAny = None, schema_name: str = "event" + destination: TDestinationReferenceArg, + default_config_values: StrAny = None, + schema_name: str = "event", ) -> Iterator[SqlJobClientBase]: # create dataset with random name dataset_name = "test_" + uniq_id() - with cm_yield_client( - destination_type, dataset_name, default_config_values, schema_name - ) as client: + with cm_yield_client(destination, dataset_name, default_config_values, schema_name) as client: client.initialize_storage() yield client if client.is_storage_initialized(): @@ -866,9 +898,11 @@ def delete_dataset(client: SqlClientBase[Any], normalized_dataset_name: str) -> @contextlib.contextmanager def cm_yield_client_with_storage( - destination_type: str, default_config_values: StrAny = None, schema_name: str = "event" + destination: TDestinationReferenceArg, + default_config_values: StrAny = None, + schema_name: str = "event", ) -> Iterator[SqlJobClientBase]: - return yield_client_with_storage(destination_type, default_config_values, schema_name) + return yield_client_with_storage(destination, default_config_values, schema_name) def write_dataset( diff --git a/tests/pipeline/test_pipeline_extra.py b/tests/pipeline/test_pipeline_extra.py index 1fe6231279..821bec8e08 100644 --- a/tests/pipeline/test_pipeline_extra.py +++ b/tests/pipeline/test_pipeline_extra.py @@ -56,8 +56,8 @@ class BaseModel: # type: ignore[no-redef] def test_create_pipeline_all_destinations(destination_config: DestinationTestConfiguration) -> None: # create pipelines, extract and normalize. that should be possible without installing any dependencies p = dlt.pipeline( - pipeline_name=destination_config.destination + "_pipeline", - destination=destination_config.destination, + pipeline_name=destination_config.destination_type + "_pipeline", + destination=destination_config.destination_type, staging=destination_config.staging, ) # are capabilities injected diff --git a/tests/utils.py b/tests/utils.py index 9facdfc375..e90ac5a626 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -54,6 +54,7 @@ "databricks", "clickhouse", "dremio", + "sqlalchemy", } NON_SQL_DESTINATIONS = { "filesystem",