Skip to content

Commit

Permalink
[PECO-1109] Parameterized Query: add suport for inferring decimal typ…
Browse files Browse the repository at this point in the history
…es (#228)

Signed-off-by: Jesse Whitehouse <[email protected]>
  • Loading branch information
Jesse authored Sep 26, 2023
1 parent fcc262f commit 1239bff
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 32 deletions.
69 changes: 56 additions & 13 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
from __future__ import annotations

import copy
import datetime
import decimal
from abc import ABC, abstractmethod
from collections import namedtuple, OrderedDict
from collections import OrderedDict, namedtuple
from collections.abc import Iterable
from decimal import Decimal
import datetime
import decimal
from enum import Enum
from typing import Any, Dict, List, Union

import lz4.frame
from typing import Dict, List, Union, Any
import pyarrow
from enum import Enum
import copy

from databricks.sql import exc, OperationalError
from databricks.sql import OperationalError, exc
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
from databricks.sql.thrift_api.TCLIService.ttypes import (
TSparkArrowResultLink,
TSparkRowSetType,
TRowSet,
TSparkArrowResultLink,
TSparkParameter,
TSparkParameterValue,
TSparkRowSetType,
)

BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
Expand Down Expand Up @@ -478,6 +479,10 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type):


class DbSqlType(Enum):
"""The values of this enumeration are passed as literals to be used in a CAST
evaluation by the thrift server.
"""

STRING = "STRING"
DATE = "DATE"
TIMESTAMP = "TIMESTAMP"
Expand All @@ -495,7 +500,7 @@ class DbSqlType(Enum):
class DbSqlParameter:
name: str
value: Any
type: DbSqlType
type: Union[DbSqlType, DbsqlDynamicDecimalType, Enum]

def __init__(self, name="", value=None, type=None):
self.name = name
Expand All @@ -506,6 +511,11 @@ def __eq__(self, other):
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__


class DbsqlDynamicDecimalType:
def __init__(self, value):
self.value = value


def named_parameters_to_dbsqlparams_v1(parameters: Dict[str, str]):
dbsqlparams = []
for name, parameter in parameters.items():
Expand All @@ -531,16 +541,49 @@ def infer_types(params: list[DbSqlParameter]):
datetime.datetime: DbSqlType.TIMESTAMP,
datetime.date: DbSqlType.DATE,
bool: DbSqlType.BOOLEAN,
Decimal: DbSqlType.DECIMAL,
}
newParams = copy.deepcopy(params)
for param in newParams:
new_params = copy.deepcopy(params)
for param in new_params:
if not param.type:
if type(param.value) in type_lookup_table:
param.type = type_lookup_table[type(param.value)]
else:
raise ValueError("Parameter type cannot be inferred")

if param.type == DbSqlType.DECIMAL:
cast_exp = calculate_decimal_cast_string(param.value)
param.type = DbsqlDynamicDecimalType(cast_exp)

param.value = str(param.value)
return newParams
return new_params


def calculate_decimal_cast_string(input: Decimal) -> str:
"""Returns the smallest SQL cast argument that can contain the passed decimal
Example:
Input: Decimal("1234.5678")
Output: DECIMAL(8,4)
"""

string_decimal = str(input)

if string_decimal.startswith("0."):
# This decimal is less than 1
overall = after = len(string_decimal) - 2
elif "." not in string_decimal:
# This decimal has no fractional component
overall = len(string_decimal)
after = 0
else:
# This decimal has both whole and fractional parts
parts = string_decimal.split(".")
parts_lengths = [len(i) for i in parts]
before, after = parts_lengths[:2]
overall = before + after

return f"DECIMAL({overall},{after})"


def named_parameters_to_tsparkparams(parameters: Union[List[Any], Dict[str, str]]):
Expand Down
72 changes: 70 additions & 2 deletions tests/e2e/common/parameterized_query_tests.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import datetime
from decimal import Decimal
from enum import Enum
from typing import Dict, List, Tuple, Union

