From f2bac9642caafc56ba6ff834326086d306428ebb Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 17 Dec 2024 00:23:37 +0100 Subject: [PATCH] refactors column and constraint sql in job client --- .../impl/databricks/databricks.py | 6 ---- dlt/destinations/impl/dremio/dremio.py | 6 ---- dlt/destinations/impl/duckdb/duck.py | 11 ------ dlt/destinations/impl/mssql/mssql.py | 6 +--- dlt/destinations/impl/postgres/postgres.py | 12 ------- dlt/destinations/impl/redshift/redshift.py | 12 +------ dlt/destinations/job_client_impl.py | 36 ++++++++++++++----- 7 files changed, 29 insertions(+), 60 deletions(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 2bb68a607e..a83db6ec34 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -264,12 +264,6 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_destination_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" - ) - def _get_storage_table_query_columns(self) -> List[str]: fields = super()._get_storage_table_query_columns() fields[2] = ( # Override because this is the only way to get data type with precision diff --git a/dlt/destinations/impl/dremio/dremio.py b/dlt/destinations/impl/dremio/dremio.py index ab23f58ab4..e3a090c824 100644 --- a/dlt/destinations/impl/dremio/dremio.py +++ b/dlt/destinations/impl/dremio/dremio.py @@ -151,12 +151,6 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_destination_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" - ) - def _create_merge_followup_jobs( self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index 3bd4c83e1f..2b3370270b 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -74,17 +74,6 @@ def create_load_job( job = DuckDbCopyJob(file_path) return job - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - hints_str = " ".join( - self.active_hints.get(h, "") - for h in self.active_hints.keys() - if c.get(h, False) is True - ) - column_name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" - ) - def _from_db_type( self, pq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index 27aebe07f2..7b48a6b551 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -115,11 +115,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = Non else: db_type = self.type_mapper.to_destination_type(c, table) - hints_str = " ".join( - self.active_hints.get(h, "") - for h in self.active_hints.keys() - if c.get(h, False) is True - ) + hints_str = self._get_column_hints_sql(c) column_name = self.sql_client.escape_column_name(c["name"]) return f"{column_name} {db_type} {hints_str} {self._gen_not_null(c.get('nullable', True))}" diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index 2459ee1dbe..3d54b59f93 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -161,18 +161,6 @@ def create_load_job( job = PostgresCsvCopyJob(file_path) return job - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - hints_ = " ".join( - self.active_hints.get(h, "") - for h in self.active_hints.keys() - if c.get(h, False) is True - ) - column_name = self.sql_client.escape_column_name(c["name"]) - nullability = self._gen_not_null(c.get("nullable", True)) - column_type = self.type_mapper.to_destination_type(c, table) - - return f"{column_name} {column_type} {hints_} {nullability}" - def _create_replace_followup_jobs( self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index 2335166761..b1aa37ce6a 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -153,6 +153,7 @@ def __init__( capabilities, ) super().__init__(schema, config, sql_client) + self.active_hints = HINT_TO_REDSHIFT_ATTR self.sql_client = sql_client self.config: RedshiftClientConfiguration = config self.type_mapper = self.capabilities.get_type_mapper() @@ -162,17 +163,6 @@ def _create_merge_followup_jobs( ) -> List[FollowupJobRequest]: return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)] - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - hints_str = " ".join( - HINT_TO_REDSHIFT_ATTR.get(h, "") - for h in HINT_TO_REDSHIFT_ATTR.keys() - if c.get(h, False) is True - ) - column_name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" - ) - def create_load_job( self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index d1f211b1e9..12cb129812 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -7,6 +7,7 @@ from typing import ( Any, ClassVar, + Dict, List, Optional, Sequence, @@ -14,21 +15,18 @@ Type, Iterable, Iterator, - Generator, ) import zlib import re -from contextlib import contextmanager -from contextlib import suppress from dlt.common import pendulum, logger +from dlt.common.destination.capabilities import DataTypeMapper from dlt.common.json import json from dlt.common.schema.typing import ( C_DLT_LOAD_ID, COLUMN_HINTS, TColumnType, TColumnSchemaBase, - TTableFormat, ) from dlt.common.schema.utils import ( get_inherited_table_hint, @@ -40,11 +38,11 @@ from dlt.common.storages import FileStorage from dlt.common.storages.load_package import LoadJobInfo, ParsedLoadJobFileName from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables +from dlt.common.schema import TColumnHint from dlt.common.destination.reference import ( PreparedTableSchema, StateInfo, StorageSchemaInfo, - SupportsReadableDataset, WithStateSync, DestinationClientConfiguration, DestinationClientDwhConfiguration, @@ -55,9 +53,7 @@ JobClientBase, HasFollowupJobs, CredentialsConfiguration, - SupportsReadableRelation, ) -from dlt.destinations.dataset import ReadableDBAPIDataset from dlt.destinations.exceptions import DatabaseUndefinedRelation from dlt.destinations.job_impl import ( @@ -154,6 +150,8 @@ def __init__( self.state_table_columns = ", ".join( sql_client.escape_column_name(col) for col in state_table_["columns"] ) + self.active_hints: Dict[TColumnHint, str] = {} + self.type_mapper: DataTypeMapper = None super().__init__(schema, config, sql_client.capabilities) self.sql_client = sql_client assert isinstance(config, DestinationClientDwhConfiguration) @@ -569,6 +567,7 @@ def _get_table_update_sql( # build CREATE sql = self._make_create_table(qualified_name, table) + " (\n" sql += ",\n".join([self._get_column_def_sql(c, table) for c in new_columns]) + sql += self._get_constraints_sql(table_name, new_columns, generate_alter) sql += ")" sql_result.append(sql) else: @@ -582,8 +581,16 @@ def _get_table_update_sql( sql_result.extend( [sql_base + col_statement for col_statement in add_column_statements] ) + constraints_sql = self._get_constraints_sql(table_name, new_columns, generate_alter) + if constraints_sql: + sql_result.append(constraints_sql) return sql_result + def _get_constraints_sql( + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + ) -> str: + return "" + def _check_table_update_hints( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> None: @@ -613,9 +620,20 @@ def _check_table_update_hints( " existing tables." ) - @abstractmethod def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - pass + hints_ = self._get_column_hints_sql(c) + column_name = self.sql_client.escape_column_name(c["name"]) + nullability = self._gen_not_null(c.get("nullable", True)) + column_type = self.type_mapper.to_destination_type(c, table) + + return f"{column_name} {column_type} {hints_} {nullability}" + + def _get_column_hints_sql(self, c: TColumnSchema) -> str: + return " ".join( + self.active_hints.get(h, "") + for h in self.active_hints.keys() + if c.get(h, False) is True # use ColumnPropInfos to get default value + ) @staticmethod def _gen_not_null(nullable: bool) -> str: