Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sqlalchemy destination #1734

Merged
merged 37 commits into from
Sep 14, 2024
Merged
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
3f06e1e
Implement sqlalchemy loader
steinitzu Sep 11, 2024
31d9033
Support destination name in tests
steinitzu Sep 11, 2024
e3eaa43
Some job client/sql client tests running on sqlite
steinitzu Sep 11, 2024
2973526
Fix more tests
steinitzu Sep 4, 2024
8caf2f3
ALl sqlite tests passing
steinitzu Sep 5, 2024
2a30b36
Add sqlalchemy tests in ci
steinitzu Sep 5, 2024
e7f56c9
Type errors
steinitzu Sep 5, 2024
11d52db
Test sqlalchemy in own workflow
steinitzu Sep 5, 2024
9d37ea6
Fix tests, type errors
steinitzu Sep 11, 2024
cdeb17d
Fix config
steinitzu Sep 6, 2024
a730a91
CI fix
steinitzu Sep 6, 2024
3326580
Add alembic to handle ALTER TABLE
steinitzu Sep 11, 2024
567359d
FIx workflow
steinitzu Sep 6, 2024
babcd3c
Install mysqlclient in venv
steinitzu Sep 6, 2024
9dec1c5
Mysql service version
steinitzu Sep 6, 2024
3e282ea
Single fail
steinitzu Sep 6, 2024
0439015
mysql healtcheck
steinitzu Sep 6, 2024
61c8355
No localhost
steinitzu Sep 6, 2024
84dc4cf
Remove weaviate
steinitzu Sep 6, 2024
4bcc425
Change ubuntu version
steinitzu Sep 6, 2024
a9b7e49
Debug sqlite version
steinitzu Sep 6, 2024
e0a0781
Revert
steinitzu Sep 6, 2024
98f8de2
Use py datetime in tests
steinitzu Sep 6, 2024
4f8d8f6
Test on sqlalchemy 1.4 and 2
steinitzu Sep 6, 2024
79631b2
Lint, no cli tests
steinitzu Sep 6, 2024
8068595
Update lockfile
steinitzu Sep 11, 2024
a9c89a0
Fix test, complex -> json
steinitzu Sep 11, 2024
874c871
Refactor type mapper
steinitzu Sep 11, 2024
a69d749
Update tests destination config
steinitzu Sep 11, 2024
c25932b
Fix tests
steinitzu Sep 12, 2024
6c426e6
Ignore sources tests
steinitzu Sep 12, 2024
36b585e
Fix overriding destination in test pipeline
steinitzu Sep 12, 2024
0208c64
Fix time precision in arrow test
steinitzu Sep 12, 2024
65f6ef7
Lint
steinitzu Sep 12, 2024
6c29071
Fix destination setup in test
steinitzu Sep 13, 2024
eec4e22
Fix
steinitzu Sep 13, 2024
dc4c29c
Use nullpool, lazy create engine, close current connection
steinitzu Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor type mapper
steinitzu committed Sep 11, 2024
commit 874c8710900a8ca87b8d7afe753fdbe3f92fdb28
4 changes: 2 additions & 2 deletions dlt/common/destination/capabilities.py
Original file line number Diff line number Diff line change
@@ -226,8 +226,8 @@ def generic_capabilities(
caps.merge_strategies_selector = merge_strategies_selector
return caps

def get_type_mapper(self) -> DataTypeMapper:
return self.type_mapper(self)
def get_type_mapper(self, *args: Any, **kwargs: Any) -> DataTypeMapper:
return self.type_mapper(self, *args, **kwargs)


def merge_caps_file_formats(
1 change: 0 additions & 1 deletion dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
@@ -393,7 +393,6 @@ def run_managed(
self.run()
self._state = "completed"
except (DestinationTerminalException, TerminalValueError) as e:
logger.exception(f"Job {self.job_id()} failed terminally")
self._state = "failed"
self._exception = e
logger.exception(f"Terminal exception in job {self.job_id()} in file {self._file_path}")
11 changes: 10 additions & 1 deletion dlt/destinations/impl/sqlalchemy/factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing as t

from dlt.common.data_writers.configuration import CsvFormatConfiguration
from dlt.common.destination import Destination, DestinationCapabilitiesContext
from dlt.common.destination.capabilities import DataTypeMapper
from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE
from dlt.common.normalizers import NamingConvention

@@ -10,6 +10,14 @@
SqlalchemyClientConfiguration,
)

SqlalchemyTypeMapper: t.Type[DataTypeMapper]

try:
from dlt.destinations.impl.sqlalchemy.type_mapper import SqlalchemyTypeMapper
except ModuleNotFoundError:
# assign mock type mapper if no sqlalchemy
from dlt.common.destination.capabilities import UnsupportedTypeMapper as SqlalchemyTypeMapper

if t.TYPE_CHECKING:
# from dlt.destinations.impl.sqlalchemy.sqlalchemy_client import SqlalchemyJobClient
from dlt.destinations.impl.sqlalchemy.sqlalchemy_job_client import SqlalchemyJobClient
@@ -37,6 +45,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext:
caps.supports_ddl_transactions = True
caps.max_query_parameters = 20_0000
caps.max_rows_per_insert = 10_000 # Set a default to avoid OOM on large datasets
caps.type_mapper = SqlalchemyTypeMapper

return caps

176 changes: 16 additions & 160 deletions dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
from typing import Iterable, Optional, Type, Dict, Any, Iterator, Sequence, List, Tuple, IO
from types import TracebackType
from typing import Iterable, Optional, Dict, Any, Iterator, Sequence, List, Tuple, IO
from contextlib import suppress
import inspect
import math

import sqlalchemy as sa
from sqlalchemy.sql import sqltypes

from dlt.common import logger
from dlt.common import pendulum
from dlt.common.exceptions import TerminalValueError
from dlt.common.destination.reference import (
JobClientBase,
LoadJob,
RunnableLoadJob,
StorageSchemaInfo,
StateInfo,
PreparedTableSchema,
)
from dlt.destinations.job_client_impl import SqlJobClientBase
from dlt.common.destination.capabilities import DestinationCapabilitiesContext
from dlt.common.schema import Schema, TTableSchema, TColumnSchema, TSchemaTables
from dlt.common.schema.typing import TColumnType, TTableFormat, TTableSchemaColumns
from dlt.common.schema.typing import TColumnType, TTableSchemaColumns
from dlt.common.schema.utils import pipeline_state_table, normalize_table_identifiers
from dlt.common.storages import FileStorage
from dlt.common.json import json, PY_DATETIME_DECODERS
@@ -32,147 +29,6 @@
from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyClientConfiguration


class SqlaTypeMapper:
# TODO: May be merged with TypeMapper as a generic
def __init__(
self, capabilities: DestinationCapabilitiesContext, dialect: sa.engine.Dialect
) -> None:
self.capabilities = capabilities
self.dialect = dialect

def _db_integer_type(self, precision: Optional[int]) -> sa.types.TypeEngine:
if precision is None:
return sa.BigInteger()
elif precision <= 16:
return sa.SmallInteger()
elif precision <= 32:
return sa.Integer()
elif precision <= 64:
return sa.BigInteger()
raise TerminalValueError(f"Unsupported precision for integer type: {precision}")

def _create_date_time_type(
self, sc_t: str, precision: Optional[int], timezone: Optional[bool]
) -> sa.types.TypeEngine:
"""Use the dialect specific datetime/time type if possible since the generic type doesn't accept precision argument"""
precision = precision if precision is not None else self.capabilities.timestamp_precision
base_type: sa.types.TypeEngine
timezone = timezone is None or bool(timezone)
if sc_t == "timestamp":
base_type = sa.DateTime()
if self.dialect.name == "mysql":
# Special case, type_descriptor does not return the specifc datetime type
from sqlalchemy.dialects.mysql import DATETIME

return DATETIME(fsp=precision)
elif sc_t == "time":
base_type = sa.Time()

dialect_type = type(
self.dialect.type_descriptor(base_type)
) # Get the dialect specific subtype
precision = precision if precision is not None else self.capabilities.timestamp_precision

# Find out whether the dialect type accepts precision or fsp argument
params = inspect.signature(dialect_type).parameters
kwargs: Dict[str, Any] = dict(timezone=timezone)
if "fsp" in params:
kwargs["fsp"] = precision # MySQL uses fsp for fractional seconds
elif "precision" in params:
kwargs["precision"] = precision
return dialect_type(**kwargs) # type: ignore[no-any-return,misc]

def _create_double_type(self) -> sa.types.TypeEngine:
if dbl := getattr(sa, "Double", None):
# Sqlalchemy 2 has generic double type
return dbl() # type: ignore[no-any-return]
elif self.dialect.name == "mysql":
# MySQL has a specific double type
from sqlalchemy.dialects.mysql import DOUBLE
return sa.Float(precision=53) # Otherwise use float

def _to_db_decimal_type(self, column: TColumnSchema) -> sa.types.TypeEngine:
precision, scale = column.get("precision"), column.get("scale")
if precision is None and scale is None:
precision, scale = self.capabilities.decimal_precision
return sa.Numeric(precision, scale)

def to_db_type(self, column: TColumnSchema, table_format: TTableSchema) -> sa.types.TypeEngine:
sc_t = column["data_type"]
precision = column.get("precision")
# TODO: Precision and scale for supported types
if sc_t == "text":
length = precision
if length is None and column.get("unique"):
length = 128
if length is None:
return sa.Text()
return sa.String(length=length)
elif sc_t == "double":
return self._create_double_type()
elif sc_t == "bool":
return sa.Boolean()
elif sc_t == "timestamp":
return self._create_date_time_type(sc_t, precision, column.get("timezone"))
elif sc_t == "bigint":
return self._db_integer_type(precision)
elif sc_t == "binary":
return sa.LargeBinary(length=precision)
elif sc_t == "json":
return sa.JSON(none_as_null=True)
elif sc_t == "decimal":
return self._to_db_decimal_type(column)
elif sc_t == "wei":
wei_precision, wei_scale = self.capabilities.wei_precision
return sa.Numeric(precision=wei_precision, scale=wei_scale)
elif sc_t == "date":
return sa.Date()
elif sc_t == "time":
return self._create_date_time_type(sc_t, precision, column.get("timezone"))
raise TerminalValueError(f"Unsupported data type: {sc_t}")

def _from_db_integer_type(self, db_type: sa.Integer) -> TColumnType:
if isinstance(db_type, sa.SmallInteger):
return dict(data_type="bigint", precision=16)
elif isinstance(db_type, sa.Integer):
return dict(data_type="bigint", precision=32)
elif isinstance(db_type, sa.BigInteger):
return dict(data_type="bigint")
return dict(data_type="bigint")

def _from_db_decimal_type(self, db_type: sa.Numeric) -> TColumnType:
precision, scale = db_type.precision, db_type.scale
if (precision, scale) == self.capabilities.wei_precision:
return dict(data_type="wei")

return dict(data_type="decimal", precision=precision, scale=scale)

def from_db_type(self, db_type: sa.types.TypeEngine) -> TColumnType:
# TODO: pass the sqla type through dialect.type_descriptor before instance check
# Possibly need to check both dialect specific and generic types
if isinstance(db_type, sa.String):
return dict(data_type="text")
elif isinstance(db_type, sa.Float):
return dict(data_type="double")
elif isinstance(db_type, sa.Boolean):
return dict(data_type="bool")
elif isinstance(db_type, sa.DateTime):
return dict(data_type="timestamp", timezone=db_type.timezone)
elif isinstance(db_type, sa.Integer):
return self._from_db_integer_type(db_type)
elif isinstance(db_type, sqltypes._Binary):
return dict(data_type="binary", precision=db_type.length)
elif isinstance(db_type, sa.JSON):
return dict(data_type="json")
elif isinstance(db_type, sa.Numeric):
return self._from_db_decimal_type(db_type)
elif isinstance(db_type, sa.Date):
return dict(data_type="date")
elif isinstance(db_type, sa.Time):
return dict(data_type="time")
raise TerminalValueError(f"Unsupported db type: {db_type}")


class SqlalchemyJsonLInsertJob(RunnableLoadJob):
def __init__(self, file_path: str, table: sa.Table) -> None:
super().__init__(file_path)
@@ -270,9 +126,9 @@ def __init__(
self.schema = schema
self.capabilities = capabilities
self.config = config
self.type_mapper = SqlaTypeMapper(capabilities, self.sql_client.dialect)
self.type_mapper = self.capabilities.get_type_mapper(self.sql_client.dialect)

def _to_table_object(self, schema_table: TTableSchema) -> sa.Table:
def _to_table_object(self, schema_table: PreparedTableSchema) -> sa.Table:
existing = self.sql_client.get_existing_table(schema_table["name"])
if existing is not None:
existing_col_names = set(col.name for col in existing.columns)
@@ -292,17 +148,17 @@ def _to_table_object(self, schema_table: TTableSchema) -> sa.Table:
)

def _to_column_object(
self, schema_column: TColumnSchema, table_format: TTableSchema
self, schema_column: TColumnSchema, table: PreparedTableSchema
) -> sa.Column:
return sa.Column(
schema_column["name"],
self.type_mapper.to_db_type(schema_column, table_format),
self.type_mapper.to_destination_type(schema_column, table),
nullable=schema_column.get("nullable", True),
unique=schema_column.get("unique", False),
)

def create_load_job(
self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False
self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False
) -> LoadJob:
if file_path.endswith(".typed-jsonl"):
table_obj = self._to_table_object(table)
@@ -313,7 +169,7 @@ def create_load_job(
return None

def complete_load(self, load_id: str) -> None:
loads_table = self._to_table_object(self.schema.tables[self.schema.loads_table_name])
loads_table = self._to_table_object(self.schema.tables[self.schema.loads_table_name]) # type: ignore[arg-type]
now_ts = pendulum.now()
self.sql_client.execute_sql(
loads_table.insert().values(
@@ -346,7 +202,7 @@ def get_storage_tables(
col.name: {
"name": col.name,
"nullable": col.nullable,
**self.type_mapper.from_db_type(col.type),
**self.type_mapper.from_destination_type(col.type, None, None),
}
for col in table_obj.columns
}
@@ -373,7 +229,7 @@ def update_stored_schema(

# Create all schema tables in metadata
for table_name in only_tables or self.schema.tables:
self._to_table_object(self.schema.tables[table_name])
self._to_table_object(self.schema.tables[table_name]) # type: ignore[arg-type]

schema_update: TSchemaTables = {}
tables_to_create: List[sa.Table] = []
@@ -407,15 +263,15 @@ def update_stored_schema(

def _delete_schema_in_storage(self, schema: Schema) -> None:
version_table = schema.tables[schema.version_table_name]
table_obj = self._to_table_object(version_table)
table_obj = self._to_table_object(version_table) # type: ignore[arg-type]
schema_name_col = schema.naming.normalize_identifier("schema_name")
self.sql_client.execute_sql(
table_obj.delete().where(table_obj.c[schema_name_col] == schema.name)
)

def _update_schema_in_storage(self, schema: Schema) -> None:
version_table = schema.tables[schema.version_table_name]
table_obj = self._to_table_object(version_table)
table_obj = self._to_table_object(version_table) # type: ignore[arg-type]
schema_str = json.dumps(schema.to_dict())

schema_mapping = StorageSchemaInfo(
@@ -433,7 +289,7 @@ def _get_stored_schema(
self, version_hash: Optional[str] = None, schema_name: Optional[str] = None
) -> Optional[StorageSchemaInfo]:
version_table = self.schema.tables[self.schema.version_table_name]
table_obj = self._to_table_object(version_table)
table_obj = self._to_table_object(version_table) # type: ignore[arg-type]
with suppress(DatabaseUndefinedRelation):
q = sa.select(table_obj)
if version_hash is not None:
@@ -465,9 +321,9 @@ def get_stored_state(self, pipeline_name: str) -> StateInfo:
state_table = self.schema.tables.get(
self.schema.state_table_name
) or normalize_table_identifiers(pipeline_state_table(), self.schema.naming)
state_table_obj = self._to_table_object(state_table)
state_table_obj = self._to_table_object(state_table) # type: ignore[arg-type]
loads_table = self.schema.tables[self.schema.loads_table_name]
loads_table_obj = self._to_table_object(loads_table)
loads_table_obj = self._to_table_object(loads_table) # type: ignore[arg-type]

c_load_id, c_dlt_load_id, c_pipeline_name, c_status = map(
self.schema.naming.normalize_identifier,
174 changes: 174 additions & 0 deletions dlt/destinations/impl/sqlalchemy/type_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from typing import Optional, Dict, Any
import inspect

import sqlalchemy as sa
from sqlalchemy.sql import sqltypes

from dlt.common.exceptions import TerminalValueError
from dlt.common.typing import TLoaderFileFormat
from dlt.common.destination.capabilities import DataTypeMapper, DestinationCapabilitiesContext
from dlt.common.destination.typing import PreparedTableSchema
from dlt.common.schema.typing import TColumnSchema


# TODO: base type mapper should be a generic class to support TypeEngine instead of str types
class SqlalchemyTypeMapper(DataTypeMapper):
def __init__(
self,
capabilities: DestinationCapabilitiesContext,
dialect: Optional[sa.engine.Dialect] = None,
):
super().__init__(capabilities)
# Mapper is used to verify supported types without client, dialect is not important for this
self.dialect = dialect or sa.engine.default.DefaultDialect()

def _db_integer_type(self, precision: Optional[int]) -> sa.types.TypeEngine:
if precision is None:
return sa.BigInteger()
elif precision <= 16:
return sa.SmallInteger()
elif precision <= 32:
return sa.Integer()
elif precision <= 64:
return sa.BigInteger()
raise TerminalValueError(f"Unsupported precision for integer type: {precision}")

def _create_date_time_type(
self, sc_t: str, precision: Optional[int], timezone: Optional[bool]
) -> sa.types.TypeEngine:
"""Use the dialect specific datetime/time type if possible since the generic type doesn't accept precision argument"""
precision = precision if precision is not None else self.capabilities.timestamp_precision
base_type: sa.types.TypeEngine
timezone = timezone is None or bool(timezone)
if sc_t == "timestamp":
base_type = sa.DateTime()
if self.dialect.name == "mysql":
# Special case, type_descriptor does not return the specifc datetime type
from sqlalchemy.dialects.mysql import DATETIME

return DATETIME(fsp=precision)
elif sc_t == "time":
base_type = sa.Time()

dialect_type = type(
self.dialect.type_descriptor(base_type)
) # Get the dialect specific subtype
precision = precision if precision is not None else self.capabilities.timestamp_precision

# Find out whether the dialect type accepts precision or fsp argument
params = inspect.signature(dialect_type).parameters
kwargs: Dict[str, Any] = dict(timezone=timezone)
if "fsp" in params:
kwargs["fsp"] = precision # MySQL uses fsp for fractional seconds
elif "precision" in params:
kwargs["precision"] = precision
return dialect_type(**kwargs) # type: ignore[no-any-return,misc]

def _create_double_type(self) -> sa.types.TypeEngine:
if dbl := getattr(sa, "Double", None):
# Sqlalchemy 2 has generic double type
return dbl() # type: ignore[no-any-return]
elif self.dialect.name == "mysql":
# MySQL has a specific double type
from sqlalchemy.dialects.mysql import DOUBLE

return DOUBLE()
return sa.Float(precision=53) # Otherwise use float

def _to_db_decimal_type(self, column: TColumnSchema) -> sa.types.TypeEngine:
precision, scale = column.get("precision"), column.get("scale")
if precision is None and scale is None:
precision, scale = self.capabilities.decimal_precision
return sa.Numeric(precision, scale)

def to_destination_type( # type: ignore[override]
self, column: TColumnSchema, table: PreparedTableSchema = None
) -> sqltypes.TypeEngine:
sc_t = column["data_type"]
precision = column.get("precision")
# TODO: Precision and scale for supported types
if sc_t == "text":
length = precision
if length is None and column.get("unique"):
length = 128
if length is None:
return sa.Text()
return sa.String(length=length)
elif sc_t == "double":
return self._create_double_type()
elif sc_t == "bool":
return sa.Boolean()
elif sc_t == "timestamp":
return self._create_date_time_type(sc_t, precision, column.get("timezone"))
elif sc_t == "bigint":
return self._db_integer_type(precision)
elif sc_t == "binary":
return sa.LargeBinary(length=precision)
elif sc_t == "json":
return sa.JSON(none_as_null=True)
elif sc_t == "decimal":
return self._to_db_decimal_type(column)
elif sc_t == "wei":
wei_precision, wei_scale = self.capabilities.wei_precision
return sa.Numeric(precision=wei_precision, scale=wei_scale)
elif sc_t == "date":
return sa.Date()
elif sc_t == "time":
return self._create_date_time_type(sc_t, precision, column.get("timezone"))
raise TerminalValueError(f"Unsupported data type: {sc_t}")

def _from_db_integer_type(self, db_type: sa.Integer) -> TColumnSchema:
if isinstance(db_type, sa.SmallInteger):
return dict(data_type="bigint", precision=16)
elif isinstance(db_type, sa.Integer):
return dict(data_type="bigint", precision=32)
elif isinstance(db_type, sa.BigInteger):
return dict(data_type="bigint")
return dict(data_type="bigint")

def _from_db_decimal_type(self, db_type: sa.Numeric) -> TColumnSchema:
precision, scale = db_type.precision, db_type.scale
if (precision, scale) == self.capabilities.wei_precision:
return dict(data_type="wei")

return dict(data_type="decimal", precision=precision, scale=scale)

def from_destination_type( # type: ignore[override]
self,
db_type: sqltypes.TypeEngine,
precision: Optional[int] = None,
scale: Optional[int] = None,
) -> TColumnSchema:
# TODO: pass the sqla type through dialect.type_descriptor before instance check
# Possibly need to check both dialect specific and generic types
if isinstance(db_type, sa.String):
return dict(data_type="text")
elif isinstance(db_type, sa.Float):
return dict(data_type="double")
elif isinstance(db_type, sa.Boolean):
return dict(data_type="bool")
elif isinstance(db_type, sa.DateTime):
return dict(data_type="timestamp", timezone=db_type.timezone)
elif isinstance(db_type, sa.Integer):
return self._from_db_integer_type(db_type)
elif isinstance(db_type, sqltypes._Binary):
return dict(data_type="binary", precision=db_type.length)
elif isinstance(db_type, sa.JSON):
return dict(data_type="json")
elif isinstance(db_type, sa.Numeric):
return self._from_db_decimal_type(db_type)
elif isinstance(db_type, sa.Date):
return dict(data_type="date")
elif isinstance(db_type, sa.Time):
return dict(data_type="time")
raise TerminalValueError(f"Unsupported db type: {db_type}")

pass

def ensure_supported_type(
self,
column: TColumnSchema,
table: PreparedTableSchema,
loader_file_format: TLoaderFileFormat,
) -> None:
pass