import pytz

from databricks.sql.client import Connection
from databricks.sql.utils import DbSqlParameter, DbSqlType
from databricks.sql.exc import DatabaseError
from databricks.sql.utils import (
DbSqlParameter,
DbSqlType,
calculate_decimal_cast_string,
)


class MyCustomDecimalType(Enum):
DECIMAL_38_0 = "DECIMAL(38,0)"
DECIMAL_38_2 = "DECIMAL(38,2)"
DECIMAL_18_9 = "DECIMAL(18,9)"


class PySQLParameterizedQueryTestSuiteMixin:
Expand Down Expand Up @@ -63,6 +74,11 @@ def test_primitive_inferred_string(self):
result = self._get_one_result(self.QUERY, params)
assert result.col == "Hello"

def test_primitive_inferred_decimal(self):
params = {"p": Decimal("1234.56")}
result = self._get_one_result(self.QUERY, params)
assert result.col == Decimal("1234.56")

def test_dbsqlparam_inferred_bool(self):

params = [DbSqlParameter(name="p", value=True, type=None)]
Expand Down Expand Up @@ -103,6 +119,11 @@ def test_dbsqlparam_inferred_string(self):
result = self._get_one_result(self.QUERY, params)
assert result.col == "Hello"

def test_dbsqlparam_inferred_decimal(self):
params = [DbSqlParameter(name="p", value=Decimal("1234.56"), type=None)]
result = self._get_one_result(self.QUERY, params)
assert result.col == Decimal("1234.56")

def test_dbsqlparam_explicit_bool(self):

params = [DbSqlParameter(name="p", value=True, type=DbSqlType.BOOLEAN)]
Expand Down Expand Up @@ -142,3 +163,50 @@ def test_dbsqlparam_explicit_string(self):
params = [DbSqlParameter(name="p", value="Hello", type=DbSqlType.STRING)]
result = self._get_one_result(self.QUERY, params)
assert result.col == "Hello"

def test_dbsqlparam_explicit_decimal(self):
params = [
DbSqlParameter(name="p", value=Decimal("1234.56"), type=DbSqlType.DECIMAL)
]
result = self._get_one_result(self.QUERY, params)
assert result.col == Decimal("1234.56")

def test_dbsqlparam_custom_explicit_decimal_38_0(self):

# This DECIMAL can be contained in a DECIMAL(38,0) column in Databricks
value = Decimal("12345678912345678912345678912345678912")
params = [
DbSqlParameter(name="p", value=value, type=MyCustomDecimalType.DECIMAL_38_0)
]
result = self._get_one_result(self.QUERY, params)
assert result.col == value

def test_dbsqlparam_custom_explicit_decimal_38_2(self):

# This DECIMAL can be contained in a DECIMAL(38,2) column in Databricks
value = Decimal("123456789123456789123456789123456789.12")
params = [
DbSqlParameter(name="p", value=value, type=MyCustomDecimalType.DECIMAL_38_2)
]
result = self._get_one_result(self.QUERY, params)
assert result.col == value

def test_dbsqlparam_custom_explicit_decimal_18_9(self):

# This DECIMAL can be contained in a DECIMAL(18,9) column in Databricks
value = Decimal("123456789.123456789")
params = [
DbSqlParameter(name="p", value=value, type=MyCustomDecimalType.DECIMAL_18_9)
]
result = self._get_one_result(self.QUERY, params)
assert result.col == value

def test_calculate_decimal_cast_string(self):

