Skip to content

Commit

Permalink
(4/x) Flesh out how to declare different decimal precisions up to the…
Browse files Browse the repository at this point in the history
… max

precision supported by Databricks: DECIMAL(38,x)

Signed-off-by: Jesse Whitehouse <[email protected]>
  • Loading branch information
Jesse Whitehouse committed Sep 23, 2023
1 parent 2621f8f commit ad5314f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 14 deletions.
5 changes: 4 additions & 1 deletion src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,9 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type):


class DbSqlType(Enum):
"""The values of this enumeration are passed as literals to be used in a CAST
evaluation by the thrift server.
"""
STRING = "STRING"
DATE = "DATE"
TIMESTAMP = "TIMESTAMP"
Expand All @@ -495,7 +498,7 @@ class DbSqlType(Enum):
class DbSqlParameter:
name: str
value: Any
type: DbSqlType
type: Union[DbSqlType, Enum]

def __init__(self, name="", value=None, type=None):
self.name = name
Expand Down
49 changes: 36 additions & 13 deletions tests/e2e/common/parameterized_query_tests.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from collections import namedtuple
from datetime import datetime
from decimal import Decimal
from enum import Enum

import pytz
import pytest
from collections import namedtuple
import pytz

from databricks.sql.client import Connection
from databricks.sql.utils import DbSqlParameter, DbSqlType
from databricks.sql.exc import DatabaseError
from databricks.sql.utils import DbSqlParameter, DbSqlType


class PySQLParameterizedQueryTestSuiteMixin:
Expand Down Expand Up @@ -201,7 +202,7 @@ def test_decimal_parameterization_not_inferred(self):
base_query = "SELECT :p_decimal col_decimal"

# first, demonstrate how the default inferrence will raise a server exception because
# the default precision is (6,2)
# the default precision is (6,2). This example Decimal should be a DECIMAL(18,9)
inferred_decimal_param = DbSqlParameter(
name="p_decimal", value=Decimal("123456789.123456789")
)
Expand All @@ -215,15 +216,37 @@ def test_decimal_parameterization_not_inferred(self):
):
cursor.execute(base_query, parameters=[inferred_decimal_param])

# now we bypass inferrence and set it ourselves
custom_type_tuple_maker = namedtuple("CustomDbsqlType", "value")
decimal_18_9 = custom_type_tuple_maker("DECIMAL(18,9)")
class MyCustomDecimalType(Enum):
DECIMAL_38_0 = "DECIMAL(38,0)"
DECIMAL_38_2 = "DECIMAL(38,2)"
DECIMAL_18_9 = "DECIMAL(18,9)"

explicit_decimal_param = DbSqlParameter(
name="p_decimal", value=Decimal("123456789.123456789"), type=decimal_18_9
)
example_decimals = [
Decimal("12345678912345678912345678912345678912"),
Decimal("123456789123456789123456789123456789.12"),
Decimal("123456789.123456789"),
]

explicit_decimal_params = [
DbSqlParameter(
name="p_decimal",
value=example_decimals[0],
type=MyCustomDecimalType.DECIMAL_38_0,
),
DbSqlParameter(
name="p_decimal",
value=example_decimals[1],
type=MyCustomDecimalType.DECIMAL_38_2,
),
DbSqlParameter(
name="p_decimal",
value=example_decimals[2],
type=MyCustomDecimalType.DECIMAL_18_9,
),
]
with self.connection() as conn:
cursor = conn.cursor()
cursor.execute(base_query, parameters=[explicit_decimal_param])
result = cursor.fetchone()
assert result.col_decimal == Decimal("123456789.123456789")
for idx, param in enumerate(explicit_decimal_params):
cursor.execute(base_query, parameters=[param])
result = cursor.fetchone()
assert result.col_decimal == example_decimals[idx]

0 comments on commit ad5314f

Please sign in to comment.