From 07076293bec0aa8f48dc3a31d2f7bb49859de3be Mon Sep 17 00:00:00 2001 From: Dave Date: Fri, 13 Oct 2023 15:58:25 +0200 Subject: [PATCH] small changes --- dlt/destinations/athena/athena.py | 19 +++++++++---------- dlt/destinations/job_client_impl.py | 2 +- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/dlt/destinations/athena/athena.py b/dlt/destinations/athena/athena.py index 9aa0493a4e..6e032a5acf 100644 --- a/dlt/destinations/athena/athena.py +++ b/dlt/destinations/athena/athena.py @@ -69,18 +69,17 @@ class AthenaTypeMapper(TypeMapper): "int": "bigint", } - def __init__(self, capabilities: DestinationCapabilitiesContext, iceberg_mode: bool): + def __init__(self, capabilities: DestinationCapabilitiesContext): super().__init__(capabilities) - self.iceberg_mode = iceberg_mode def to_db_integer_type(self, precision: Optional[int]) -> str: if precision is None: return "bigint" - # iceberg does not support smallint and tinyint + # TODO: iceberg does not support smallint and tinyint if precision <= 8: - return "int" if self.iceberg_mode else "tinyint" + return "int" elif precision <= 16: - return "int" if self.iceberg_mode else "smallint" + return "int" elif precision <= 32: return "int" return "bigint" @@ -303,7 +302,7 @@ def __init__(self, schema: Schema, config: AthenaClientConfiguration) -> None: super().__init__(schema, config, sql_client) self.sql_client: AthenaSQLClient = sql_client # type: ignore self.config: AthenaClientConfiguration = config - self.type_mapper = AthenaTypeMapper(self.capabilities, True) + self.type_mapper = AthenaTypeMapper(self.capabilities) def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: # only truncate tables in iceberg mode @@ -364,12 +363,12 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> return job def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: - if self._is_iceberg_table(table_chain[0]): + if self._is_iceberg_table(self.get_load_table(table_chain[0]["name"])): return [SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": False})] return super()._create_append_followup_jobs(table_chain) def _create_replace_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: - if self._is_iceberg_table(table_chain[0]): + if self._is_iceberg_table(self.get_load_table(table_chain[0]["name"])): return [SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True})] return super()._create_replace_followup_jobs(table_chain) @@ -400,10 +399,10 @@ def should_load_data_to_staging_dataset_on_staging_destination(self, table: TTab def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: table = super().get_load_table(table_name, staging) + if self.config.force_iceberg: + table["table_format"] ="iceberg" if staging and table.get("table_format", None) == "iceberg": table.pop("table_format") - elif self.config.force_iceberg: - table["table_format"] = "iceberg" return table @staticmethod diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 8b26ac06ee..cfde6625d5 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -311,7 +311,7 @@ def _build_schema_update_sql(self, only_tables: Iterable[str]) -> Tuple[List[str sql += ";" sql_updates.append(sql) # create a schema update for particular table - partial_table = copy(self.schema.get_table(table_name)) + partial_table = copy(self.get_load_table(table_name)) # keep only new columns partial_table["columns"] = {c["name"]: c for c in new_columns} schema_update[table_name] = partial_table