Skip to content

Commit

Permalink
(3/x) Add e2e tests for DECIMAL inferrence
Browse files Browse the repository at this point in the history
Signed-off-by: Jesse Whitehouse <[email protected]>
  • Loading branch information
Jesse Whitehouse committed Sep 25, 2023
1 parent 7685438 commit 99f1364
Showing 1 changed file with 66 additions and 1 deletion.
67 changes: 66 additions & 1 deletion tests/e2e/common/parameterized_query_tests.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import datetime
from decimal import Decimal
from enum import Enum
from typing import Dict, List, Tuple, Union

import pytz

from databricks.sql.client import Connection
from databricks.sql.exc import DatabaseError
from databricks.sql.utils import DbSqlParameter, DbSqlType


class MyCustomDecimalType(Enum):
DECIMAL_38_0 = "DECIMAL(38,0)"
DECIMAL_38_2 = "DECIMAL(38,2)"
DECIMAL_18_9 = "DECIMAL(18,9)"


class PySQLParameterizedQueryTestSuiteMixin:
"""Namespace for tests of server-side parameterized queries"""

Expand Down Expand Up @@ -63,6 +70,11 @@ def test_primitive_inferred_string(self):
result = self._get_one_result(self.QUERY, params)
assert result.col == "Hello"

def test_primitive_inferred_decimal(self):
params = {"p": Decimal("1234.56")}
result = self._get_one_result(self.QUERY, params)
assert result.col == Decimal("1234.56")

def test_dbsqlparam_inferred_bool(self):

params = [DbSqlParameter(name="p", value=True, type=None)]
Expand Down Expand Up @@ -103,6 +115,11 @@ def test_dbsqlparam_inferred_string(self):
result = self._get_one_result(self.QUERY, params)
assert result.col == "Hello"

def test_dbsqlparam_inferred_decimal(self):
params = [DbSqlParameter(name="p", value=Decimal("1234.56"), type=None)]
result = self._get_one_result(self.QUERY, params)
assert result.col == Decimal("1234.56")

def test_dbsqlparam_explicit_bool(self):

params = [DbSqlParameter(name="p", value=True, type=DbSqlType.BOOLEAN)]
Expand Down Expand Up @@ -142,3 +159,51 @@ def test_dbsqlparam_explicit_string(self):
params = [DbSqlParameter(name="p", value="Hello", type=DbSqlType.STRING)]
result = self._get_one_result(self.QUERY, params)
assert result.col == "Hello"

def test_dbsqlparam_explicit_decimal(self):
params = [
DbSqlParameter(name="p", value=Decimal("1234.56"), type=DbSqlType.DECIMAL)
]
result = self._get_one_result(self.QUERY, params)
assert result.col == Decimal("1234.56")

def test_dbsqlparam_custom_explicit_decimal_38_0(self):

# This DECIMAL can be contained in a DECIMAL(38,0) column in Databricks
value = Decimal("12345678912345678912345678912345678912")
params = [
DbSqlParameter(name="p", value=value, type=MyCustomDecimalType.DECIMAL_38_0)
]
result = self._get_one_result(self.QUERY, params)
assert result.col == value

def test_dbsqlparam_custom_explicit_decimal_38_2(self):

# This DECIMAL can be contained in a DECIMAL(38,2) column in Databricks
value = Decimal("123456789123456789123456789123456789.12")
params = [
DbSqlParameter(name="p", value=value, type=MyCustomDecimalType.DECIMAL_38_2)
]
result = self._get_one_result(self.QUERY, params)
assert result.col == value

def test_dbsqlparam_custom_explicit_decimal_18_9(self):

# This DECIMAL can be contained in a DECIMAL(38,2) column in Databricks
value = Decimal("123456789.123456789")
params = [
DbSqlParameter(name="p", value=value, type=MyCustomDecimalType.DECIMAL_18_9)
]
result = self._get_one_result(self.QUERY, params)
assert result.col == value

def test_primitive_inferred_decimal_raises_exception(self):
"""The default precision is DECIMAL(6,2). Without a custom DbsqlParameter type, the value will be rounded"""

# This DECIMAL would require DECIMAL(10,2) but the default is DECIMAL(6,2)
params = {"p": Decimal("12345678.91")}

with self.assertRaises(
DatabaseError, msg="cannot be cast to DECIMAL(6,2) because it is malformed"
):
result = self._get_one_result(self.QUERY, params)

0 comments on commit 99f1364

Please sign in to comment.