diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index d10091998a..da0e57ce36 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -53,6 +53,8 @@ class TColumnSchema(TColumnSchemaBase, total=False): root_key: Optional[bool] merge_key: Optional[bool] variant: Optional[bool] + precision: Optional[int] + scale: Optional[int] TTableSchemaColumns = Dict[str, TColumnSchema] diff --git a/dlt/destinations/athena/athena.py b/dlt/destinations/athena/athena.py index 45ca91a7b2..6ef9945daa 100644 --- a/dlt/destinations/athena/athena.py +++ b/dlt/destinations/athena/athena.py @@ -31,20 +31,10 @@ from dlt.destinations.typing import DBApiCursor from dlt.destinations.job_client_impl import SqlJobClientBase, StorageSchemaInfo from dlt.destinations.athena.configuration import AthenaClientConfiguration +from dlt.destinations.type_mapping import TypeMapper from dlt.destinations import path_utils -SCT_TO_HIVET: Dict[TDataType, str] = { - "complex": "string", - "text": "string", - "double": "double", - "bool": "boolean", - "date": "date", - "timestamp": "timestamp", - "bigint": "bigint", - "binary": "binary", - "decimal": "decimal(%i,%i)", - "time": "string" -} + HIVET_TO_SCT: Dict[str, TDataType] = { "varchar": "text", @@ -59,6 +49,25 @@ } +class AthenaTypeMapper(TypeMapper): + sct_to_unbound_dbt = { + "complex": "string", + "text": "string", + "double": "double", + "bool": "boolean", + "date": "date", + "timestamp": "timestamp", + "bigint": "bigint", + "binary": "binary", + "time": "string" + } + + sct_to_dbt = { + "decimal": "decimal(%i,%i)", + "wei": "decimal(%i,%i)" + } + + # add a formatter for pendulum to be used by pyathen dbapi def _format_pendulum_datetime(formatter: Formatter, escaper: Callable[[str], str], val: Any) -> Any: # copied from https://github.com/laughingman7743/PyAthena/blob/f4b21a0b0f501f5c3504698e25081f491a541d4e/pyathena/formatter.py#L114 @@ -265,19 +274,12 @@ def __init__(self, schema: Schema, config: AthenaClientConfiguration) -> None: super().__init__(schema, config, sql_client) self.sql_client: AthenaSQLClient = sql_client # type: ignore self.config: AthenaClientConfiguration = config + self.type_mapper = AthenaTypeMapper(self.capabilities) def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: # never truncate tables in athena super().initialize_storage([]) - @classmethod - def _to_db_type(cls, sc_t: TDataType) -> str: - if sc_t == "wei": - return SCT_TO_HIVET["decimal"] % cls.capabilities.wei_precision - if sc_t == "decimal": - return SCT_TO_HIVET["decimal"] % cls.capabilities.decimal_precision - return SCT_TO_HIVET[sc_t] - @classmethod def _from_db_type(cls, hive_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: for key, val in HIVET_TO_SCT.items(): @@ -286,7 +288,7 @@ def _from_db_type(cls, hive_t: str, precision: Optional[int], scale: Optional[in return None def _get_column_def_sql(self, c: TColumnSchema) -> str: - return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self._to_db_type(c['data_type'])}" + return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c)}" def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool) -> List[str]: diff --git a/dlt/destinations/bigquery/bigquery.py b/dlt/destinations/bigquery/bigquery.py index ffbc1b546b..a3cdcb0ba0 100644 --- a/dlt/destinations/bigquery/bigquery.py +++ b/dlt/destinations/bigquery/bigquery.py @@ -22,6 +22,7 @@ from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.type_mapping import TypeMapper from dlt.common.schema.utils import table_schema_has_type @@ -53,6 +54,27 @@ "TIME": "time", } + +class BigQueryTypeMapper(TypeMapper): + sct_to_unbound_dbt = { + "complex": "JSON", + "text": "STRING", + "double": "FLOAT64", + "bool": "BOOLEAN", + "date": "DATE", + "timestamp": "TIMESTAMP", + "bigint": "INTEGER", + "binary": "BYTES", + "wei": "BIGNUMERIC", # non parametrized should hold wei values + "time": "TIME", + } + + sct_to_dbt = { + "text": "STRING(%i)", + "binary": "BYTES(%i)", + "decimal": "NUMERIC(%i,%i)", + } + class BigQueryLoadJob(LoadJob, FollowupJob): def __init__( self, @@ -146,6 +168,7 @@ def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None: super().__init__(schema, config, sql_client) self.config: BigQueryClientConfiguration = config self.sql_client: BigQuerySqlClient = sql_client # type: ignore + self.type_mapper = BigQueryTypeMapper(self.capabilities) def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return BigQueryMergeJob.from_table_chain(table_chain, self.sql_client) @@ -229,7 +252,7 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc def _get_column_def_sql(self, c: TColumnSchema) -> str: name = self.capabilities.escape_identifier(c["name"]) - return f"{name} {self._to_db_type(c['data_type'])} {self._gen_not_null(c.get('nullable', True))}" + return f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: schema_table: TTableSchemaColumns = {} @@ -311,11 +334,11 @@ def _retrieve_load_job(self, file_path: str) -> bigquery.LoadJob: job_id = BigQueryLoadJob.get_job_id_from_file_path(file_path) return cast(bigquery.LoadJob, self.sql_client.native_connection.get_job(job_id)) - @classmethod - def _to_db_type(cls, sc_t: TDataType) -> str: - if sc_t == "decimal": - return SCT_TO_BQT["decimal"] % cls.capabilities.decimal_precision - return SCT_TO_BQT[sc_t] + # @classmethod + # def _to_db_type(cls, sc_t: TDataType) -> str: + # if sc_t == "decimal": + # return SCT_TO_BQT["decimal"] % cls.capabilities.decimal_precision + # return SCT_TO_BQT[sc_t] @classmethod def _from_db_type(cls, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: diff --git a/dlt/destinations/duckdb/duck.py b/dlt/destinations/duckdb/duck.py index 0a9fe1fb3b..d71b387af2 100644 --- a/dlt/destinations/duckdb/duck.py +++ b/dlt/destinations/duckdb/duck.py @@ -14,20 +14,21 @@ from dlt.destinations.duckdb import capabilities from dlt.destinations.duckdb.sql_client import DuckDbSqlClient from dlt.destinations.duckdb.configuration import DuckDbClientConfiguration - - -SCT_TO_PGT: Dict[TDataType, str] = { - "complex": "JSON", - "text": "VARCHAR", - "double": "DOUBLE", - "bool": "BOOLEAN", - "date": "DATE", - "timestamp": "TIMESTAMP WITH TIME ZONE", - "bigint": "BIGINT", - "binary": "BLOB", - "decimal": "DECIMAL(%i,%i)", - "time": "TIME" -} +from dlt.destinations.type_mapping import TypeMapper + + +# SCT_TO_PGT: Dict[TDataType, str] = { +# "complex": "JSON", +# "text": "VARCHAR", +# "double": "DOUBLE", +# "bool": "BOOLEAN", +# "date": "DATE", +# "timestamp": "TIMESTAMP WITH TIME ZONE", +# "bigint": "BIGINT", +# "binary": "BLOB", +# "decimal": "DECIMAL(%i,%i)", +# "time": "TIME" +# } PGT_TO_SCT: Dict[str, TDataType] = { "VARCHAR": "text", @@ -51,6 +52,27 @@ TABLES_LOCKS: Dict[str, threading.Lock] = {} +class DuckDbTypeMapper(TypeMapper): + sct_to_unbound_dbt = { + "complex": "JSON", + "text": "VARCHAR", + "double": "DOUBLE", + "bool": "BOOLEAN", + "date": "DATE", + # Duck does not allow specifying precision on timestamp with tz + "timestamp": "TIMESTAMP WITH TIME ZONE", + "bigint": "BIGINT", + "binary": "BLOB", + "time": "TIME" + } + + sct_to_dbt = { + # VARCHAR(n) is alias for VARCHAR in duckdb + # "text": "VARCHAR(%i)", + "decimal": "DECIMAL(%i,%i)", + "wei": "DECIMAL(%i,%i)", + } + class DuckDbCopyJob(LoadJob, FollowupJob): def __init__(self, table_name: str, file_path: str, sql_client: DuckDbSqlClient) -> None: super().__init__(FileStorage.get_file_name_from_file_path(file_path)) @@ -95,6 +117,7 @@ def __init__(self, schema: Schema, config: DuckDbClientConfiguration) -> None: self.config: DuckDbClientConfiguration = config self.sql_client: DuckDbSqlClient = sql_client # type: ignore self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {} + self.type_mapper = DuckDbTypeMapper(self.capabilities) def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: job = super().start_file_load(table, file_path, load_id) @@ -105,15 +128,15 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> def _get_column_def_sql(self, c: TColumnSchema) -> str: hints_str = " ".join(self.active_hints.get(h, "") for h in self.active_hints.keys() if c.get(h, False) is True) column_name = self.capabilities.escape_identifier(c["name"]) - return f"{column_name} {self._to_db_type(c['data_type'])} {hints_str} {self._gen_not_null(c.get('nullable', True))}" - - @classmethod - def _to_db_type(cls, sc_t: TDataType) -> str: - if sc_t == "wei": - return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision - if sc_t == "decimal": - return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision - return SCT_TO_PGT[sc_t] + return f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + + # @classmethod + # def _to_db_type(cls, sc_t: TDataType) -> str: + # if sc_t == "wei": + # return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision + # if sc_t == "decimal": + # return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision + # return SCT_TO_PGT[sc_t] @classmethod def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 1c0d55d1c8..d9ef55ba77 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -247,11 +247,6 @@ def _null_to_bool(v: str) -> bool: schema_table[c[0]] = schema_c # type: ignore return True, schema_table - @classmethod - @abstractmethod - def _to_db_type(cls, schema_type: TDataType) -> str: - pass - @classmethod @abstractmethod def _from_db_type(cls, db_type: str, precision: Optional[int], scale: Optional[int]) -> TDataType: diff --git a/dlt/destinations/mssql/__init__.py b/dlt/destinations/mssql/__init__.py index d15d54676c..56051a324e 100644 --- a/dlt/destinations/mssql/__init__.py +++ b/dlt/destinations/mssql/__init__.py @@ -36,6 +36,7 @@ def capabilities() -> DestinationCapabilitiesContext: caps.is_max_text_data_type_length_in_bytes = False caps.supports_ddl_transactions = True caps.max_rows_per_insert = 1000 + caps.timestamp_precision = 7 return caps diff --git a/dlt/destinations/mssql/mssql.py b/dlt/destinations/mssql/mssql.py index 701ccb88fa..5ebb49a495 100644 --- a/dlt/destinations/mssql/mssql.py +++ b/dlt/destinations/mssql/mssql.py @@ -16,20 +16,21 @@ from dlt.destinations.mssql.sql_client import PyOdbcMsSqlClient from dlt.destinations.mssql.configuration import MsSqlClientConfiguration from dlt.destinations.sql_client import SqlClientBase - - -SCT_TO_PGT: Dict[TDataType, str] = { - "complex": "nvarchar(max)", - "text": "nvarchar(max)", - "double": "float", - "bool": "bit", - "timestamp": "datetimeoffset", - "date": "date", - "bigint": "bigint", - "binary": "varbinary(max)", - "decimal": "decimal(%i,%i)", - "time": "time" -} +from dlt.destinations.type_mapping import TypeMapper + + +# SCT_TO_PGT: Dict[TDataType, str] = { +# "complex": "nvarchar(max)", +# "text": "nvarchar(max)", +# "double": "float", +# "bool": "bit", +# "timestamp": "datetimeoffset", +# "date": "date", +# "bigint": "bigint", +# "binary": "varbinary(max)", +# "decimal": "decimal(%i,%i)", +# "time": "time" +# } PGT_TO_SCT: Dict[str, TDataType] = { "nvarchar": "text", @@ -47,6 +48,27 @@ "unique": "UNIQUE" } + +class MsSqlTypeMapper(TypeMapper): + sct_to_unbound_dbt = { + "complex": "nvarchar(max)", + "text": "nvarchar(max)", + "double": "float", + "bool": "bit", + "bigint": "bigint", + "binary": "varbinary(max)", + } + + sct_to_dbt = { + "complex": "nvarchar(%i)", + "text": "nvarchar(%i)", + "timestamp": "datetimeoffset(%i)", + "binary": "varbinary(%i)", + "decimal": "decimal(%i,%i)", + "time": "time(%i)" + } + + class MsSqlStagingCopyJob(SqlStagingCopyJob): @classmethod @@ -97,6 +119,7 @@ def __init__(self, schema: Schema, config: MsSqlClientConfiguration) -> None: self.config: MsSqlClientConfiguration = config self.sql_client = sql_client self.active_hints = HINT_TO_MSSQL_ATTR if self.config.create_indexes else {} + self.type_mapper = MsSqlTypeMapper(self.capabilities) def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return MsSqlMergeJob.from_table_chain(table_chain, self.sql_client) @@ -109,9 +132,9 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str: sc_type = c["data_type"] if sc_type == "text" and c.get("unique"): # MSSQL does not allow index on large TEXT columns - db_type = "nvarchar(900)" + db_type = "nvarchar(%i)" % (c.get("precision") or 900) else: - db_type = self._to_db_type(sc_type) + db_type = self.type_mapper.to_db_type(c) hints_str = " ".join(self.active_hints.get(h, "") for h in self.active_hints.keys() if c.get(h, False) is True) column_name = self.capabilities.escape_identifier(c["name"]) @@ -120,16 +143,16 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str: def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return MsSqlStagingCopyJob.from_table_chain(table_chain, self.sql_client) - @classmethod - def _to_db_type(cls, sc_t: TDataType) -> str: - if sc_t == "wei": - return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision - if sc_t == "decimal": - return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision - - if sc_t == "wei": - return f"numeric({2*EVM_DECIMAL_PRECISION},{EVM_DECIMAL_PRECISION})" - return SCT_TO_PGT[sc_t] + # @classmethod + # def _to_db_type(cls, sc_t: TDataType) -> str: + # if sc_t == "wei": + # return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision + # if sc_t == "decimal": + # return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision + + # if sc_t == "wei": + # return f"numeric({2*EVM_DECIMAL_PRECISION},{EVM_DECIMAL_PRECISION})" + # return SCT_TO_PGT[sc_t] @classmethod def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: diff --git a/dlt/destinations/postgres/postgres.py b/dlt/destinations/postgres/postgres.py index 6e7e049bde..6ecf0a8dde 100644 --- a/dlt/destinations/postgres/postgres.py +++ b/dlt/destinations/postgres/postgres.py @@ -15,21 +15,9 @@ from dlt.destinations.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.postgres.configuration import PostgresClientConfiguration from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.type_mapping import TypeMapper -SCT_TO_PGT: Dict[TDataType, str] = { - "complex": "jsonb", - "text": "varchar", - "double": "double precision", - "bool": "boolean", - "timestamp": "timestamp with time zone", - "date": "date", - "bigint": "bigint", - "binary": "bytea", - "decimal": "numeric(%i,%i)", - "time": "time without time zone" -} - PGT_TO_SCT: Dict[str, TDataType] = { "varchar": "text", "jsonb": "complex", @@ -47,6 +35,27 @@ "unique": "UNIQUE" } +class PostgresTypeMapper(TypeMapper): + sct_to_unbound_dbt = { + "complex": "jsonb", + "text": "varchar", + "double": "double precision", + "bool": "boolean", + "date": "date", + "bigint": "bigint", + "binary": "bytea", + } + + sct_to_dbt = { + "text": "varchar(%i)", + "timestamp": "timestamp (%i) with time zone", + "binary": "bytea(%i)", + "decimal": "numeric(%i,%i)", + "time": "time (%i) without time zone", + "wei": "numeric(%i,%i)" + } + + class PostgresStagingCopyJob(SqlStagingCopyJob): @classmethod @@ -64,6 +73,7 @@ def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClient sql.append(f"CREATE TABLE {staging_table_name} (like {table_name} including all);") return sql + class PostgresClient(InsertValuesJobClient): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -77,25 +87,37 @@ def __init__(self, schema: Schema, config: PostgresClientConfiguration) -> None: self.config: PostgresClientConfiguration = config self.sql_client = sql_client self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {} + self.type_mapper = PostgresTypeMapper(self.capabilities) def _get_column_def_sql(self, c: TColumnSchema) -> str: hints_str = " ".join(self.active_hints.get(h, "") for h in self.active_hints.keys() if c.get(h, False) is True) column_name = self.capabilities.escape_identifier(c["name"]) - return f"{column_name} {self._to_db_type(c['data_type'])} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + return f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return PostgresStagingCopyJob.from_table_chain(table_chain, self.sql_client) - @classmethod - def _to_db_type(cls, sc_t: TDataType) -> str: - if sc_t == "wei": - return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision - if sc_t == "decimal": - return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision - - if sc_t == "wei": - return f"numeric({2*EVM_DECIMAL_PRECISION},{EVM_DECIMAL_PRECISION})" - return SCT_TO_PGT[sc_t] + # @classmethod + # def _to_db_type(cls, column: TColumnSchema) -> str: + # sc_t = column["data_type"] + # precision, scale = column.get("precision"), column.get("scale") + # if sc_t in ("timestamp", "time"): + # return SCT_TO_PGT[sc_t] % (precision or cls.capabilities.timestamp_precision) + # if sc_t == "wei": + # default_precision, default_scale = cls.capabilities.wei_precision + # precision = precision if precision is not None else default_precision + # scale = scale if scale is not None else default_scale + # return SCT_TO_PGT["decimal"] % (precision, scale) + # if sc_t == "decimal": + # default_precision, default_scale = cls.capabilities.decimal_precision + # precision = precision if precision is not None else default_precision + # scale = scale if scale is not None else default_scale + # return SCT_TO_PGT["decimal"] % (precision, scale) + # if sc_t == "text": + # if precision is not None: + # return "varchar (%i)" % precision + # return "varchar" + # return SCT_TO_PGT[sc_t] @classmethod def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: @@ -103,4 +125,3 @@ def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int] if (precision, scale) == cls.capabilities.wei_precision: return "wei" return PGT_TO_SCT.get(pq_t, "text") - diff --git a/dlt/destinations/redshift/redshift.py b/dlt/destinations/redshift/redshift.py index 2604f52c53..42e971be59 100644 --- a/dlt/destinations/redshift/redshift.py +++ b/dlt/destinations/redshift/redshift.py @@ -29,22 +29,9 @@ from dlt.destinations.redshift.configuration import RedshiftClientConfiguration from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.type_mapping import TypeMapper - -SCT_TO_PGT: Dict[TDataType, str] = { - "complex": "super", - "text": "varchar(max)", - "double": "double precision", - "bool": "boolean", - "date": "date", - "timestamp": "timestamp with time zone", - "bigint": "bigint", - "binary": "varbinary", - "decimal": "numeric(%i,%i)", - "time": "time without time zone" -} - PGT_TO_SCT: Dict[str, TDataType] = { "super": "complex", "varchar(max)": "text", @@ -66,6 +53,27 @@ } +class RedshiftTypeMapper(TypeMapper): + sct_to_unbound_dbt = { + "complex": "super", + "text": "varchar(max)", + "double": "double precision", + "bool": "boolean", + "date": "date", + "timestamp": "timestamp with time zone", + "bigint": "bigint", + "binary": "varbinary", + "time": "time without time zone" + } + + sct_to_dbt = { + "decimal": "numeric(%i,%i)", + "wei": "numeric(%i,%i)", + "text": "varchar(%i)", + "binary": "varbinary(%i)", + } + + class RedshiftSqlClient(Psycopg2SqlClient): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -168,6 +176,7 @@ def __init__(self, schema: Schema, config: RedshiftClientConfiguration) -> None: super().__init__(schema, config, sql_client) self.sql_client = sql_client self.config: RedshiftClientConfiguration = config + self.type_mapper = RedshiftTypeMapper(self.capabilities) def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return RedshiftMergeJob.from_table_chain(table_chain, self.sql_client) @@ -175,7 +184,7 @@ def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: def _get_column_def_sql(self, c: TColumnSchema) -> str: hints_str = " ".join(HINT_TO_REDSHIFT_ATTR.get(h, "") for h in HINT_TO_REDSHIFT_ATTR.keys() if c.get(h, False) is True) column_name = self.capabilities.escape_identifier(c["name"]) - return f"{column_name} {self._to_db_type(c['data_type'])} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + return f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" @@ -185,13 +194,13 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> job = RedshiftCopyFileLoadJob(table, file_path, self.sql_client, staging_credentials=self.config.staging_config.credentials, staging_iam_role=self.config.staging_iam_role) return job - @classmethod - def _to_db_type(cls, sc_t: TDataType) -> str: - if sc_t == "wei": - return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision - if sc_t == "decimal": - return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision - return SCT_TO_PGT[sc_t] + # @classmethod + # def _to_db_type(cls, sc_t: TDataType) -> str: + # if sc_t == "wei": + # return SCT_TO_PGT["decimal"] % cls.capabilities.wei_precision + # if sc_t == "decimal": + # return SCT_TO_PGT["decimal"] % cls.capabilities.decimal_precision + # return SCT_TO_PGT[sc_t] @classmethod def _from_db_type(cls, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: diff --git a/dlt/destinations/snowflake/__init__.py b/dlt/destinations/snowflake/__init__.py index 5d32bc41fd..c253adafc9 100644 --- a/dlt/destinations/snowflake/__init__.py +++ b/dlt/destinations/snowflake/__init__.py @@ -34,6 +34,7 @@ def capabilities() -> DestinationCapabilitiesContext: caps.is_max_text_data_type_length_in_bytes = True caps.supports_ddl_transactions = True caps.alter_add_multi_column = True + caps.timestamp_precision = 9 return caps diff --git a/dlt/destinations/snowflake/snowflake.py b/dlt/destinations/snowflake/snowflake.py index 426a1e53a1..33fa3e3cfd 100644 --- a/dlt/destinations/snowflake/snowflake.py +++ b/dlt/destinations/snowflake/snowflake.py @@ -21,23 +21,25 @@ from dlt.destinations.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.type_mapping import TypeMapper + BIGINT_PRECISION = 19 -MAX_NUMERIC_PRECISION = 38 - - -SCT_TO_SNOW: Dict[TDataType, str] = { - "complex": "VARIANT", - "text": "VARCHAR", - "double": "FLOAT", - "bool": "BOOLEAN", - "date": "DATE", - "timestamp": "TIMESTAMP_TZ", - "bigint": f"NUMBER({BIGINT_PRECISION},0)", # Snowflake has no integer types - "binary": "BINARY", - "decimal": "NUMBER(%i,%i)", - "time": "TIME", -} +# MAX_NUMERIC_PRECISION = 38 + + +# SCT_TO_SNOW: Dict[TDataType, str] = { +# "complex": "VARIANT", +# "text": "VARCHAR", +# "double": "FLOAT", +# "bool": "BOOLEAN", +# "date": "DATE", +# "timestamp": "TIMESTAMP_TZ", +# "bigint": f"NUMBER({BIGINT_PRECISION},0)", # Snowflake has no integer types +# "binary": "BINARY", +# "decimal": "NUMBER(%i,%i)", +# "time": "TIME", +# } SNOW_TO_SCT: Dict[str, TDataType] = { "VARCHAR": "text", @@ -51,6 +53,30 @@ } +class SnowflakeTypeMapper(TypeMapper): + BIGINT_PRECISION = 19 + MAX_NUMERIC_PRECISION = 38 + sct_to_unbound_dbt = { + "complex": "VARIANT", + "text": "VARCHAR", + "double": "FLOAT", + "bool": "BOOLEAN", + "date": "DATE", + "timestamp": "TIMESTAMP_TZ", + "bigint": f"NUMBER({BIGINT_PRECISION},0)", # Snowflake has no integer types + "binary": "BINARY", + "time": "TIME", + } + + sct_to_dbt = { + "text": "VARCHAR(%i)", + "timestamp": "TIMESTAMP_TZ(%i)", + "decimal": "NUMBER(%i,%i)", + "time": "TIME(%i)", + "wei": "NUMBER(%i,%i)", + } + + class SnowflakeLoadJob(LoadJob, FollowupJob): def __init__( self, file_path: str, table_name: str, load_id: str, client: SnowflakeSqlClient, @@ -155,7 +181,6 @@ def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClient class SnowflakeClient(SqlJobClientWithStaging): - capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: SnowflakeClientConfiguration) -> None: @@ -166,6 +191,7 @@ def __init__(self, schema: Schema, config: SnowflakeClientConfiguration) -> None super().__init__(schema, config, sql_client) self.config: SnowflakeClientConfiguration = config self.sql_client: SnowflakeSqlClient = sql_client # type: ignore + self.type_mapper = SnowflakeTypeMapper(self.capabilities) def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: job = super().start_file_load(table, file_path, load_id) @@ -202,14 +228,6 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc return sql - @classmethod - def _to_db_type(cls, sc_t: TDataType) -> str: - if sc_t == "wei": - return SCT_TO_SNOW["decimal"] % cls.capabilities.wei_precision - if sc_t == "decimal": - return SCT_TO_SNOW["decimal"] % cls.capabilities.decimal_precision - return SCT_TO_SNOW[sc_t] - @classmethod def _from_db_type(cls, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TDataType: if bq_t == "NUMBER": @@ -222,7 +240,7 @@ def _from_db_type(cls, bq_t: str, precision: Optional[int], scale: Optional[int] def _get_column_def_sql(self, c: TColumnSchema) -> str: name = self.capabilities.escape_identifier(c["name"]) - return f"{name} {self._to_db_type(c['data_type'])} {self._gen_not_null(c.get('nullable', True))}" + return f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: table_name = table_name.upper() # All snowflake tables are uppercased in information schema diff --git a/dlt/destinations/type_mapping.py b/dlt/destinations/type_mapping.py new file mode 100644 index 0000000000..b89a5a7b4f --- /dev/null +++ b/dlt/destinations/type_mapping.py @@ -0,0 +1,56 @@ +from typing import Tuple, ClassVar, Dict, Optional + +from dlt.common.schema.typing import TColumnSchema, TDataType +from dlt.common.destination.capabilities import DestinationCapabilitiesContext + + +class TypeMapper: + capabilities: DestinationCapabilitiesContext + + sct_to_unbound_dbt: Dict[TDataType, str] + """Data types without precision or scale specified (e.g. `"text": "varchar"` in postgres)""" + sct_to_dbt: Dict[TDataType, str] + """Data types that require a precision or scale (e.g. `"text": "varchar(%i)"` or `"decimal": "numeric(%i,%i)"` in postgres). + Values should have printf placeholders for precision (and scale if applicable) + """ + + def __init__(self, capabilities: DestinationCapabilitiesContext) -> None: + self.capabilities = capabilities + + def to_db_type(self, column: TColumnSchema) -> str: + precision, scale = column.get("precision"), column.get("scale") + sc_t = column["data_type"] + bounded_template = self.sct_to_dbt.get(sc_t) + precision_tuple = self.precision_tuple_or_default(sc_t, precision, scale) + if not precision_tuple or not bounded_template: + return self.sct_to_unbound_dbt[sc_t] + return self.sct_to_dbt[sc_t] % precision_tuple + + def precision_tuple_or_default(self, data_type: TDataType, precision: Optional[int], scale: Optional[int]) -> Optional[Tuple[int, ...]]: + if data_type in ("timestamp", "time"): + return (precision or self.capabilities.timestamp_precision, ) + elif data_type == "decimal": + return self.decimal_precision(precision, scale) + elif data_type == "wei": + return self.wei_precision(precision, scale) + + if precision is None: + return None + elif scale is None: + return (precision, ) + return (precision, scale) + + def decimal_precision(self, precision: Optional[int], scale: Optional[int]) -> Tuple[int, int]: + default_precision, default_scale = self.capabilities.decimal_precision + return ( + precision if precision is not None else default_precision, scale if scale is not None else default_scale + ) + + def wei_precision(self, precision: Optional[int], scale: Optional[int]) -> Tuple[int, int]: + default_precision, default_scale = self.capabilities.wei_precision + return ( + precision if precision is not None else default_precision, scale if scale is not None else default_scale + ) + + def timestamp_precision(self, precision: Optional[int]) -> int: + return precision or self.capabilities.timestamp_precision diff --git a/dlt/destinations/weaviate/weaviate_client.py b/dlt/destinations/weaviate/weaviate_client.py index bbd01f0381..103be3d8d6 100644 --- a/dlt/destinations/weaviate/weaviate_client.py +++ b/dlt/destinations/weaviate/weaviate_client.py @@ -48,20 +48,8 @@ from dlt.destinations.weaviate import capabilities from dlt.destinations.weaviate.configuration import WeaviateClientConfiguration from dlt.destinations.weaviate.exceptions import PropertyNameConflict, WeaviateBatchError +from dlt.destinations.type_mapping import TypeMapper -SCT_TO_WT: Dict[TDataType, str] = { - "text": "text", - "double": "number", - "bool": "boolean", - "timestamp": "date", - "date": "date", - "time": "text", - "bigint": "int", - "binary": "blob", - "decimal": "text", - "wei": "number", - "complex": "text", -} WT_TO_SCT: Dict[str, TDataType] = { "text": "text", @@ -80,6 +68,24 @@ } +class WeaviateTypeMapper(TypeMapper): + sct_to_unbound_dbt = { + "text": "text", + "double": "number", + "bool": "boolean", + "timestamp": "date", + "date": "date", + "time": "text", + "bigint": "int", + "binary": "blob", + "decimal": "text", + "wei": "number", + "complex": "text", + } + + sct_to_dbt = {} + + def wrap_weaviate_error(f: TFun) -> TFun: @wraps(f) def _wrap(self: JobClientBase, *args: Any, **kwargs: Any) -> Any: @@ -247,6 +253,7 @@ def __init__(self, schema: Schema, config: WeaviateClientConfiguration) -> None: "vectorizer": config.vectorizer, "moduleConfig": config.module_config, } + self.type_mapper = WeaviateTypeMapper(self.capabilities) @property def dataset_name(self) -> str: @@ -620,7 +627,7 @@ def _make_property_schema(self, column_name: str, column: TColumnSchema) -> Dict return { "name": column_name, - "dataType": [self._to_db_type(column["data_type"])], + "dataType": [self.type_mapper.to_db_type(column)], **extra_kv, } @@ -673,10 +680,6 @@ def _update_schema_in_storage(self, schema: Schema) -> None: } self.create_object(properties, self.schema.version_table_name) - @staticmethod - def _to_db_type(sc_t: TDataType) -> str: - return SCT_TO_WT[sc_t] - @staticmethod def _from_db_type(wt_t: str) -> TDataType: return WT_TO_SCT.get(wt_t, "text") diff --git a/tests/load/postgres/test_postgres_table_builder.py b/tests/load/postgres/test_postgres_table_builder.py index 50d7d3f245..5dc6866b42 100644 --- a/tests/load/postgres/test_postgres_table_builder.py +++ b/tests/load/postgres/test_postgres_table_builder.py @@ -29,7 +29,7 @@ def test_create_table(client: PostgresClient) -> None: assert '"col1" bigint NOT NULL' in sql assert '"col2" double precision NOT NULL' in sql assert '"col3" boolean NOT NULL' in sql - assert '"col4" timestamp with time zone NOT NULL' in sql + assert '"col4" timestamp (6) with time zone NOT NULL' in sql assert '"col5" varchar' in sql assert '"col6" numeric(38,9) NOT NULL' in sql assert '"col7" bytea' in sql @@ -49,7 +49,7 @@ def test_alter_table(client: PostgresClient) -> None: assert '"col1" bigint NOT NULL' in sql assert '"col2" double precision NOT NULL' in sql assert '"col3" boolean NOT NULL' in sql - assert '"col4" timestamp with time zone NOT NULL' in sql + assert '"col4" timestamp (6) with time zone NOT NULL' in sql assert '"col5" varchar' in sql assert '"col6" numeric(38,9) NOT NULL' in sql assert '"col7" bytea' in sql @@ -72,7 +72,7 @@ def test_create_table_with_hints(client: PostgresClient) -> None: assert '"col5" varchar ' in sql # no hints assert '"col3" boolean NOT NULL' in sql - assert '"col4" timestamp with time zone NOT NULL' in sql + assert '"col4" timestamp (6) with time zone NOT NULL' in sql # same thing without indexes client = PostgresClient(client.schema, PostgresClientConfiguration(dataset_name="test_" + uniq_id(), create_indexes=False, credentials=PostgresCredentials()))