Skip to content

Commit

Permalink
make type mapper table format sensitive
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Oct 14, 2023
1 parent 0707629 commit 4692e37
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 39 deletions.
15 changes: 7 additions & 8 deletions dlt/destinations/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dlt.common.utils import without_none
from dlt.common.data_types import TDataType
from dlt.common.schema import TColumnSchema, Schema, TSchemaTables, TTableSchema
from dlt.common.schema.typing import TTableSchema, TColumnType, TWriteDisposition
from dlt.common.schema.typing import TTableSchema, TColumnType, TWriteDisposition, TTableFormat
from dlt.common.schema.utils import table_schema_has_type, get_table_format
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import LoadJob, FollowupJob
Expand Down Expand Up @@ -72,14 +72,13 @@ class AthenaTypeMapper(TypeMapper):
def __init__(self, capabilities: DestinationCapabilitiesContext):
super().__init__(capabilities)

def to_db_integer_type(self, precision: Optional[int]) -> str:
def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str:
if precision is None:
return "bigint"
# TODO: iceberg does not support smallint and tinyint
if precision <= 8:
return "int"
return "int" if table_format == "iceberg" else "tinyint"
elif precision <= 16:
return "int"
return "int" if table_format == "iceberg" else "smallint"
elif precision <= 32:
return "int"
return "bigint"
Expand Down Expand Up @@ -312,8 +311,8 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None:
def _from_db_type(self, hive_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType:
return self.type_mapper.from_db_type(hive_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema) -> str:
return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c)}"
def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}"

def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool) -> List[str]:

Expand All @@ -326,7 +325,7 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc
# or if we are in iceberg mode, we create iceberg tables for all tables
table = self.get_load_table(table_name, self.in_staging_mode)
is_iceberg = self._is_iceberg_table(table) or table.get("write_disposition", None) == "skip"
columns = ", ".join([self._get_column_def_sql(c) for c in new_columns])
columns = ", ".join([self._get_column_def_sql(c, table.get("table_format")) for c in new_columns])

# this will fail if the table prefix is not properly defined
table_prefix = self.table_prefix_layout.format(table_name=table_name)
Expand Down
6 changes: 3 additions & 3 deletions dlt/destinations/bigquery/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dlt.common.data_types import TDataType
from dlt.common.storages.file_storage import FileStorage
from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns
from dlt.common.schema.typing import TTableSchema, TColumnType
from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat
from dlt.common.schema.exceptions import UnknownTableException

from dlt.destinations.job_client_impl import SqlJobClientWithStaging
Expand Down Expand Up @@ -250,9 +250,9 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc

return sql

def _get_column_def_sql(self, c: TColumnSchema) -> str:
def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
name = self.capabilities.escape_identifier(c["name"])
return f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}"
return f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}"

def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]:
schema_table: TTableSchemaColumns = {}
Expand Down
6 changes: 3 additions & 3 deletions dlt/destinations/duckdb/duck.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dlt.common.data_types import TDataType
from dlt.common.schema import TColumnSchema, TColumnHint, Schema
from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState
from dlt.common.schema.typing import TTableSchema, TColumnType
from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat
from dlt.common.storages.file_storage import FileStorage
from dlt.common.utils import maybe_context

Expand Down Expand Up @@ -65,7 +65,7 @@ class DuckDbTypeMapper(TypeMapper):
"HUGEINT": "bigint",
}

def to_db_integer_type(self, precision: Optional[int]) -> str:
def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str:
if precision is None:
return "BIGINT"
# Precision is number of bits
Expand Down Expand Up @@ -141,7 +141,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) ->
job = DuckDbCopyJob(table["name"], file_path, self.sql_client)
return job

