Skip to content

Commit

Permalink
One way support for precision, scale in column schema
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Sep 20, 2023
1 parent 390519d commit 3541919
Show file tree
Hide file tree
Showing 14 changed files with 351 additions and 174 deletions.
2 changes: 2 additions & 0 deletions dlt/common/schema/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class TColumnSchema(TColumnSchemaBase, total=False):
root_key: Optional[bool]
merge_key: Optional[bool]
variant: Optional[bool]
precision: Optional[int]
scale: Optional[int]


TTableSchemaColumns = Dict[str, TColumnSchema]
Expand Down
44 changes: 23 additions & 21 deletions dlt/destinations/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,10 @@
from dlt.destinations.typing import DBApiCursor
from dlt.destinations.job_client_impl import SqlJobClientBase, StorageSchemaInfo
from dlt.destinations.athena.configuration import AthenaClientConfiguration
from dlt.destinations.type_mapping import TypeMapper
from dlt.destinations import path_utils

SCT_TO_HIVET: Dict[TDataType, str] = {
"complex": "string",
"text": "string",
"double": "double",
"bool": "boolean",
"date": "date",
"timestamp": "timestamp",
"bigint": "bigint",
"binary": "binary",
"decimal": "decimal(%i,%i)",
"time": "string"
}


HIVET_TO_SCT: Dict[str, TDataType] = {
"varchar": "text",
Expand All @@ -59,6 +49,25 @@
}


class AthenaTypeMapper(TypeMapper):
sct_to_unbound_dbt = {
"complex": "string",
"text": "string",
"double": "double",
"bool": "boolean",
"date": "date",
"timestamp": "timestamp",
"bigint": "bigint",
"binary": "binary",
"time": "string"
}

sct_to_dbt = {
"decimal": "decimal(%i,%i)",
"wei": "decimal(%i,%i)"
}


# add a formatter for pendulum to be used by pyathen dbapi
def _format_pendulum_datetime(formatter: Formatter, escaper: Callable[[str], str], val: Any) -> Any:
# copied from https://github.com/laughingman7743/PyAthena/blob/f4b21a0b0f501f5c3504698e25081f491a541d4e/pyathena/formatter.py#L114
Expand Down Expand Up @@ -265,19 +274,12 @@ def __init__(self, schema: Schema, config: AthenaClientConfiguration) -> None:
super().__init__(schema, config, sql_client)
self.sql_client: AthenaSQLClient = sql_client # type: ignore
self.config: AthenaClientConfiguration = config
self.type_mapper = AthenaTypeMapper(self.capabilities)

def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None:
# never truncate tables in athena
super().initialize_storage([])

@classmethod
def _to_db_type(cls, sc_t: TDataType) -> str:
if sc_t == "wei":
return SCT_TO_HIVET["decimal"] % cls.capabilities.wei_precision
if sc_t == "decimal":
return SCT_TO_HIVET["decimal"] % cls.capabilities.decimal_precision
return SCT_TO_HIVET[sc_t]

@classmethod
def _from_db_type(cls, hive_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType:
for key, val in HIVET_TO_SCT.items():
Expand All @@ -286,7 +288,7 @@ def _from_db_type(cls, hive_t: str, precision: Optional[int], scale: Optional[in
return None

def _get_column_def_sql(self, c: TColumnSchema) -> str:
return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self._to_db_type(c['data_type'])}"
return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c)}"

def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool) -> List[str]:

Expand Down
35 changes: 29 additions & 6 deletions dlt/destinations/bigquery/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob
from dlt.destinations.job_impl import NewReferenceJob
from dlt.destinations.sql_client import SqlClientBase
from dlt.destinations.type_mapping import TypeMapper

from dlt.common.schema.utils import table_schema_has_type

Expand Down Expand Up @@ -53,6 +54,27 @@
"TIME": "time",
}


class BigQueryTypeMapper(TypeMapper):
sct_to_unbound_dbt = {
"complex": "JSON",
"text": "STRING",
"double": "FLOAT64",
"bool": "BOOLEAN",
"date": "DATE",
"timestamp": "TIMESTAMP",
"bigint": "INTEGER",
"binary": "BYTES",
"wei": "BIGNUMERIC", # non parametrized should hold wei values
"time": "TIME",
}

sct_to_dbt = {
"text": "STRING(%i)",
"binary": "BYTES(%i)",
"decimal": "NUMERIC(%i,%i)",
}

class BigQueryLoadJob(LoadJob, FollowupJob):
def __init__(
self,
Expand Down Expand Up @@ -146,6 +168,7 @@ def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None:
super().__init__(schema, config, sql_client)
self.config: BigQueryClientConfiguration = config
self.sql_client: BigQuerySqlClient = sql_client # type: ignore
self.type_mapper = BigQueryTypeMapper(self.capabilities)

def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob:
return BigQueryMergeJob.from_table_chain(table_chain, self.sql_client)
Expand Down Expand Up @@ -229,7 +252,7 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc

def _get_column_def_sql(self, c: TColumnSchema) -> str:
name = self.capabilities.escape_identifier(c["name"])
return f"{name} {self._to_db_type(c['data_type'])} {self._gen_not_null(c.get('nullable', True))}"
return f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}"

def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]:
schema_table: TTableSchemaColumns = {}
Expand Down Expand Up @@ -311,11 +334,11 @@ def _retrieve_load_job(self, file_path: str) -> bigquery.LoadJob:
job_id = BigQueryLoadJob.get_job_id_from_file_path(file_path)
return cast(bigquery.LoadJob, self.sql_client.native_connection.get_job(job_id))

@classmethod
def _to_db_type(cls, sc_t: TDataType) -> str:
if sc_t == "decimal":
return SCT_TO_BQT["decimal"] % cls.capabilities.decimal_precision
return SCT_TO_BQT[sc_t]
# @classmethod
# def _to_db_type(cls, sc_t: TDataType) -> str:
# if sc_t == "decimal":
# return SCT_TO_BQT["decimal"] % cls.capabilities.decimal_precision
# return SCT_TO_BQT[sc_t]

@classmethod
def _from_db_type(cls, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType:
Expand Down
69 changes: 46 additions & 23 deletions dlt/destinations/duckdb/duck.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@
from dlt.destinations.duckdb import capabilities
from dlt.destinations.duckdb.sql_client import DuckDbSqlClient
from dlt.destinations.duckdb.configuration import DuckDbClientConfiguration


SCT_TO_PGT: Dict[TDataType, str] = {
"complex": "JSON",
"text": "VARCHAR",
"double": "DOUBLE",
"bool": "BOOLEAN",
"date": "DATE",
"timestamp": "TIMESTAMP WITH TIME ZONE",
"bigint": "BIGINT",
"binary": "BLOB",
"decimal": "DECIMAL(%i,%i)",
"time": "TIME"
}
from dlt.destinations.type_mapping import TypeMapper


# SCT_TO_PGT: Dict[TDataType, str] = {
# "complex": "JSON",
# "text": "VARCHAR",
# "double": "DOUBLE",
# "bool": "BOOLEAN",
# "date": "DATE",
# "timestamp": "TIMESTAMP WITH TIME ZONE",
# "bigint": "BIGINT",
# "binary": "BLOB",
# "decimal": "DECIMAL(%i,%i)",
# "time": "TIME"
# }

PGT_TO_SCT: Dict[str, TDataType] = {
"VARCHAR": "text",
Expand All @@ -51,6 +52,27 @@
TABLES_LOCKS: Dict[str, threading.Lock] = {}


class DuckDbTypeMapper(TypeMapper):
sct_to_unbound_dbt = {
"complex": "JSON",
"text": "VARCHAR",
"double": "DOUBLE",
"bool": "BOOLEAN",
"date": "DATE",
# Duck does not allow specifying precision on timestamp with tz
"timestamp": "TIMESTAMP WITH TIME ZONE",
"bigint": "BIGINT",
"binary": "BLOB",
"time": "TIME"
}

sct_to_dbt = {
# VARCHAR(n) is alias for VARCHAR in duckdb
# "text": "VARCHAR(%i)",
"decimal": "DECIMAL(%i,%i)",
"wei": "DECIMAL(%i,%i)",
}

class DuckDbCopyJob(LoadJob, FollowupJob):
def __init__(self, table_name: str, file_path: str, sql_client: DuckDbSqlClient) -> None:
super().__init__(FileStorage.get_file_name_from_file_path(file_path))
Expand Down Expand Up @@ -95,6 +117,7 @@ def __init__(self, schema: Schema, config: DuckDbClientConfiguration) -> None:
self.config: DuckDbClientConfiguration = config
self.sql_client: DuckDbSqlClient = sql_client # type: ignore
self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {}
self.type_mapper = DuckDbTypeMapper(self.capabilities)

def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
job = super().start_file_load(table, file_path, load_id)
Expand All @@ -105,15 +128,15 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) ->
def _get_column_def_sql(self, c: TColumnSchema) -> str:
hints_str = " ".join(self.active_hints.get(h, "") for h in self.active_hints.keys() if c.get(h, False) is True)
column_name = self.capabilities.escape_identifier(c["name"])
return f"{column_name} {self._to_db_type(c['data_type'])} {hints_str} {self._gen_not_null(c.get('nullable', True))}"

@classmethod
def _to_db_type(cls, sc_t: TDataType) -> str:
if sc_t == "wei":
return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision
if sc_t == "decimal":
return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision
return SCT_TO_PGT[sc_t]
return f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}"

# @classmethod
# def _to_db_type(cls, sc_t: TDataType) -> str:
# if sc_t == "wei":
# return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision
# if sc_t == "decimal":
# return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision
# return SCT_TO_PGT[sc_t]

@classmethod
def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType:
Expand Down
5 changes: 0 additions & 5 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,6 @@ def _null_to_bool(v: str) -> bool:
schema_table[c[0]] = schema_c # type: ignore
return True, schema_table

@classmethod
@abstractmethod
def _to_db_type(cls, schema_type: TDataType) -> str:
pass

@classmethod
@abstractmethod
def _from_db_type(cls, db_type: str, precision: Optional[int], scale: Optional[int]) -> TDataType:
Expand Down
1 change: 1 addition & 0 deletions dlt/destinations/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def capabilities() -> DestinationCapabilitiesContext:
caps.is_max_text_data_type_length_in_bytes = False
caps.supports_ddl_transactions = True
caps.max_rows_per_insert = 1000
caps.timestamp_precision = 7

return caps

Expand Down
75 changes: 49 additions & 26 deletions dlt/destinations/mssql/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,21 @@
from dlt.destinations.mssql.sql_client import PyOdbcMsSqlClient
from dlt.destinations.mssql.configuration import MsSqlClientConfiguration
from dlt.destinations.sql_client import SqlClientBase


SCT_TO_PGT: Dict[TDataType, str] = {
"complex": "nvarchar(max)",
"text": "nvarchar(max)",
"double": "float",
"bool": "bit",
"timestamp": "datetimeoffset",
"date": "date",
"bigint": "bigint",
"binary": "varbinary(max)",
"decimal": "decimal(%i,%i)",
"time": "time"
}
from dlt.destinations.type_mapping import TypeMapper


# SCT_TO_PGT: Dict[TDataType, str] = {
# "complex": "nvarchar(max)",
# "text": "nvarchar(max)",
# "double": "float",
# "bool": "bit",
# "timestamp": "datetimeoffset",
# "date": "date",
# "bigint": "bigint",
# "binary": "varbinary(max)",
# "decimal": "decimal(%i,%i)",
# "time": "time"
# }

PGT_TO_SCT: Dict[str, TDataType] = {
"nvarchar": "text",
Expand All @@ -47,6 +48,27 @@
"unique": "UNIQUE"
}


class MsSqlTypeMapper(TypeMapper):
sct_to_unbound_dbt = {
"complex": "nvarchar(max)",
"text": "nvarchar(max)",
"double": "float",
"bool": "bit",
"bigint": "bigint",
"binary": "varbinary(max)",
}

sct_to_dbt = {
"complex": "nvarchar(%i)",
"text": "nvarchar(%i)",
"timestamp": "datetimeoffset(%i)",
"binary": "varbinary(%i)",
"decimal": "decimal(%i,%i)",
"time": "time(%i)"
}


class MsSqlStagingCopyJob(SqlStagingCopyJob):

@classmethod
Expand Down Expand Up @@ -97,6 +119,7 @@ def __init__(self, schema: Schema, config: MsSqlClientConfiguration) -> None:
self.config: MsSqlClientConfiguration = config
self.sql_client = sql_client
self.active_hints = HINT_TO_MSSQL_ATTR if self.config.create_indexes else {}
self.type_mapper = MsSqlTypeMapper(self.capabilities)

def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob:
return MsSqlMergeJob.from_table_chain(table_chain, self.sql_client)
Expand All @@ -109,9 +132,9 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str:
sc_type = c["data_type"]
if sc_type == "text" and c.get("unique"):
# MSSQL does not allow index on large TEXT columns
db_type = "nvarchar(900)"
db_type = "nvarchar(%i)" % (c.get("precision") or 900)
else:
db_type = self._to_db_type(sc_type)
db_type = self.type_mapper.to_db_type(c)

hints_str = " ".join(self.active_hints.get(h, "") for h in self.active_hints.keys() if c.get(h, False) is True)
column_name = self.capabilities.escape_identifier(c["name"])
Expand All @@ -120,16 +143,16 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str:
def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob:
return MsSqlStagingCopyJob.from_table_chain(table_chain, self.sql_client)

@classmethod
def _to_db_type(cls, sc_t: TDataType) -> str:
if sc_t == "wei":
return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision
if sc_t == "decimal":
return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision

if sc_t == "wei":
return f"numeric({2*EVM_DECIMAL_PRECISION},{EVM_DECIMAL_PRECISION})"
return SCT_TO_PGT[sc_t]
# @classmethod
# def _to_db_type(cls, sc_t: TDataType) -> str:
# if sc_t == "wei":
# return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision
# if sc_t == "decimal":
# return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision

# if sc_t == "wei":
# return f"numeric({2*EVM_DECIMAL_PRECISION},{EVM_DECIMAL_PRECISION})"
# return SCT_TO_PGT[sc_t]

@classmethod
def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType:
Expand Down
Loading

0 comments on commit 3541919

Please sign in to comment.