Skip to content

Commit

Permalink
Handle databricks 2.9 paramstyle
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Jan 17, 2024
1 parent ebf16fb commit 9b712aa
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
34 changes: 21 additions & 13 deletions dlt/destinations/impl/databricks/sql_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import contextmanager, suppress
from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List
from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List, Union, Dict

from databricks import sql as databricks_lib
from databricks.sql.client import (
Expand Down Expand Up @@ -101,18 +101,26 @@ def execute_many(
@raise_database_error
def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]:
curr: DBApiCursor = None
if args:
keys = [f"arg{i}" for i in range(len(args))]
# Replace position arguments (%s) with named arguments (:arg0, :arg1, ...)
query = query % tuple(f":{key}" for key in keys)
db_args = {}
for key, db_arg in zip(keys, args):
# Databricks connector doesn't accept pendulum objects
if isinstance(db_arg, pendulum.DateTime):
db_arg = to_py_datetime(db_arg)
elif isinstance(db_arg, pendulum.Date):
db_arg = to_py_date(db_arg)
db_args[key] = db_arg
# TODO: databricks connector 3.0.0 will use :named paramstyle only
# if args:
# keys = [f"arg{i}" for i in range(len(args))]
# # Replace position arguments (%s) with named arguments (:arg0, :arg1, ...)
# # query = query % tuple(f":{key}" for key in keys)
# db_args = {}
# for key, db_arg in zip(keys, args):
# # Databricks connector doesn't accept pendulum objects
# if isinstance(db_arg, pendulum.DateTime):
# db_arg = to_py_datetime(db_arg)
# elif isinstance(db_arg, pendulum.Date):
# db_arg = to_py_date(db_arg)
# db_args[key] = db_arg
# else:
# db_args = None
db_args: Optional[Union[Dict[str, Any], Sequence[Any]]]
if kwargs:
db_args = kwargs
elif args:
db_args = args
else:
db_args = None
with self._conn.cursor() as curr:
Expand Down
1 change: 1 addition & 0 deletions tests/load/pipeline/test_arrow_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_load_item(
include_time = destination_config.destination not in (
"athena",
"redshift",
"databricks",
) # athena/redshift can't load TIME columns from parquet
item, records = arrow_table_all_data_types(
item_type, include_json=False, include_time=include_time
Expand Down

0 comments on commit 9b712aa

Please sign in to comment.