def _get_column_def_sql(self, c: TColumnSchema) -> str:
def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
hints_str = " ".join(self.active_hints.get(h, "") for h in self.active_hints.keys() if c.get(h, False) is True)
column_name = self.capabilities.escape_identifier(c["name"])
return f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}"
Expand Down
13 changes: 7 additions & 6 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from dlt.common import json, pendulum, logger
from dlt.common.data_types import TDataType
from dlt.common.schema.typing import COLUMN_HINTS, TColumnType, TColumnSchemaBase, TTableSchema, TWriteDisposition
from dlt.common.schema.typing import COLUMN_HINTS, TColumnType, TColumnSchemaBase, TTableSchema, TWriteDisposition, TTableFormat
from dlt.common.storages import FileStorage
from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables
from dlt.common.destination.reference import StateInfo, StorageSchemaInfo,WithStateSync, DestinationClientConfiguration, DestinationClientDwhConfiguration, DestinationClientDwhWithStagingConfiguration, NewLoadJob, WithStagingDataset, TLoadJobState, LoadJob, JobClientBase, FollowupJob, CredentialsConfiguration
Expand Down Expand Up @@ -318,23 +318,24 @@ def _build_schema_update_sql(self, only_tables: Iterable[str]) -> Tuple[List[str

return sql_updates, schema_update

def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str]:
def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None) -> List[str]:
"""Make one or more ADD COLUMN sql clauses to be joined in ALTER TABLE statement(s)"""
return [f"ADD COLUMN {self._get_column_def_sql(c)}" for c in new_columns]
return [f"ADD COLUMN {self._get_column_def_sql(c, table_format)}" for c in new_columns]

def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool) -> List[str]:
# build sql
canonical_name = self.sql_client.make_qualified_table_name(table_name)
table = self.get_load_table(table_name)
sql_result: List[str] = []
if not generate_alter:
# build CREATE
sql = f"CREATE TABLE {canonical_name} (\n"
sql += ",\n".join([self._get_column_def_sql(c) for c in new_columns])
sql += ",\n".join([self._get_column_def_sql(c, table.get("table_format")) for c in new_columns])
sql += ")"
sql_result.append(sql)
else:
sql_base = f"ALTER TABLE {canonical_name}\n"
add_column_statements = self._make_add_column_sql(new_columns)
add_column_statements = self._make_add_column_sql(new_columns, table.get("table_format"))
if self.capabilities.alter_add_multi_column:
column_sql = ",\n"
sql_result.append(sql_base + column_sql.join(add_column_statements))
Expand All @@ -357,7 +358,7 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc
return sql_result

@abstractmethod
def _get_column_def_sql(self, c: TColumnSchema) -> str:
def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
pass

@staticmethod
Expand Down
10 changes: 5 additions & 5 deletions dlt/destinations/mssql/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.data_types import TDataType
from dlt.common.schema import TColumnSchema, TColumnHint, Schema
from dlt.common.schema.typing import TTableSchema, TColumnType
from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat
from dlt.common.utils import uniq_id

from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob, SqlJobParams
Expand Down Expand Up @@ -62,7 +62,7 @@ class MsSqlTypeMapper(TypeMapper):
"int": "bigint",
}

def to_db_integer_type(self, precision: Optional[int]) -> str:
def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str:
if precision is None:
return "bigint"
if precision <= 8:
Expand Down Expand Up @@ -136,11 +136,11 @@ def __init__(self, schema: Schema, config: MsSqlClientConfiguration) -> None:
def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]:
return [MsSqlMergeJob.from_table_chain(table_chain, self.sql_client)]

def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str]:
def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None) -> List[str]:
# Override because mssql requires multiple columns in a single ADD COLUMN clause
return ["ADD \n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)]
return ["ADD \n" + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns)]

def _get_column_def_sql(self, c: TColumnSchema) -> str:
def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
sc_type = c["data_type"]
if sc_type == "text" and c.get("unique"):
# MSSQL does not allow index on large TEXT columns
Expand Down
6 changes: 3 additions & 3 deletions dlt/destinations/postgres/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.data_types import TDataType
from dlt.common.schema import TColumnSchema, TColumnHint, Schema
from dlt.common.schema.typing import TTableSchema, TColumnType
from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat

from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams

Expand Down Expand Up @@ -59,7 +59,7 @@ class PostgresTypeMapper(TypeMapper):
"integer": "bigint",
}

