Skip to content

Commit b94f59e

Browse files
author
Jesse
authored
[PECO-1109] Parameterized Query: add suport for inferring decimal types (#228)
Signed-off-by: Jesse Whitehouse <[email protected]>
1 parent 9489087 commit b94f59e

File tree

3 files changed

+190
-32
lines changed

3 files changed

+190
-32
lines changed

src/databricks/sql/utils.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
11
from __future__ import annotations
2+
3+
import copy
4+
import datetime
5+
import decimal
26
from abc import ABC, abstractmethod
3-
from collections import namedtuple, OrderedDict
7+
from collections import OrderedDict, namedtuple
48
from collections.abc import Iterable
59
from decimal import Decimal
6-
import datetime
7-
import decimal
810
from enum import Enum
11+
from typing import Any, Dict, List, Union
12+
913
import lz4.frame
10-
from typing import Dict, List, Union, Any
1114
import pyarrow
12-
from enum import Enum
13-
import copy
1415

15-
from databricks.sql import exc, OperationalError
16+
from databricks.sql import OperationalError, exc
1617
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
1718
from databricks.sql.thrift_api.TCLIService.ttypes import (
18-
TSparkArrowResultLink,
19-
TSparkRowSetType,
2019
TRowSet,
20+
TSparkArrowResultLink,
2121
TSparkParameter,
2222
TSparkParameterValue,
23+
TSparkRowSetType,
2324
)
2425

2526
BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128]
@@ -478,6 +479,10 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type):
478479

479480

480481
class DbSqlType(Enum):
482+
"""The values of this enumeration are passed as literals to be used in a CAST
483+
evaluation by the thrift server.
484+
"""
485+
481486
STRING = "STRING"
482487
DATE = "DATE"
483488
TIMESTAMP = "TIMESTAMP"
@@ -495,7 +500,7 @@ class DbSqlType(Enum):
495500
class DbSqlParameter:
496501
name: str
497502
value: Any
498-
type: DbSqlType
503+
type: Union[DbSqlType, DbsqlDynamicDecimalType, Enum]
499504

500505
def __init__(self, name="", value=None, type=None):
501506
self.name = name
@@ -506,6 +511,11 @@ def __eq__(self, other):
506511
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
507512

508513

514+
class DbsqlDynamicDecimalType:
515+
def __init__(self, value):
516+
self.value = value
517+
518+
509519
def named_parameters_to_dbsqlparams_v1(parameters: Dict[str, str]):
510520
dbsqlparams = []
511521
for name, parameter in parameters.items():
@@ -531,16 +541,49 @@ def infer_types(params: list[DbSqlParameter]):
531541
datetime.datetime: DbSqlType.TIMESTAMP,
532542
datetime.date: DbSqlType.DATE,
533543
bool: DbSqlType.BOOLEAN,
544+
Decimal: DbSqlType.DECIMAL,
534545
}
535-
newParams = copy.deepcopy(params)
536-
for param in newParams:
546+
new_params = copy.deepcopy(params)
547+
for param in new_params:
537548
if not param.type:
538549
if type(param.value) in type_lookup_table:
539550
param.type = type_lookup_table[type(param.value)]
540551
else:
541552
raise ValueError("Parameter type cannot be inferred")
553+
554+
if param.type == DbSqlType.DECIMAL:
555+
cast_exp = calculate_decimal_cast_string(param.value)
556+
param.type = DbsqlDynamicDecimalType(cast_exp)
557+
542558
param.value = str(param.value)
543-
return newParams
559+
return new_params
560+
561+
562+
def calculate_decimal_cast_string(input: Decimal) -> str:
563+
"""Returns the smallest SQL cast argument that can contain the passed decimal
564+
565+
Example:
566+
Input: Decimal("1234.5678")
567+
Output: DECIMAL(8,4)
568+
"""
569+
570+
string_decimal = str(input)
571+
572+
if string_decimal.startswith("0."):
573+
# This decimal is less than 1
574+
overall = after = len(string_decimal) - 2
575+
elif "." not in string_decimal:
576+
# This decimal has no fractional component
577+
overall = len(string_decimal)
578+
after = 0
579+
else:
580+
# This decimal has both whole and fractional parts
581+
parts = string_decimal.split(".")
582+
parts_lengths = [len(i) for i in parts]
583+
before, after = parts_lengths[:2]
584+
overall = before + after
585+
586+
return f"DECIMAL({overall},{after})"
544587

545588

546589
def named_parameters_to_tsparkparams(parameters: Union[List[Any], Dict[str, str]]):