assert calculate_decimal_cast_string(Decimal("10.00")) == "DECIMAL(4,2)"
assert (
calculate_decimal_cast_string(
Decimal("123456789123456789.123456789123456789")
)
== "DECIMAL(36,18)"
)
81 changes: 64 additions & 17 deletions tests/unit/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
infer_types,
named_parameters_to_dbsqlparams_v1,
named_parameters_to_dbsqlparams_v2,
calculate_decimal_cast_string,
DbsqlDynamicDecimalType
)
from databricks.sql.thrift_api.TCLIService.ttypes import (
TSparkParameter,
Expand All @@ -11,6 +13,10 @@
from databricks.sql.utils import DbSqlParameter, DbSqlType
import pytest

from decimal import Decimal

from typing import List


class TestTSparkParameterConversion(object):
def test_conversion_e2e(self):
Expand All @@ -31,7 +37,7 @@ def test_conversion_e2e(self):
name="", type="FLOAT", value=TSparkParameterValue(stringValue="1.0")
),
TSparkParameter(
name="", type="DECIMAL", value=TSparkParameterValue(stringValue="1.0")
name="", type="DECIMAL(2,1)", value=TSparkParameterValue(stringValue="1.0")
),
]

Expand All @@ -53,23 +59,64 @@ def test_basic_conversions_v2(self):
DbSqlParameter("3", "foo"),
]

def test_type_inference(self):
def test_infer_types_none(self):
with pytest.raises(ValueError):
infer_types([DbSqlParameter("", None)])

def test_infer_types_dict(self):
with pytest.raises(ValueError):
infer_types([DbSqlParameter("", {1: 1})])
assert infer_types([DbSqlParameter("", 1)]) == [
DbSqlParameter("", "1", DbSqlType.INTEGER)
]
assert infer_types([DbSqlParameter("", True)]) == [
DbSqlParameter("", "True", DbSqlType.BOOLEAN)
]
assert infer_types([DbSqlParameter("", 1.0)]) == [
DbSqlParameter("", "1.0", DbSqlType.FLOAT)
]
assert infer_types([DbSqlParameter("", "foo")]) == [
DbSqlParameter("", "foo", DbSqlType.STRING)
]
assert infer_types([DbSqlParameter("", 1.0, DbSqlType.DECIMAL)]) == [
DbSqlParameter("", "1.0", DbSqlType.DECIMAL)
]

def test_infer_types_integer(self):
input = DbSqlParameter("", 1)
output = infer_types([input])
assert output == [DbSqlParameter("", "1", DbSqlType.INTEGER)]

def test_infer_types_boolean(self):
input = DbSqlParameter("", True)
output = infer_types([input])
assert output == [DbSqlParameter("", "True", DbSqlType.BOOLEAN)]

def test_infer_types_float(self):
input = DbSqlParameter("", 1.0)
output = infer_types([input])
assert output == [DbSqlParameter("", "1.0", DbSqlType.FLOAT)]

def test_infer_types_string(self):
input = DbSqlParameter("", "foo")
output = infer_types([input])
assert output == [DbSqlParameter("", "foo", DbSqlType.STRING)]

def test_infer_types_decimal(self):
# The output decimal will have a dynamically calculated decimal type with a value of DECIMAL(2,1)
input = DbSqlParameter("", Decimal("1.0"))
output: List[DbSqlParameter] = infer_types([input])

x = output[0]

assert x.value == "1.0"
assert isinstance(x.type, DbsqlDynamicDecimalType)
assert x.type.value == "DECIMAL(2,1)"


class TestCalculateDecimalCast(object):

def test_38_38(self):
input = Decimal(".12345678912345678912345678912345678912")
output = calculate_decimal_cast_string(input)
assert output == "DECIMAL(38,38)"

def test_18_9(self):
input = Decimal("123456789.123456789")
output = calculate_decimal_cast_string(input)
assert output == "DECIMAL(18,9)"

def test_38_0(self):
input = Decimal("12345678912345678912345678912345678912")
output = calculate_decimal_cast_string(input)
assert output == "DECIMAL(38,0)"

def test_6_2(self):
input = Decimal("1234.56")
output = calculate_decimal_cast_string(input)
assert output == "DECIMAL(6,2)"

0 comments on commit 1239bff

Please sign in to comment.