diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index e6cd703b..71359536 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -1,25 +1,26 @@ from __future__ import annotations + +import copy +import datetime +import decimal from abc import ABC, abstractmethod -from collections import namedtuple, OrderedDict +from collections import OrderedDict, namedtuple from collections.abc import Iterable from decimal import Decimal -import datetime -import decimal from enum import Enum +from typing import Any, Dict, List, Union + import lz4.frame -from typing import Dict, List, Union, Any import pyarrow -from enum import Enum -import copy -from databricks.sql import exc, OperationalError +from databricks.sql import OperationalError, exc from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager from databricks.sql.thrift_api.TCLIService.ttypes import ( - TSparkArrowResultLink, - TSparkRowSetType, TRowSet, + TSparkArrowResultLink, TSparkParameter, TSparkParameterValue, + TSparkRowSetType, ) BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] @@ -478,6 +479,10 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type): class DbSqlType(Enum): + """The values of this enumeration are passed as literals to be used in a CAST + evaluation by the thrift server. + """ + STRING = "STRING" DATE = "DATE" TIMESTAMP = "TIMESTAMP" @@ -495,7 +500,7 @@ class DbSqlType(Enum): class DbSqlParameter: name: str value: Any - type: DbSqlType + type: Union[DbSqlType, DbsqlDynamicDecimalType, Enum] def __init__(self, name="", value=None, type=None): self.name = name @@ -506,6 +511,11 @@ def __eq__(self, other): return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ +class DbsqlDynamicDecimalType: + def __init__(self, value): + self.value = value + + def named_parameters_to_dbsqlparams_v1(parameters: Dict[str, str]): dbsqlparams = [] for name, parameter in parameters.items(): @@ -531,16 +541,49 @@ def infer_types(params: list[DbSqlParameter]): datetime.datetime: DbSqlType.TIMESTAMP, datetime.date: DbSqlType.DATE, bool: DbSqlType.BOOLEAN, + Decimal: DbSqlType.DECIMAL, } - newParams = copy.deepcopy(params) - for param in newParams: + new_params = copy.deepcopy(params) + for param in new_params: if not param.type: if type(param.value) in type_lookup_table: param.type = type_lookup_table[type(param.value)] else: raise ValueError("Parameter type cannot be inferred") + + if param.type == DbSqlType.DECIMAL: + cast_exp = calculate_decimal_cast_string(param.value) + param.type = DbsqlDynamicDecimalType(cast_exp) + param.value = str(param.value) - return newParams + return new_params + + +def calculate_decimal_cast_string(input: Decimal) -> str: + """Returns the smallest SQL cast argument that can contain the passed decimal + + Example: + Input: Decimal("1234.5678") + Output: DECIMAL(8,4) + """ + + string_decimal = str(input) + + if string_decimal.startswith("0."): + # This decimal is less than 1 + overall = after = len(string_decimal) - 2 + elif "." not in string_decimal: + # This decimal has no fractional component + overall = len(string_decimal) + after = 0 + else: + # This decimal has both whole and fractional parts + parts = string_decimal.split(".") + parts_lengths = [len(i) for i in parts] + before, after = parts_lengths[:2] + overall = before + after + + return f"DECIMAL({overall},{after})" def named_parameters_to_tsparkparams(parameters: Union[List[Any], Dict[str, str]]): diff --git a/tests/e2e/common/parameterized_query_tests.py b/tests/e2e/common/parameterized_query_tests.py index e6302aa3..8d33cfee 100644 --- a/tests/e2e/common/parameterized_query_tests.py +++ b/tests/e2e/common/parameterized_query_tests.py @@ -1,11 +1,22 @@ import datetime from decimal import Decimal +from enum import Enum from typing import Dict, List, Tuple, Union import pytz -from databricks.sql.client import Connection -from databricks.sql.utils import DbSqlParameter, DbSqlType +from databricks.sql.exc import DatabaseError +from databricks.sql.utils import ( + DbSqlParameter, + DbSqlType, + calculate_decimal_cast_string, +) + + +class MyCustomDecimalType(Enum): + DECIMAL_38_0 = "DECIMAL(38,0)" + DECIMAL_38_2 = "DECIMAL(38,2)" + DECIMAL_18_9 = "DECIMAL(18,9)" class PySQLParameterizedQueryTestSuiteMixin: @@ -63,6 +74,11 @@ def test_primitive_inferred_string(self): result = self._get_one_result(self.QUERY, params) assert result.col == "Hello" + def test_primitive_inferred_decimal(self): + params = {"p": Decimal("1234.56")} + result = self._get_one_result(self.QUERY, params) + assert result.col == Decimal("1234.56") + def test_dbsqlparam_inferred_bool(self): params = [DbSqlParameter(name="p", value=True, type=None)] @@ -103,6 +119,11 @@ def test_dbsqlparam_inferred_string(self): result = self._get_one_result(self.QUERY, params) assert result.col == "Hello" + def test_dbsqlparam_inferred_decimal(self): + params = [DbSqlParameter(name="p", value=Decimal("1234.56"), type=None)] + result = self._get_one_result(self.QUERY, params) + assert result.col == Decimal("1234.56") + def test_dbsqlparam_explicit_bool(self): params = [DbSqlParameter(name="p", value=True, type=DbSqlType.BOOLEAN)] @@ -142,3 +163,50 @@ def test_dbsqlparam_explicit_string(self): params = [DbSqlParameter(name="p", value="Hello", type=DbSqlType.STRING)] result = self._get_one_result(self.QUERY, params) assert result.col == "Hello" + + def test_dbsqlparam_explicit_decimal(self): + params = [ + DbSqlParameter(name="p", value=Decimal("1234.56"), type=DbSqlType.DECIMAL) + ] + result = self._get_one_result(self.QUERY, params) + assert result.col == Decimal("1234.56") + + def test_dbsqlparam_custom_explicit_decimal_38_0(self): + + # This DECIMAL can be contained in a DECIMAL(38,0) column in Databricks + value = Decimal("12345678912345678912345678912345678912") + params = [ + DbSqlParameter(name="p", value=value, type=MyCustomDecimalType.DECIMAL_38_0) + ] + result = self._get_one_result(self.QUERY, params) + assert result.col == value + + def test_dbsqlparam_custom_explicit_decimal_38_2(self): + + # This DECIMAL can be contained in a DECIMAL(38,2) column in Databricks + value = Decimal("123456789123456789123456789123456789.12") + params = [ + DbSqlParameter(name="p", value=value, type=MyCustomDecimalType.DECIMAL_38_2) + ] + result = self._get_one_result(self.QUERY, params) + assert result.col == value + + def test_dbsqlparam_custom_explicit_decimal_18_9(self): + + # This DECIMAL can be contained in a DECIMAL(18,9) column in Databricks + value = Decimal("123456789.123456789") + params = [ + DbSqlParameter(name="p", value=value, type=MyCustomDecimalType.DECIMAL_18_9) + ] + result = self._get_one_result(self.QUERY, params) + assert result.col == value + + def test_calculate_decimal_cast_string(self): + + assert calculate_decimal_cast_string(Decimal("10.00")) == "DECIMAL(4,2)" + assert ( + calculate_decimal_cast_string( + Decimal("123456789123456789.123456789123456789") + ) + == "DECIMAL(36,18)" + ) diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index daa2e232..787403c2 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -3,6 +3,8 @@ infer_types, named_parameters_to_dbsqlparams_v1, named_parameters_to_dbsqlparams_v2, + calculate_decimal_cast_string, + DbsqlDynamicDecimalType ) from databricks.sql.thrift_api.TCLIService.ttypes import ( TSparkParameter, @@ -11,6 +13,10 @@ from databricks.sql.utils import DbSqlParameter, DbSqlType import pytest +from decimal import Decimal + +from typing import List + class TestTSparkParameterConversion(object): def test_conversion_e2e(self): @@ -31,7 +37,7 @@ def test_conversion_e2e(self): name="", type="FLOAT", value=TSparkParameterValue(stringValue="1.0") ), TSparkParameter( - name="", type="DECIMAL", value=TSparkParameterValue(stringValue="1.0") + name="", type="DECIMAL(2,1)", value=TSparkParameterValue(stringValue="1.0") ), ] @@ -53,23 +59,64 @@ def test_basic_conversions_v2(self): DbSqlParameter("3", "foo"), ] - def test_type_inference(self): + def test_infer_types_none(self): with pytest.raises(ValueError): infer_types([DbSqlParameter("", None)]) + + def test_infer_types_dict(self): with pytest.raises(ValueError): infer_types([DbSqlParameter("", {1: 1})]) - assert infer_types([DbSqlParameter("", 1)]) == [ - DbSqlParameter("", "1", DbSqlType.INTEGER) - ] - assert infer_types([DbSqlParameter("", True)]) == [ - DbSqlParameter("", "True", DbSqlType.BOOLEAN) - ] - assert infer_types([DbSqlParameter("", 1.0)]) == [ - DbSqlParameter("", "1.0", DbSqlType.FLOAT) - ] - assert infer_types([DbSqlParameter("", "foo")]) == [ - DbSqlParameter("", "foo", DbSqlType.STRING) - ] - assert infer_types([DbSqlParameter("", 1.0, DbSqlType.DECIMAL)]) == [ - DbSqlParameter("", "1.0", DbSqlType.DECIMAL) - ] + + def test_infer_types_integer(self): + input = DbSqlParameter("", 1) + output = infer_types([input]) + assert output == [DbSqlParameter("", "1", DbSqlType.INTEGER)] + + def test_infer_types_boolean(self): + input = DbSqlParameter("", True) + output = infer_types([input]) + assert output == [DbSqlParameter("", "True", DbSqlType.BOOLEAN)] + + def test_infer_types_float(self): + input = DbSqlParameter("", 1.0) + output = infer_types([input]) + assert output == [DbSqlParameter("", "1.0", DbSqlType.FLOAT)] + + def test_infer_types_string(self): + input = DbSqlParameter("", "foo") + output = infer_types([input]) + assert output == [DbSqlParameter("", "foo", DbSqlType.STRING)] + + def test_infer_types_decimal(self): + # The output decimal will have a dynamically calculated decimal type with a value of DECIMAL(2,1) + input = DbSqlParameter("", Decimal("1.0")) + output: List[DbSqlParameter] = infer_types([input]) + + x = output[0] + + assert x.value == "1.0" + assert isinstance(x.type, DbsqlDynamicDecimalType) + assert x.type.value == "DECIMAL(2,1)" + + +class TestCalculateDecimalCast(object): + + def test_38_38(self): + input = Decimal(".12345678912345678912345678912345678912") + output = calculate_decimal_cast_string(input) + assert output == "DECIMAL(38,38)" + + def test_18_9(self): + input = Decimal("123456789.123456789") + output = calculate_decimal_cast_string(input) + assert output == "DECIMAL(18,9)" + + def test_38_0(self): + input = Decimal("12345678912345678912345678912345678912") + output = calculate_decimal_cast_string(input) + assert output == "DECIMAL(38,0)" + + def test_6_2(self): + input = Decimal("1234.56") + output = calculate_decimal_cast_string(input) + assert output == "DECIMAL(6,2)" \ No newline at end of file