tests/e2e/common/parameterized_query_tests.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
11
import datetime
22
from decimal import Decimal
3+
from enum import Enum
34
from typing import Dict, List, Tuple, Union
45

56
import pytz
67

7-
from databricks.sql.client import Connection
8-
from databricks.sql.utils import DbSqlParameter, DbSqlType
8+
from databricks.sql.exc import DatabaseError
9+
from databricks.sql.utils import (
10+
DbSqlParameter,
11+
DbSqlType,
12+
calculate_decimal_cast_string,
13+
)
14+
15+
16+
class MyCustomDecimalType(Enum):
17+
DECIMAL_38_0 = "DECIMAL(38,0)"
18+
DECIMAL_38_2 = "DECIMAL(38,2)"
19+
DECIMAL_18_9 = "DECIMAL(18,9)"
920

1021

1122
class PySQLParameterizedQueryTestSuiteMixin:
@@ -63,6 +74,11 @@ def test_primitive_inferred_string(self):
6374
result = self._get_one_result(self.QUERY, params)
6475
assert result.col == "Hello"
6576

77+
def test_primitive_inferred_decimal(self):
78+
params = {"p": Decimal("1234.56")}
79+
result = self._get_one_result(self.QUERY, params)
80+
assert result.col == Decimal("1234.56")
81+
6682
def test_dbsqlparam_inferred_bool(self):
6783

6884
params = [DbSqlParameter(name="p", value=True, type=None)]
@@ -103,6 +119,11 @@ def test_dbsqlparam_inferred_string(self):
103119
result = self._get_one_result(self.QUERY, params)
104120
assert result.col == "Hello"
105121

122+
def test_dbsqlparam_inferred_decimal(self):
123+
params = [DbSqlParameter(name="p", value=Decimal("1234.56"), type=None)]
124+
result = self._get_one_result(self.QUERY, params)
125+
assert result.col == Decimal("1234.56")
126+
106127
def test_dbsqlparam_explicit_bool(self):
107128

108129
params = [DbSqlParameter(name="p", value=True, type=DbSqlType.BOOLEAN)]
@@ -142,3 +163,50 @@ def test_dbsqlparam_explicit_string(self):
142163
params = [DbSqlParameter(name="p", value="Hello", type=DbSqlType.STRING)]
143164
result = self._get_one_result(self.QUERY, params)
144165
assert result.col == "Hello"
166+
167+
def test_dbsqlparam_explicit_decimal(self):
168+
params = [
169+
DbSqlParameter(name="p", value=Decimal("1234.56"), type=DbSqlType.DECIMAL)
170+
]
171+
result = self._get_one_result(self.QUERY, params)
172+
assert result.col == Decimal("1234.56")
173+
174+
def test_dbsqlparam_custom_explicit_decimal_38_0(self):
175+
176+
# This DECIMAL can be contained in a DECIMAL(38,0) column in Databricks
177+
value = Decimal("12345678912345678912345678912345678912")
178+
params = [
179+
DbSqlParameter(name="p", value=value, type=MyCustomDecimalType.DECIMAL_38_0)
180+
]
181+
result = self._get_one_result(self.QUERY, params)
182+
assert result.col == value
183+
184+
def test_dbsqlparam_custom_explicit_decimal_38_2(self):
185+
186+
# This DECIMAL can be contained in a DECIMAL(38,2) column in Databricks
187+
value = Decimal("123456789123456789123456789123456789.12")
188+
params = [
189+
DbSqlParameter(name="p", value=value, type=MyCustomDecimalType.DECIMAL_38_2)
190+
]
191+
result = self._get_one_result(self.QUERY, params)
192+
assert result.col == value
193+
194+
def test_dbsqlparam_custom_explicit_decimal_18_9(self):
195+
196+
# This DECIMAL can be contained in a DECIMAL(18,9) column in Databricks
197+
value = Decimal("123456789.123456789")
198+
params = [
199+
DbSqlParameter(name="p", value=value, type=MyCustomDecimalType.DECIMAL_18_9)
200+
]
201+
result = self._get_one_result(self.QUERY, params)
202+
assert result.col == value
203+
204+
def test_calculate_decimal_cast_string(self):
205+
206+
assert calculate_decimal_cast_string(Decimal("10.00")) == "DECIMAL(4,2)"
207+
assert (
208+
calculate_decimal_cast_string(
209+
Decimal("123456789123456789.123456789123456789")
210+
)
211+
== "DECIMAL(36,18)"
212+
)

