From 6f4f16fa6bb621ed356adad26ccd328dbb33076f Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 21 Sep 2023 13:16:39 -0400 Subject: [PATCH] Don't add None precision --- dlt/common/utils.py | 8 ++++++++ dlt/destinations/athena/athena.py | 3 ++- dlt/destinations/mssql/mssql.py | 2 +- dlt/destinations/postgres/postgres.py | 2 +- dlt/destinations/redshift/redshift.py | 2 +- dlt/destinations/snowflake/snowflake.py | 4 ++-- dlt/destinations/type_mapping.py | 7 ++++--- tests/load/weaviate/test_weaviate_client.py | 2 +- 8 files changed, 20 insertions(+), 10 deletions(-) diff --git a/dlt/common/utils.py b/dlt/common/utils.py index d73daeeff8..c77a35157c 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -19,6 +19,9 @@ T = TypeVar("T") TDict = TypeVar("TDict", bound=DictStrAny) +TKey = TypeVar("TKey") +TValue = TypeVar("TValue") + # row counts TRowCount = Dict[str, int] @@ -457,3 +460,8 @@ def maybe_context(manager: ContextManager[TAny]) -> Iterator[TAny]: else: with manager as ctx: yield ctx + + +def without_none(d: Mapping[TKey, Optional[TValue]]) -> Mapping[TKey, TValue]: + """Return a new dict with all `None` values removed""" + return {k: v for k, v in d.items() if v is not None} diff --git a/dlt/destinations/athena/athena.py b/dlt/destinations/athena/athena.py index 92e647f4ea..ed8364aa3a 100644 --- a/dlt/destinations/athena/athena.py +++ b/dlt/destinations/athena/athena.py @@ -13,6 +13,7 @@ from pyathena.formatter import DefaultParameterFormatter, _DEFAULT_FORMATTERS, Formatter, _format_date from dlt.common import logger +from dlt.common.utils import without_none from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, Schema from dlt.common.schema.typing import TTableSchema, TColumnType @@ -82,7 +83,7 @@ def to_db_integer_type(self, precision: Optional[int]) -> str: def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: for key, val in self.dbt_to_sct.items(): if db_type.startswith(key): - return dict(data_type=val, precision=precision, scale=scale) + return without_none(dict(data_type=val, precision=precision, scale=scale)) # type: ignore[return-value] return dict(data_type=None) diff --git a/dlt/destinations/mssql/mssql.py b/dlt/destinations/mssql/mssql.py index 13a494682e..d36a8a2362 100644 --- a/dlt/destinations/mssql/mssql.py +++ b/dlt/destinations/mssql/mssql.py @@ -74,7 +74,7 @@ def to_db_integer_type(self, precision: Optional[int]) -> str: def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: if db_type == "numeric": if (precision, scale) == self.capabilities.wei_precision: - return dict(data_type="wei", precision=precision, scale=scale) + return dict(data_type="wei") return super().from_db_type(db_type, precision, scale) diff --git a/dlt/destinations/postgres/postgres.py b/dlt/destinations/postgres/postgres.py index 1eceeef7e3..2a8c3d791b 100644 --- a/dlt/destinations/postgres/postgres.py +++ b/dlt/destinations/postgres/postgres.py @@ -71,7 +71,7 @@ def to_db_integer_type(self, precision: Optional[int]) -> str: def from_db_type(self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None) -> TColumnType: if db_type == "numeric": if (precision, scale) == self.capabilities.wei_precision: - return dict(data_type="wei", precision=precision, scale=scale) + return dict(data_type="wei") return super().from_db_type(db_type, precision, scale) diff --git a/dlt/destinations/redshift/redshift.py b/dlt/destinations/redshift/redshift.py index aeb7b4f1b4..dcfa6f1f25 100644 --- a/dlt/destinations/redshift/redshift.py +++ b/dlt/destinations/redshift/redshift.py @@ -88,7 +88,7 @@ def to_db_integer_type(self, precision: Optional[int]) -> str: def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: if db_type == "numeric": if (precision, scale) == self.capabilities.wei_precision: - return dict(data_type="wei", precision=precision, scale=scale) + return dict(data_type="wei") return super().from_db_type(db_type, precision, scale) diff --git a/dlt/destinations/snowflake/snowflake.py b/dlt/destinations/snowflake/snowflake.py index e1407bfccf..38f53cf060 100644 --- a/dlt/destinations/snowflake/snowflake.py +++ b/dlt/destinations/snowflake/snowflake.py @@ -72,9 +72,9 @@ class SnowflakeTypeMapper(TypeMapper): def from_db_type(self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None) -> TColumnType: if db_type == "NUMBER": if precision == self.BIGINT_PRECISION and scale == 0: - return dict(data_type='bigint', precision=precision, scale=scale) + return dict(data_type='bigint') elif (precision, scale) == self.capabilities.wei_precision: - return dict(data_type='wei', precision=precision, scale=scale) + return dict(data_type='wei') return dict(data_type='decimal', precision=precision, scale=scale) return super().from_db_type(db_type, precision, scale) diff --git a/dlt/destinations/type_mapping.py b/dlt/destinations/type_mapping.py index 43a0362a44..333859e685 100644 --- a/dlt/destinations/type_mapping.py +++ b/dlt/destinations/type_mapping.py @@ -2,6 +2,7 @@ from dlt.common.schema.typing import TColumnSchema, TDataType, TColumnType from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.utils import without_none class TypeMapper: @@ -72,8 +73,8 @@ def timestamp_precision(self, precision: Optional[int]) -> Optional[int]: return precision or self.capabilities.timestamp_precision def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: - return dict( - data_type=self.dbt_to_sct[db_type], + return without_none(dict( # type: ignore[return-value] + data_type=self.dbt_to_sct.get(db_type, "text"), precision=precision, scale=scale - ) + )) diff --git a/tests/load/weaviate/test_weaviate_client.py b/tests/load/weaviate/test_weaviate_client.py index 0e234cd7a8..8ef3ddd660 100644 --- a/tests/load/weaviate/test_weaviate_client.py +++ b/tests/load/weaviate/test_weaviate_client.py @@ -129,7 +129,7 @@ def test_case_insensitive_properties_create(ci_client: WeaviateClient) -> None: ci_client.update_stored_schema() _, table_columns = ci_client.get_storage_table("ColClass") # later column overwrites earlier one so: double - assert table_columns == {'col1': {'name': 'col1', 'data_type': 'double', 'precision': None, 'scale': None}} + assert table_columns == {'col1': {'name': 'col1', 'data_type': 'double'}} def test_case_sensitive_properties_add(client: WeaviateClient) -> None: