Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Oct 13, 2023
1 parent acfcd16 commit 0707629
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
19 changes: 9 additions & 10 deletions dlt/destinations/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,17 @@ class AthenaTypeMapper(TypeMapper):
"int": "bigint",
}

def __init__(self, capabilities: DestinationCapabilitiesContext, iceberg_mode: bool):
def __init__(self, capabilities: DestinationCapabilitiesContext):
super().__init__(capabilities)
self.iceberg_mode = iceberg_mode

def to_db_integer_type(self, precision: Optional[int]) -> str:
if precision is None:
return "bigint"
# iceberg does not support smallint and tinyint
# TODO: iceberg does not support smallint and tinyint
if precision <= 8:
return "int" if self.iceberg_mode else "tinyint"
return "int"
elif precision <= 16:
return "int" if self.iceberg_mode else "smallint"
return "int"
elif precision <= 32:
return "int"
return "bigint"
Expand Down Expand Up @@ -303,7 +302,7 @@ def __init__(self, schema: Schema, config: AthenaClientConfiguration) -> None:
super().__init__(schema, config, sql_client)
self.sql_client: AthenaSQLClient = sql_client # type: ignore
self.config: AthenaClientConfiguration = config
self.type_mapper = AthenaTypeMapper(self.capabilities, True)
self.type_mapper = AthenaTypeMapper(self.capabilities)

def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None:
# only truncate tables in iceberg mode
Expand Down Expand Up @@ -364,12 +363,12 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) ->
return job

def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]:
if self._is_iceberg_table(table_chain[0]):
if self._is_iceberg_table(self.get_load_table(table_chain[0]["name"])):
return [SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": False})]
return super()._create_append_followup_jobs(table_chain)

def _create_replace_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]:
if self._is_iceberg_table(table_chain[0]):
if self._is_iceberg_table(self.get_load_table(table_chain[0]["name"])):
return [SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True})]
return super()._create_replace_followup_jobs(table_chain)

Expand Down Expand Up @@ -400,10 +399,10 @@ def should_load_data_to_staging_dataset_on_staging_destination(self, table: TTab

def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema:
table = super().get_load_table(table_name, staging)
if self.config.force_iceberg:
table["table_format"] ="iceberg"
if staging and table.get("table_format", None) == "iceberg":
table.pop("table_format")
elif self.config.force_iceberg:
table["table_format"] = "iceberg"
return table

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def _build_schema_update_sql(self, only_tables: Iterable[str]) -> Tuple[List[str
sql += ";"
sql_updates.append(sql)
# create a schema update for particular table
partial_table = copy(self.schema.get_table(table_name))
partial_table = copy(self.get_load_table(table_name))
# keep only new columns
partial_table["columns"] = {c["name"]: c for c in new_columns}
schema_update[table_name] = partial_table
Expand Down

0 comments on commit 0707629

Please sign in to comment.