def to_db_integer_type(self, precision: Optional[int]) -> str:
def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str:
if precision is None:
return "bigint"
# Precision is number of bits
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(self, schema: Schema, config: PostgresClientConfiguration) -> None:
self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {}
self.type_mapper = PostgresTypeMapper(self.capabilities)

def _get_column_def_sql(self, c: TColumnSchema) -> str:
def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
hints_str = " ".join(self.active_hints.get(h, "") for h in self.active_hints.keys() if c.get(h, False) is True)
column_name = self.capabilities.escape_identifier(c["name"])
return f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}"
Expand Down
6 changes: 3 additions & 3 deletions dlt/destinations/redshift/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dlt.common.destination.reference import NewLoadJob, CredentialsConfiguration, SupportsStagingDestination
from dlt.common.data_types import TDataType
from dlt.common.schema import TColumnSchema, TColumnHint, Schema
from dlt.common.schema.typing import TTableSchema, TColumnType
from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat
from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults

from dlt.destinations.insert_job_client import InsertValuesJobClient
Expand Down Expand Up @@ -76,7 +76,7 @@ class RedshiftTypeMapper(TypeMapper):
"integer": "bigint",
}

def to_db_integer_type(self, precision: Optional[int]) -> str:
def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str:
if precision is None:
return "bigint"
if precision <= 16:
Expand Down Expand Up @@ -204,7 +204,7 @@ def __init__(self, schema: Schema, config: RedshiftClientConfiguration) -> None:
def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]:
return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)]

def _get_column_def_sql(self, c: TColumnSchema) -> str:
def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
hints_str = " ".join(HINT_TO_REDSHIFT_ATTR.get(h, "") for h in HINT_TO_REDSHIFT_ATTR.keys() if c.get(h, False) is True)
column_name = self.capabilities.escape_identifier(c["name"])
return f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}"
Expand Down
8 changes: 4 additions & 4 deletions dlt/destinations/snowflake/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dlt.common.data_types import TDataType
from dlt.common.storages.file_storage import FileStorage
from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns
from dlt.common.schema.typing import TTableSchema, TColumnType
from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat


from dlt.destinations.job_client_impl import SqlJobClientWithStaging
Expand Down Expand Up @@ -200,9 +200,9 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) ->
def restore_file_load(self, file_path: str) -> LoadJob:
return EmptyLoadJob.from_file_path(file_path, "completed")

def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str]:
def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None) -> List[str]:
# Override because snowflake requires multiple columns in a single ADD COLUMN clause
return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)]
return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns)]

def _create_replace_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]:
if self.config.replace_strategy == "staging-optimized":
Expand All @@ -222,7 +222,7 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc
def _from_db_type(self, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType:
return self.type_mapper.from_db_type(bq_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema) -> str:
def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
name = self.capabilities.escape_identifier(c["name"])
return f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}"

Expand Down
8 changes: 4 additions & 4 deletions dlt/destinations/type_mapping.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Tuple, ClassVar, Dict, Optional

from dlt.common.schema.typing import TColumnSchema, TDataType, TColumnType
from dlt.common.schema.typing import TColumnSchema, TDataType, TColumnType, TTableFormat
from dlt.common.destination.capabilities import DestinationCapabilitiesContext
from dlt.common.utils import without_none

Expand All @@ -20,15 +20,15 @@ class TypeMapper:
def __init__(self, capabilities: DestinationCapabilitiesContext) -> None:
self.capabilities = capabilities

def to_db_integer_type(self, precision: Optional[int]) -> str:
def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str:
# Override in subclass if db supports other integer types (e.g. smallint, integer, tinyint, etc.)
return self.sct_to_unbound_dbt["bigint"]

def to_db_type(self, column: TColumnSchema) -> str:
def to_db_type(self, column: TColumnSchema, table_format: TTableFormat = None) -> str:
precision, scale = column.get("precision"), column.get("scale")
sc_t = column["data_type"]
if sc_t == "bigint":
return self.to_db_integer_type(precision)
return self.to_db_integer_type(precision, table_format)
bounded_template = self.sct_to_dbt.get(sc_t)
if not bounded_template:
return self.sct_to_unbound_dbt[sc_t]
Expand Down

0 comments on commit 4692e37

Please sign in to comment.