diff --git a/dask_sql/input_utils/hive.py b/dask_sql/input_utils/hive.py index 4e1bdde62..b786c0913 100644 --- a/dask_sql/input_utils/hive.py +++ b/dask_sql/input_utils/hive.py @@ -1,6 +1,7 @@ import ast import logging import os +import re from functools import partial from typing import Any, Union @@ -184,6 +185,24 @@ def wrapped_read_function(location, column_information, **kwargs): df = wrapped_read_function(location, column_information, **kwargs) return df + def _escape_partition(self, partition: str): # pragma: no cover + """ + Given a partition string like `key=value` escape the string properly for Hive. + Wrap anything but digits in quotes. Don't wrap the column name. + """ + contains_only_digits = re.compile(r"^\d+$") + + try: + k, v = partition.split("=") + if re.match(contains_only_digits, v): + escaped_value = v + else: + escaped_value = f'"{v}"' + return f"{k}={escaped_value}" + except ValueError: + logger.warning(f"{partition} didn't contain a `=`") + return partition + def _parse_hive_table_description( self, cursor: Union["sqlalchemy.engine.base.Connection", "hive.Cursor"], @@ -198,6 +217,7 @@ def _parse_hive_table_description( """ cursor.execute(f"USE {schema}") if partition: + partition = self._escape_partition(partition) result = self._fetch_all_results( cursor, f"DESCRIBE FORMATTED {table_name} PARTITION ({partition})" )