From cdc603eb4bdc5e288d01c1ba22489d3f9c251772 Mon Sep 17 00:00:00 2001 From: William Cox Date: Wed, 19 May 2021 16:08:46 -0400 Subject: [PATCH] Add better partition string escaping --- dask_sql/input_utils/hive.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/dask_sql/input_utils/hive.py b/dask_sql/input_utils/hive.py index 179c208d2..ad55d0e2c 100644 --- a/dask_sql/input_utils/hive.py +++ b/dask_sql/input_utils/hive.py @@ -1,6 +1,7 @@ import ast import logging import os +import re from functools import partial from typing import Any, Union @@ -159,6 +160,24 @@ def wrapped_read_function(location, column_information, **kwargs): df = wrapped_read_function(location, column_information, **kwargs) return df + def _escape_partition(self, partition: str): # pragma: no cover + """ + Given a partition string like `key=value` escape the string properly for Hive. + Wrap anything but digits in quotes. Don't wrap the column name. + """ + contains_only_digits = re.compile(r"^\d+$") + + try: + k, v = partition.split("=") + if re.match(contains_only_digits, v): + escaped_value = v + else: + escaped_value = f'"{v}"' + return f"{k}={escaped_value}" + except ValueError: + logger.warning(f"{partition} didn't contain a `=`") + return partition + def _parse_hive_table_description( self, cursor: Union["sqlalchemy.engine.base.Connection", "hive.Cursor"], @@ -173,6 +192,7 @@ def _parse_hive_table_description( """ cursor.execute(f"USE {schema}") if partition: + partition = self._escape_partition(partition) result = self._fetch_all_results( cursor, f"DESCRIBE FORMATTED {table_name} PARTITION ({partition})" )