diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index cff31053..8d7ff67b 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -478,6 +478,9 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type): class DbSqlType(Enum): + """The values of this enumeration are passed as literals to be used in a CAST + evaluation by the thrift server. + """ STRING = "STRING" DATE = "DATE" TIMESTAMP = "TIMESTAMP" @@ -495,7 +498,7 @@ class DbSqlType(Enum): class DbSqlParameter: name: str value: Any - type: DbSqlType + type: Union[DbSqlType, Enum] def __init__(self, name="", value=None, type=None): self.name = name diff --git a/tests/e2e/common/parameterized_query_tests.py b/tests/e2e/common/parameterized_query_tests.py index f2708c15..31b13c18 100644 --- a/tests/e2e/common/parameterized_query_tests.py +++ b/tests/e2e/common/parameterized_query_tests.py @@ -1,13 +1,14 @@ +from collections import namedtuple from datetime import datetime from decimal import Decimal +from enum import Enum -import pytz import pytest -from collections import namedtuple +import pytz from databricks.sql.client import Connection -from databricks.sql.utils import DbSqlParameter, DbSqlType from databricks.sql.exc import DatabaseError +from databricks.sql.utils import DbSqlParameter, DbSqlType class PySQLParameterizedQueryTestSuiteMixin: @@ -201,7 +202,7 @@ def test_decimal_parameterization_not_inferred(self): base_query = "SELECT :p_decimal col_decimal" # first, demonstrate how the default inferrence will raise a server exception because - # the default precision is (6,2) + # the default precision is (6,2). This example Decimal should be a DECIMAL(18,9) inferred_decimal_param = DbSqlParameter( name="p_decimal", value=Decimal("123456789.123456789") ) @@ -215,15 +216,37 @@ def test_decimal_parameterization_not_inferred(self): ): cursor.execute(base_query, parameters=[inferred_decimal_param]) - # now we bypass inferrence and set it ourselves - custom_type_tuple_maker = namedtuple("CustomDbsqlType", "value") - decimal_18_9 = custom_type_tuple_maker("DECIMAL(18,9)") + class MyCustomDecimalType(Enum): + DECIMAL_38_0 = "DECIMAL(38,0)" + DECIMAL_38_2 = "DECIMAL(38,2)" + DECIMAL_18_9 = "DECIMAL(18,9)" - explicit_decimal_param = DbSqlParameter( - name="p_decimal", value=Decimal("123456789.123456789"), type=decimal_18_9 - ) + example_decimals = [ + Decimal("12345678912345678912345678912345678912"), + Decimal("123456789123456789123456789123456789.12"), + Decimal("123456789.123456789"), + ] + + explicit_decimal_params = [ + DbSqlParameter( + name="p_decimal", + value=example_decimals[0], + type=MyCustomDecimalType.DECIMAL_38_0, + ), + DbSqlParameter( + name="p_decimal", + value=example_decimals[1], + type=MyCustomDecimalType.DECIMAL_38_2, + ), + DbSqlParameter( + name="p_decimal", + value=example_decimals[2], + type=MyCustomDecimalType.DECIMAL_18_9, + ), + ] with self.connection() as conn: cursor = conn.cursor() - cursor.execute(base_query, parameters=[explicit_decimal_param]) - result = cursor.fetchone() - assert result.col_decimal == Decimal("123456789.123456789") + for idx, param in enumerate(explicit_decimal_params): + cursor.execute(base_query, parameters=[param]) + result = cursor.fetchone() + assert result.col_decimal == example_decimals[idx]