From 4692e373f28fb95529633bb120536554d5780013 Mon Sep 17 00:00:00 2001 From: Dave Date: Sat, 14 Oct 2023 10:25:57 +0200 Subject: [PATCH] make type mapper table format sensitive --- dlt/destinations/athena/athena.py | 15 +++++++-------- dlt/destinations/bigquery/bigquery.py | 6 +++--- dlt/destinations/duckdb/duck.py | 6 +++--- dlt/destinations/job_client_impl.py | 13 +++++++------ dlt/destinations/mssql/mssql.py | 10 +++++----- dlt/destinations/postgres/postgres.py | 6 +++--- dlt/destinations/redshift/redshift.py | 6 +++--- dlt/destinations/snowflake/snowflake.py | 8 ++++---- dlt/destinations/type_mapping.py | 8 ++++---- 9 files changed, 39 insertions(+), 39 deletions(-) diff --git a/dlt/destinations/athena/athena.py b/dlt/destinations/athena/athena.py index 6e032a5acf..44d020c127 100644 --- a/dlt/destinations/athena/athena.py +++ b/dlt/destinations/athena/athena.py @@ -16,7 +16,7 @@ from dlt.common.utils import without_none from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, Schema, TSchemaTables, TTableSchema -from dlt.common.schema.typing import TTableSchema, TColumnType, TWriteDisposition +from dlt.common.schema.typing import TTableSchema, TColumnType, TWriteDisposition, TTableFormat from dlt.common.schema.utils import table_schema_has_type, get_table_format from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import LoadJob, FollowupJob @@ -72,14 +72,13 @@ class AthenaTypeMapper(TypeMapper): def __init__(self, capabilities: DestinationCapabilitiesContext): super().__init__(capabilities) - def to_db_integer_type(self, precision: Optional[int]) -> str: + def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: if precision is None: return "bigint" - # TODO: iceberg does not support smallint and tinyint if precision <= 8: - return "int" + return "int" if table_format == "iceberg" else "tinyint" elif precision <= 16: - return "int" + return "int" if table_format == "iceberg" else "smallint" elif precision <= 32: return "int" return "bigint" @@ -312,8 +311,8 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: def _from_db_type(self, hive_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: return self.type_mapper.from_db_type(hive_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema) -> str: - return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c)}" + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}" def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool) -> List[str]: @@ -326,7 +325,7 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc # or if we are in iceberg mode, we create iceberg tables for all tables table = self.get_load_table(table_name, self.in_staging_mode) is_iceberg = self._is_iceberg_table(table) or table.get("write_disposition", None) == "skip" - columns = ", ".join([self._get_column_def_sql(c) for c in new_columns]) + columns = ", ".join([self._get_column_def_sql(c, table.get("table_format")) for c in new_columns]) # this will fail if the table prefix is not properly defined table_prefix = self.table_prefix_layout.format(table_name=table_name) diff --git a/dlt/destinations/bigquery/bigquery.py b/dlt/destinations/bigquery/bigquery.py index f1478697fc..4cb467a7af 100644 --- a/dlt/destinations/bigquery/bigquery.py +++ b/dlt/destinations/bigquery/bigquery.py @@ -11,7 +11,7 @@ from dlt.common.data_types import TDataType from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.common.schema.exceptions import UnknownTableException from dlt.destinations.job_client_impl import SqlJobClientWithStaging @@ -250,9 +250,9 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc return sql - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: name = self.capabilities.escape_identifier(c["name"]) - return f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" + return f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}" def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: schema_table: TTableSchemaColumns = {} diff --git a/dlt/destinations/duckdb/duck.py b/dlt/destinations/duckdb/duck.py index c40abd56a0..fe4ebac37e 100644 --- a/dlt/destinations/duckdb/duck.py +++ b/dlt/destinations/duckdb/duck.py @@ -5,7 +5,7 @@ from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState -from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import maybe_context @@ -65,7 +65,7 @@ class DuckDbTypeMapper(TypeMapper): "HUGEINT": "bigint", } - def to_db_integer_type(self, precision: Optional[int]) -> str: + def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: if precision is None: return "BIGINT" # Precision is number of bits @@ -141,7 +141,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> job = DuckDbCopyJob(table["name"], file_path, self.sql_client) return job - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: hints_str = " ".join(self.active_hints.get(h, "") for h in self.active_hints.keys() if c.get(h, False) is True) column_name = self.capabilities.escape_identifier(c["name"]) return f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index cfde6625d5..0124a277d3 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -12,7 +12,7 @@ from dlt.common import json, pendulum, logger from dlt.common.data_types import TDataType -from dlt.common.schema.typing import COLUMN_HINTS, TColumnType, TColumnSchemaBase, TTableSchema, TWriteDisposition +from dlt.common.schema.typing import COLUMN_HINTS, TColumnType, TColumnSchemaBase, TTableSchema, TWriteDisposition, TTableFormat from dlt.common.storages import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables from dlt.common.destination.reference import StateInfo, StorageSchemaInfo,WithStateSync, DestinationClientConfiguration, DestinationClientDwhConfiguration, DestinationClientDwhWithStagingConfiguration, NewLoadJob, WithStagingDataset, TLoadJobState, LoadJob, JobClientBase, FollowupJob, CredentialsConfiguration @@ -318,23 +318,24 @@ def _build_schema_update_sql(self, only_tables: Iterable[str]) -> Tuple[List[str return sql_updates, schema_update - def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str]: + def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None) -> List[str]: """Make one or more ADD COLUMN sql clauses to be joined in ALTER TABLE statement(s)""" - return [f"ADD COLUMN {self._get_column_def_sql(c)}" for c in new_columns] + return [f"ADD COLUMN {self._get_column_def_sql(c, table_format)}" for c in new_columns] def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool) -> List[str]: # build sql canonical_name = self.sql_client.make_qualified_table_name(table_name) + table = self.get_load_table(table_name) sql_result: List[str] = [] if not generate_alter: # build CREATE sql = f"CREATE TABLE {canonical_name} (\n" - sql += ",\n".join([self._get_column_def_sql(c) for c in new_columns]) + sql += ",\n".join([self._get_column_def_sql(c, table.get("table_format")) for c in new_columns]) sql += ")" sql_result.append(sql) else: sql_base = f"ALTER TABLE {canonical_name}\n" - add_column_statements = self._make_add_column_sql(new_columns) + add_column_statements = self._make_add_column_sql(new_columns, table.get("table_format")) if self.capabilities.alter_add_multi_column: column_sql = ",\n" sql_result.append(sql_base + column_sql.join(add_column_statements)) @@ -357,7 +358,7 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc return sql_result @abstractmethod - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: pass @staticmethod diff --git a/dlt/destinations/mssql/mssql.py b/dlt/destinations/mssql/mssql.py index c06ddeadbd..cd999441ff 100644 --- a/dlt/destinations/mssql/mssql.py +++ b/dlt/destinations/mssql/mssql.py @@ -5,7 +5,7 @@ from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.common.utils import uniq_id from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob, SqlJobParams @@ -62,7 +62,7 @@ class MsSqlTypeMapper(TypeMapper): "int": "bigint", } - def to_db_integer_type(self, precision: Optional[int]) -> str: + def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: if precision is None: return "bigint" if precision <= 8: @@ -136,11 +136,11 @@ def __init__(self, schema: Schema, config: MsSqlClientConfiguration) -> None: def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: return [MsSqlMergeJob.from_table_chain(table_chain, self.sql_client)] - def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str]: + def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None) -> List[str]: # Override because mssql requires multiple columns in a single ADD COLUMN clause - return ["ADD \n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] + return ["ADD \n" + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns)] - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: sc_type = c["data_type"] if sc_type == "text" and c.get("unique"): # MSSQL does not allow index on large TEXT columns diff --git a/dlt/destinations/postgres/postgres.py b/dlt/destinations/postgres/postgres.py index 72837b42b3..2812d1d4c4 100644 --- a/dlt/destinations/postgres/postgres.py +++ b/dlt/destinations/postgres/postgres.py @@ -5,7 +5,7 @@ from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams @@ -59,7 +59,7 @@ class PostgresTypeMapper(TypeMapper): "integer": "bigint", } - def to_db_integer_type(self, precision: Optional[int]) -> str: + def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: if precision is None: return "bigint" # Precision is number of bits @@ -109,7 +109,7 @@ def __init__(self, schema: Schema, config: PostgresClientConfiguration) -> None: self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {} self.type_mapper = PostgresTypeMapper(self.capabilities) - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: hints_str = " ".join(self.active_hints.get(h, "") for h in self.active_hints.keys() if c.get(h, False) is True) column_name = self.capabilities.escape_identifier(c["name"]) return f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" diff --git a/dlt/destinations/redshift/redshift.py b/dlt/destinations/redshift/redshift.py index 77ef05de74..888f27ae7c 100644 --- a/dlt/destinations/redshift/redshift.py +++ b/dlt/destinations/redshift/redshift.py @@ -17,7 +17,7 @@ from dlt.common.destination.reference import NewLoadJob, CredentialsConfiguration, SupportsStagingDestination from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults from dlt.destinations.insert_job_client import InsertValuesJobClient @@ -76,7 +76,7 @@ class RedshiftTypeMapper(TypeMapper): "integer": "bigint", } - def to_db_integer_type(self, precision: Optional[int]) -> str: + def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: if precision is None: return "bigint" if precision <= 16: @@ -204,7 +204,7 @@ def __init__(self, schema: Schema, config: RedshiftClientConfiguration) -> None: def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)] - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: hints_str = " ".join(HINT_TO_REDSHIFT_ATTR.get(h, "") for h in HINT_TO_REDSHIFT_ATTR.keys() if c.get(h, False) is True) column_name = self.capabilities.escape_identifier(c["name"]) return f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" diff --git a/dlt/destinations/snowflake/snowflake.py b/dlt/destinations/snowflake/snowflake.py index 5aa721dd03..f433ec7e7d 100644 --- a/dlt/destinations/snowflake/snowflake.py +++ b/dlt/destinations/snowflake/snowflake.py @@ -7,7 +7,7 @@ from dlt.common.data_types import TDataType from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.destinations.job_client_impl import SqlJobClientWithStaging @@ -200,9 +200,9 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str]: + def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None) -> List[str]: # Override because snowflake requires multiple columns in a single ADD COLUMN clause - return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] + return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns)] def _create_replace_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: if self.config.replace_strategy == "staging-optimized": @@ -222,7 +222,7 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc def _from_db_type(self, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: name = self.capabilities.escape_identifier(c["name"]) return f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" diff --git a/dlt/destinations/type_mapping.py b/dlt/destinations/type_mapping.py index dcbb1a4261..3ddfee5904 100644 --- a/dlt/destinations/type_mapping.py +++ b/dlt/destinations/type_mapping.py @@ -1,6 +1,6 @@ from typing import Tuple, ClassVar, Dict, Optional -from dlt.common.schema.typing import TColumnSchema, TDataType, TColumnType +from dlt.common.schema.typing import TColumnSchema, TDataType, TColumnType, TTableFormat from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.utils import without_none @@ -20,15 +20,15 @@ class TypeMapper: def __init__(self, capabilities: DestinationCapabilitiesContext) -> None: self.capabilities = capabilities - def to_db_integer_type(self, precision: Optional[int]) -> str: + def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: # Override in subclass if db supports other integer types (e.g. smallint, integer, tinyint, etc.) return self.sct_to_unbound_dbt["bigint"] - def to_db_type(self, column: TColumnSchema) -> str: + def to_db_type(self, column: TColumnSchema, table_format: TTableFormat = None) -> str: precision, scale = column.get("precision"), column.get("scale") sc_t = column["data_type"] if sc_t == "bigint": - return self.to_db_integer_type(precision) + return self.to_db_integer_type(precision, table_format) bounded_template = self.sct_to_dbt.get(sc_t) if not bounded_template: return self.sct_to_unbound_dbt[sc_t]