From d513d2c7fe1eb40e982b072b90ee062cb4ffdff2 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 20 Sep 2023 20:32:08 -0400 Subject: [PATCH] Init from_db_type --- dlt/common/schema/typing.py | 11 ++-- dlt/destinations/athena/athena.py | 42 +++++++-------- dlt/destinations/bigquery/bigquery.py | 69 ++++++++++--------------- dlt/destinations/duckdb/duck.py | 55 +++++++++----------- dlt/destinations/job_client_impl.py | 7 ++- dlt/destinations/mssql/mssql.py | 41 ++++----------- dlt/destinations/postgres/postgres.py | 65 +++++++++-------------- dlt/destinations/redshift/redshift.py | 51 +++++++++--------- dlt/destinations/snowflake/snowflake.py | 50 +++++++++--------- dlt/destinations/type_mapping.py | 15 ++++-- 10 files changed, 174 insertions(+), 232 deletions(-) diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index da0e57ce36..ae24691e2d 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -34,10 +34,15 @@ WRITE_DISPOSITIONS: Set[TWriteDisposition] = set(get_args(TWriteDisposition)) -class TColumnSchemaBase(TypedDict, total=False): +class TColumnType(TypedDict, total=False): + data_type: Optional[TDataType] + precision: Optional[int] + scale: Optional[int] + + +class TColumnSchemaBase(TColumnType, total=False): """TypedDict that defines basic properties of a column: name, data type and nullable""" name: Optional[str] - data_type: Optional[TDataType] nullable: Optional[bool] @@ -53,8 +58,6 @@ class TColumnSchema(TColumnSchemaBase, total=False): root_key: Optional[bool] merge_key: Optional[bool] variant: Optional[bool] - precision: Optional[int] - scale: Optional[int] TTableSchemaColumns = Dict[str, TColumnSchema] diff --git a/dlt/destinations/athena/athena.py b/dlt/destinations/athena/athena.py index 6ef9945daa..2f9e52e0ff 100644 --- a/dlt/destinations/athena/athena.py +++ b/dlt/destinations/athena/athena.py @@ -15,7 +15,7 @@ from dlt.common import logger from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, Schema -from dlt.common.schema.typing import TTableSchema +from dlt.common.schema.typing import TTableSchema, TColumnType from dlt.common.schema.utils import table_schema_has_type from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import LoadJob @@ -35,20 +35,6 @@ from dlt.destinations import path_utils - -HIVET_TO_SCT: Dict[str, TDataType] = { - "varchar": "text", - "double": "double", - "boolean": "bool", - "date": "date", - "timestamp": "timestamp", - "bigint": "bigint", - "binary": "binary", - "varbinary": "binary", - "decimal": "decimal", -} - - class AthenaTypeMapper(TypeMapper): sct_to_unbound_dbt = { "complex": "string", @@ -67,6 +53,24 @@ class AthenaTypeMapper(TypeMapper): "wei": "decimal(%i,%i)" } + dbt_to_sct = { + "varchar": "text", + "double": "double", + "boolean": "bool", + "date": "date", + "timestamp": "timestamp", + "bigint": "bigint", + "binary": "binary", + "varbinary": "binary", + "decimal": "decimal", + } + + def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + for key, val in self.dbt_to_sct.items(): + if db_type.startswith(key): + return dict(data_type=val, precision=precision, scale=scale) + return dict(data_type=None) + # add a formatter for pendulum to be used by pyathen dbapi def _format_pendulum_datetime(formatter: Formatter, escaper: Callable[[str], str], val: Any) -> Any: @@ -280,12 +284,8 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: # never truncate tables in athena super().initialize_storage([]) - @classmethod - def _from_db_type(cls, hive_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: - for key, val in HIVET_TO_SCT.items(): - if hive_t.startswith(key): - return val - return None + def _from_db_type(self, hive_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + return self.type_mapper.from_db_type(hive_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema) -> str: return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c)}" diff --git a/dlt/destinations/bigquery/bigquery.py b/dlt/destinations/bigquery/bigquery.py index a3cdcb0ba0..473fee2113 100644 --- a/dlt/destinations/bigquery/bigquery.py +++ b/dlt/destinations/bigquery/bigquery.py @@ -11,7 +11,7 @@ from dlt.common.data_types import TDataType from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.typing import TTableSchema +from dlt.common.schema.typing import TTableSchema, TColumnType from dlt.destinations.job_client_impl import SqlJobClientWithStaging from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate, DestinationTransientException, LoadJobNotExistsException, LoadJobTerminalException, LoadJobUnknownTableException @@ -26,32 +26,8 @@ from dlt.common.schema.utils import table_schema_has_type -SCT_TO_BQT: Dict[TDataType, str] = { - "complex": "JSON", - "text": "STRING", - "double": "FLOAT64", - "bool": "BOOLEAN", - "date": "DATE", - "timestamp": "TIMESTAMP", - "bigint": "INTEGER", - "binary": "BYTES", - "decimal": "NUMERIC(%i,%i)", - "wei": "BIGNUMERIC", # non parametrized should hold wei values - "time": "TIME", -} - BQT_TO_SCT: Dict[str, TDataType] = { - "STRING": "text", - "FLOAT": "double", - "BOOLEAN": "bool", - "DATE": "date", - "TIMESTAMP": "timestamp", - "INTEGER": "bigint", - "BYTES": "binary", - "NUMERIC": "decimal", - "BIGNUMERIC": "decimal", - "JSON": "complex", - "TIME": "time", + } @@ -75,6 +51,27 @@ class BigQueryTypeMapper(TypeMapper): "decimal": "NUMERIC(%i,%i)", } + dbt_to_sct = { + "STRING": "text", + "FLOAT": "double", + "BOOLEAN": "bool", + "DATE": "date", + "TIMESTAMP": "timestamp", + "INTEGER": "bigint", + "BYTES": "binary", + "NUMERIC": "decimal", + "BIGNUMERIC": "decimal", + "JSON": "complex", + "TIME": "time", + } + + def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + if db_type == "BIGNUMERIC": + if precision is None: # biggest numeric possible + return dict(data_type="wei") + return super().from_db_type(db_type, precision, scale) + + class BigQueryLoadJob(LoadJob, FollowupJob): def __init__( self, @@ -267,13 +264,13 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] schema_c: TColumnSchema = { "name": c.name, "nullable": c.is_nullable, - "data_type": self._from_db_type(c.field_type, c.precision, c.scale), "unique": False, "sort": False, "primary_key": False, "foreign_key": False, "cluster": c.name in (table.clustering_fields or []), - "partition": c.name == partition_field + "partition": c.name == partition_field, + **self._from_db_type(c.field_type, c.precision, c.scale) # type: ignore[misc] } schema_table[c.name] = schema_c return True, schema_table @@ -334,17 +331,5 @@ def _retrieve_load_job(self, file_path: str) -> bigquery.LoadJob: job_id = BigQueryLoadJob.get_job_id_from_file_path(file_path) return cast(bigquery.LoadJob, self.sql_client.native_connection.get_job(job_id)) - # @classmethod - # def _to_db_type(cls, sc_t: TDataType) -> str: - # if sc_t == "decimal": - # return SCT_TO_BQT["decimal"] % cls.capabilities.decimal_precision - # return SCT_TO_BQT[sc_t] - - @classmethod - def _from_db_type(cls, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: - if bq_t == "BIGNUMERIC": - if precision is None: # biggest numeric possible - return "wei" - return BQT_TO_SCT.get(bq_t, "text") - - + def _from_db_type(self, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + return self.type_mapper.from_db_type(bq_t, precision, scale) diff --git a/dlt/destinations/duckdb/duck.py b/dlt/destinations/duckdb/duck.py index d71b387af2..ffe7310ddc 100644 --- a/dlt/destinations/duckdb/duck.py +++ b/dlt/destinations/duckdb/duck.py @@ -5,7 +5,7 @@ from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState -from dlt.common.schema.typing import TTableSchema +from dlt.common.schema.typing import TTableSchema, TColumnType from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import maybe_context @@ -17,19 +17,6 @@ from dlt.destinations.type_mapping import TypeMapper -# SCT_TO_PGT: Dict[TDataType, str] = { -# "complex": "JSON", -# "text": "VARCHAR", -# "double": "DOUBLE", -# "bool": "BOOLEAN", -# "date": "DATE", -# "timestamp": "TIMESTAMP WITH TIME ZONE", -# "bigint": "BIGINT", -# "binary": "BLOB", -# "decimal": "DECIMAL(%i,%i)", -# "time": "TIME" -# } - PGT_TO_SCT: Dict[str, TDataType] = { "VARCHAR": "text", "JSON": "complex", @@ -73,6 +60,28 @@ class DuckDbTypeMapper(TypeMapper): "wei": "DECIMAL(%i,%i)", } + dbt_to_sct = { + "VARCHAR": "text", + "JSON": "complex", + "DOUBLE": "double", + "BOOLEAN": "bool", + "DATE": "date", + "TIMESTAMP WITH TIME ZONE": "timestamp", + "BIGINT": "bigint", + "BLOB": "binary", + "DECIMAL": "decimal", + "TIME": "time" + } + + def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + # duckdb provides the types with scale and precision + db_type = db_type.split("(")[0].upper() + if db_type == "DECIMAL": + if precision == 38 and scale == 0: + return dict(data_type="wei", precision=precision, scale=scale) + return super().from_db_type(db_type, precision, scale) + + class DuckDbCopyJob(LoadJob, FollowupJob): def __init__(self, table_name: str, file_path: str, sql_client: DuckDbSqlClient) -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) @@ -130,19 +139,5 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str: column_name = self.capabilities.escape_identifier(c["name"]) return f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" - # @classmethod - # def _to_db_type(cls, sc_t: TDataType) -> str: - # if sc_t == "wei": - # return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision - # if sc_t == "decimal": - # return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision - # return SCT_TO_PGT[sc_t] - - @classmethod - def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: - # duckdb provides the types with scale and precision - pq_t = pq_t.split("(")[0].upper() - if pq_t == "DECIMAL": - if precision == 38 and scale == 0: - return "wei" - return PGT_TO_SCT[pq_t] + def _from_db_type(self, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + return self.type_mapper.from_db_type(pq_t, precision, scale) diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index d9ef55ba77..c082eefb93 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -12,7 +12,7 @@ from dlt.common import json, pendulum, logger from dlt.common.data_types import TDataType -from dlt.common.schema.typing import COLUMN_HINTS, TColumnSchemaBase, TTableSchema, TWriteDisposition +from dlt.common.schema.typing import COLUMN_HINTS, TColumnType, TColumnSchemaBase, TTableSchema, TWriteDisposition from dlt.common.storages import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables from dlt.common.destination.reference import StateInfo, StorageSchemaInfo,WithStateSync, DestinationClientConfiguration, DestinationClientDwhConfiguration, DestinationClientDwhWithStagingConfiguration, NewLoadJob, WithStagingDataset, TLoadJobState, LoadJob, JobClientBase, FollowupJob, CredentialsConfiguration @@ -242,14 +242,13 @@ def _null_to_bool(v: str) -> bool: schema_c: TColumnSchemaBase = { "name": c[0], "nullable": _null_to_bool(c[2]), - "data_type": self._from_db_type(c[1], numeric_precision, numeric_scale), + **self._from_db_type(c[1], numeric_precision, numeric_scale), # type: ignore[misc] } schema_table[c[0]] = schema_c # type: ignore return True, schema_table - @classmethod @abstractmethod - def _from_db_type(cls, db_type: str, precision: Optional[int], scale: Optional[int]) -> TDataType: + def _from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: pass def get_stored_schema(self) -> StorageSchemaInfo: diff --git a/dlt/destinations/mssql/mssql.py b/dlt/destinations/mssql/mssql.py index e99e00b350..c2c38ad333 100644 --- a/dlt/destinations/mssql/mssql.py +++ b/dlt/destinations/mssql/mssql.py @@ -5,7 +5,7 @@ from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.schema.typing import TTableSchema +from dlt.common.schema.typing import TTableSchema, TColumnType from dlt.common.utils import uniq_id from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob @@ -19,19 +19,6 @@ from dlt.destinations.type_mapping import TypeMapper -# SCT_TO_PGT: Dict[TDataType, str] = { -# "complex": "nvarchar(max)", -# "text": "nvarchar(max)", -# "double": "float", -# "bool": "bit", -# "timestamp": "datetimeoffset", -# "date": "date", -# "bigint": "bigint", -# "binary": "varbinary(max)", -# "decimal": "decimal(%i,%i)", -# "time": "time" -# } - PGT_TO_SCT: Dict[str, TDataType] = { "nvarchar": "text", "float": "double", @@ -70,6 +57,12 @@ class MsSqlTypeMapper(TypeMapper): "wei": "decimal(%i,%i)" } + def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + if db_type == "numeric": + if (precision, scale) == self.capabilities.wei_precision: + return dict(data_type="wei", precision=precision, scale=scale) + return super().from_db_type(db_type, precision, scale) + class MsSqlStagingCopyJob(SqlStagingCopyJob): @@ -108,6 +101,7 @@ def _new_temp_table_name(cls, name_prefix: str) -> str: name = SqlMergeJob._new_temp_table_name(name_prefix) return '#' + name + class MsSqlClient(InsertValuesJobClient): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -145,20 +139,5 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str: def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return MsSqlStagingCopyJob.from_table_chain(table_chain, self.sql_client) - # @classmethod - # def _to_db_type(cls, sc_t: TDataType) -> str: - # if sc_t == "wei": - # return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision - # if sc_t == "decimal": - # return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision - - # if sc_t == "wei": - # return f"numeric({2*EVM_DECIMAL_PRECISION},{EVM_DECIMAL_PRECISION})" - # return SCT_TO_PGT[sc_t] - - @classmethod - def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: - if pq_t == "numeric": - if (precision, scale) == cls.capabilities.wei_precision: - return "wei" - return PGT_TO_SCT[pq_t] + def _from_db_type(self, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + return self.type_mapper.from_db_type(pq_t, precision, scale) diff --git a/dlt/destinations/postgres/postgres.py b/dlt/destinations/postgres/postgres.py index 6ecf0a8dde..9bfaba0616 100644 --- a/dlt/destinations/postgres/postgres.py +++ b/dlt/destinations/postgres/postgres.py @@ -5,7 +5,7 @@ from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.schema.typing import TTableSchema +from dlt.common.schema.typing import TTableSchema, TColumnType from dlt.destinations.sql_jobs import SqlStagingCopyJob @@ -18,19 +18,6 @@ from dlt.destinations.type_mapping import TypeMapper -PGT_TO_SCT: Dict[str, TDataType] = { - "varchar": "text", - "jsonb": "complex", - "double precision": "double", - "boolean": "bool", - "timestamp with time zone": "timestamp", - "date": "date", - "bigint": "bigint", - "bytea": "binary", - "numeric": "decimal", - "time without time zone": "time" -} - HINT_TO_POSTGRES_ATTR: Dict[TColumnHint, str] = { "unique": "UNIQUE" } @@ -55,6 +42,26 @@ class PostgresTypeMapper(TypeMapper): "wei": "numeric(%i,%i)" } + dbt_to_sct = { + "varchar": "text", + "jsonb": "complex", + "double precision": "double", + "boolean": "bool", + "timestamp with time zone": "timestamp", + "date": "date", + "bigint": "bigint", + "bytea": "binary", + "numeric": "decimal", + "time without time zone": "time", + "character varying": "text", + } + + def from_db_type(self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None) -> TColumnType: + if db_type == "numeric": + if (precision, scale) == self.capabilities.wei_precision: + return dict(data_type="wei", precision=precision, scale=scale) + return super().from_db_type(db_type, precision, scale) + class PostgresStagingCopyJob(SqlStagingCopyJob): @@ -97,31 +104,5 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str: def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return PostgresStagingCopyJob.from_table_chain(table_chain, self.sql_client) - # @classmethod - # def _to_db_type(cls, column: TColumnSchema) -> str: - # sc_t = column["data_type"] - # precision, scale = column.get("precision"), column.get("scale") - # if sc_t in ("timestamp", "time"): - # return SCT_TO_PGT[sc_t] % (precision or cls.capabilities.timestamp_precision) - # if sc_t == "wei": - # default_precision, default_scale = cls.capabilities.wei_precision - # precision = precision if precision is not None else default_precision - # scale = scale if scale is not None else default_scale - # return SCT_TO_PGT["decimal"] % (precision, scale) - # if sc_t == "decimal": - # default_precision, default_scale = cls.capabilities.decimal_precision - # precision = precision if precision is not None else default_precision - # scale = scale if scale is not None else default_scale - # return SCT_TO_PGT["decimal"] % (precision, scale) - # if sc_t == "text": - # if precision is not None: - # return "varchar (%i)" % precision - # return "varchar" - # return SCT_TO_PGT[sc_t] - - @classmethod - def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: - if pq_t == "numeric": - if (precision, scale) == cls.capabilities.wei_precision: - return "wei" - return PGT_TO_SCT.get(pq_t, "text") + def _from_db_type(self, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + return self.type_mapper.from_db_type(pq_t, precision, scale) diff --git a/dlt/destinations/redshift/redshift.py b/dlt/destinations/redshift/redshift.py index 42e971be59..b53182899f 100644 --- a/dlt/destinations/redshift/redshift.py +++ b/dlt/destinations/redshift/redshift.py @@ -17,7 +17,7 @@ from dlt.common.destination.reference import NewLoadJob, CredentialsConfiguration from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.schema.typing import TTableSchema +from dlt.common.schema.typing import TTableSchema, TColumnType from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults from dlt.destinations.insert_job_client import InsertValuesJobClient @@ -32,19 +32,6 @@ from dlt.destinations.type_mapping import TypeMapper -PGT_TO_SCT: Dict[str, TDataType] = { - "super": "complex", - "varchar(max)": "text", - "double precision": "double", - "boolean": "bool", - "date": "date", - "timestamp with time zone": "timestamp", - "bigint": "bigint", - "binary varying": "binary", - "numeric": "decimal", - "time without time zone": "time" -} - HINT_TO_REDSHIFT_ATTR: Dict[TColumnHint, str] = { "cluster": "DISTKEY", # it is better to not enforce constraints in redshift @@ -73,6 +60,26 @@ class RedshiftTypeMapper(TypeMapper): "binary": "varbinary(%i)", } + dbt_to_sct = { + "super": "complex", + "varchar(max)": "text", + "double precision": "double", + "boolean": "bool", + "date": "date", + "timestamp with time zone": "timestamp", + "bigint": "bigint", + "binary varying": "binary", + "numeric": "decimal", + "time without time zone": "time", + "varchar": "text", + } + + def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + if db_type == "numeric": + if (precision, scale) == self.capabilities.wei_precision: + return dict(data_type="wei", precision=precision, scale=scale) + return super().from_db_type(db_type, precision, scale) + class RedshiftSqlClient(Psycopg2SqlClient): @@ -194,17 +201,5 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> job = RedshiftCopyFileLoadJob(table, file_path, self.sql_client, staging_credentials=self.config.staging_config.credentials, staging_iam_role=self.config.staging_iam_role) return job - # @classmethod - # def _to_db_type(cls, sc_t: TDataType) -> str: - # if sc_t == "wei": - # return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision - # if sc_t == "decimal": - # return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision - # return SCT_TO_PGT[sc_t] - - @classmethod - def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: - if pq_t == "numeric": - if (precision, scale) == cls.capabilities.wei_precision: - return "wei" - return PGT_TO_SCT.get(pq_t, "text") + def _from_db_type(self, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + return self.type_mapper.from_db_type(pq_t, precision, scale) diff --git a/dlt/destinations/snowflake/snowflake.py b/dlt/destinations/snowflake/snowflake.py index 33fa3e3cfd..6d19717e17 100644 --- a/dlt/destinations/snowflake/snowflake.py +++ b/dlt/destinations/snowflake/snowflake.py @@ -7,7 +7,7 @@ from dlt.common.data_types import TDataType from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.typing import TTableSchema +from dlt.common.schema.typing import TTableSchema, TColumnType from dlt.destinations.job_client_impl import SqlJobClientWithStaging @@ -24,23 +24,6 @@ from dlt.destinations.type_mapping import TypeMapper -BIGINT_PRECISION = 19 -# MAX_NUMERIC_PRECISION = 38 - - -# SCT_TO_SNOW: Dict[TDataType, str] = { -# "complex": "VARIANT", -# "text": "VARCHAR", -# "double": "FLOAT", -# "bool": "BOOLEAN", -# "date": "DATE", -# "timestamp": "TIMESTAMP_TZ", -# "bigint": f"NUMBER({BIGINT_PRECISION},0)", # Snowflake has no integer types -# "binary": "BINARY", -# "decimal": "NUMBER(%i,%i)", -# "time": "TIME", -# } - SNOW_TO_SCT: Dict[str, TDataType] = { "VARCHAR": "text", "FLOAT": "double", @@ -76,6 +59,26 @@ class SnowflakeTypeMapper(TypeMapper): "wei": "NUMBER(%i,%i)", } + dbt_to_sct = { + "VARCHAR": "text", + "FLOAT": "double", + "BOOLEAN": "bool", + "DATE": "date", + "TIMESTAMP_TZ": "timestamp", + "BINARY": "binary", + "VARIANT": "complex", + "TIME": "time" + } + + def from_db_type(self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None) -> TColumnType: + if db_type == "NUMBER": + if precision == self.BIGINT_PRECISION and scale == 0: + return dict(data_type='bigint', precision=precision, scale=scale) + elif (precision, scale) == self.capabilities.wei_precision: + return dict(data_type='wei', precision=precision, scale=scale) + return dict(data_type='decimal', precision=precision, scale=scale) + return super().from_db_type(db_type, precision, scale) + class SnowflakeLoadJob(LoadJob, FollowupJob): def __init__( @@ -228,15 +231,8 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc return sql - @classmethod - def _from_db_type(cls, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: - if bq_t == "NUMBER": - if precision == BIGINT_PRECISION and scale == 0: - return 'bigint' - elif (precision, scale) == cls.capabilities.wei_precision: - return 'wei' - return 'decimal' - return SNOW_TO_SCT.get(bq_t, "text") + def _from_db_type(self, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + return self.type_mapper.from_db_type(bq_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema) -> str: name = self.capabilities.escape_identifier(c["name"]) diff --git a/dlt/destinations/type_mapping.py b/dlt/destinations/type_mapping.py index 2a6b851123..c884186293 100644 --- a/dlt/destinations/type_mapping.py +++ b/dlt/destinations/type_mapping.py @@ -1,6 +1,6 @@ from typing import Tuple, ClassVar, Dict, Optional -from dlt.common.schema.typing import TColumnSchema, TDataType +from dlt.common.schema.typing import TColumnSchema, TDataType, TColumnType from dlt.common.destination.capabilities import DestinationCapabilitiesContext @@ -14,6 +14,8 @@ class TypeMapper: Values should have printf placeholders for precision (and scale if applicable) """ + dbt_to_sct: Dict[str, TDataType] + def __init__(self, capabilities: DestinationCapabilitiesContext) -> None: self.capabilities = capabilities @@ -42,7 +44,7 @@ def precision_tuple_or_default(self, data_type: TDataType, precision: Optional[i return (precision, ) return (precision, scale) - def decimal_precision(self, precision: Optional[int], scale: Optional[int]) -> Optional[Tuple[int, int]]: + def decimal_precision(self, precision: Optional[int] = None, scale: Optional[int] = None) -> Optional[Tuple[int, int]]: defaults = self.capabilities.decimal_precision if not defaults: return None @@ -51,7 +53,7 @@ def decimal_precision(self, precision: Optional[int], scale: Optional[int]) -> O precision if precision is not None else default_precision, scale if scale is not None else default_scale ) - def wei_precision(self, precision: Optional[int], scale: Optional[int]) -> Optional[Tuple[int, int]]: + def wei_precision(self, precision: Optional[int] = None, scale: Optional[int] = None) -> Optional[Tuple[int, int]]: defaults = self.capabilities.wei_precision if not defaults: return None @@ -62,3 +64,10 @@ def wei_precision(self, precision: Optional[int], scale: Optional[int]) -> Optio def timestamp_precision(self, precision: Optional[int]) -> Optional[int]: return precision or self.capabilities.timestamp_precision + + def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + return dict( + data_type=self.dbt_to_sct[db_type], + precision=precision, + scale=scale + )