tests/unit/test_parameters.py

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
infer_types,
44
named_parameters_to_dbsqlparams_v1,
55
named_parameters_to_dbsqlparams_v2,
6+
calculate_decimal_cast_string,
7+
DbsqlDynamicDecimalType
68
)
79
from databricks.sql.thrift_api.TCLIService.ttypes import (
810
TSparkParameter,
@@ -11,6 +13,10 @@
1113
from databricks.sql.utils import DbSqlParameter, DbSqlType
1214
import pytest
1315

16+
from decimal import Decimal
17+
18+
from typing import List
19+
1420

1521
class TestTSparkParameterConversion(object):
1622
def test_conversion_e2e(self):
@@ -31,7 +37,7 @@ def test_conversion_e2e(self):
3137
name="", type="FLOAT", value=TSparkParameterValue(stringValue="1.0")
3238
),
3339
TSparkParameter(
34-
name="", type="DECIMAL", value=TSparkParameterValue(stringValue="1.0")
40+
name="", type="DECIMAL(2,1)", value=TSparkParameterValue(stringValue="1.0")
3541
),
3642
]
3743

@@ -53,23 +59,64 @@ def test_basic_conversions_v2(self):
5359
DbSqlParameter("3", "foo"),
5460
]
5561

56-
def test_type_inference(self):
62+
def test_infer_types_none(self):
5763
with pytest.raises(ValueError):
5864
infer_types([DbSqlParameter("", None)])
65+
66+
def test_infer_types_dict(self):
5967
with pytest.raises(ValueError):
6068
infer_types([DbSqlParameter("", {1: 1})])
61-
assert infer_types([DbSqlParameter("", 1)]) == [
62-
DbSqlParameter("", "1", DbSqlType.INTEGER)
63-
]
64-
assert infer_types([DbSqlParameter("", True)]) == [
65-
DbSqlParameter("", "True", DbSqlType.BOOLEAN)
66-
]
67-
assert infer_types([DbSqlParameter("", 1.0)]) == [
68-
DbSqlParameter("", "1.0", DbSqlType.FLOAT)
69-
]
70-
assert infer_types([DbSqlParameter("", "foo")]) == [
71-
DbSqlParameter("", "foo", DbSqlType.STRING)
72-
]
73-
assert infer_types([DbSqlParameter("", 1.0, DbSqlType.DECIMAL)]) == [
74-
DbSqlParameter("", "1.0", DbSqlType.DECIMAL)
75-
]
69+
70+
def test_infer_types_integer(self):
71+
input = DbSqlParameter("", 1)
72+
output = infer_types([input])
73+
assert output == [DbSqlParameter("", "1", DbSqlType.INTEGER)]
74+
75+
def test_infer_types_boolean(self):
76+
input = DbSqlParameter("", True)
77+
output = infer_types([input])
78+
assert output == [DbSqlParameter("", "True", DbSqlType.BOOLEAN)]
79+
80+
def test_infer_types_float(self):
81+
input = DbSqlParameter("", 1.0)
82+
output = infer_types([input])
83+
assert output == [DbSqlParameter("", "1.0", DbSqlType.FLOAT)]
84+
85+
def test_infer_types_string(self):
86+
input = DbSqlParameter("", "foo")
87+
output = infer_types([input])
88+
assert output == [DbSqlParameter("", "foo", DbSqlType.STRING)]
89+
90+
def test_infer_types_decimal(self):
91+
# The output decimal will have a dynamically calculated decimal type with a value of DECIMAL(2,1)
92+
input = DbSqlParameter("", Decimal("1.0"))
93+
output: List[DbSqlParameter] = infer_types([input])
94+
95+
x = output[0]
96+
97+
assert x.value == "1.0"
98+
assert isinstance(x.type, DbsqlDynamicDecimalType)
99+
assert x.type.value == "DECIMAL(2,1)"
100+
101+
102+
class TestCalculateDecimalCast(object):
103+
104+
def test_38_38(self):
105+
input = Decimal(".12345678912345678912345678912345678912")
106+
output = calculate_decimal_cast_string(input)
107+
assert output == "DECIMAL(38,38)"
108+
109+
def test_18_9(self):
110+
input = Decimal("123456789.123456789")
111+
output = calculate_decimal_cast_string(input)
112+
assert output == "DECIMAL(18,9)"
113+
114+
def test_38_0(self):
115+
input = Decimal("12345678912345678912345678912345678912")
116+
output = calculate_decimal_cast_string(input)
117+
assert output == "DECIMAL(38,0)"
118+
119+
def test_6_2(self):
120+
input = Decimal("1234.56")
121+
output = calculate_decimal_cast_string(input)
122+
assert output == "DECIMAL(6,2)"

0 commit comments

Comments
 (0)