Skip to content

Commit

Permalink
Merge pull request #1719 from moj-analytical-services/1712-bug-inputc…
Browse files Browse the repository at this point in the history
…olumn-does-not-work-properly-with-spark-columns-that-need-escaping

Fix InputColumn quoting for spark and improve code quality
  • Loading branch information
RobinL authored Nov 13, 2023
2 parents fe9d0f3 + 16de63e commit 9ae3014
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 39 deletions.
116 changes: 77 additions & 39 deletions splink/input_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sqlglot
import sqlglot.expressions as exp
from sqlglot.errors import ParseError
from sqlglot.expressions import Expression

from .default_from_jsonschema import default_value_from_schema

Expand All @@ -21,39 +22,58 @@ def sqlglot_tree_signature(tree):
return " ".join(n[0].key for n in tree.walk())


def add_suffix(tree, suffix):
def add_suffix(tree, suffix) -> Expression:
tree = tree.copy()
identifier_string = tree.find(exp.Identifier).this
identifier_string = f"{identifier_string}{suffix}"
tree.find(exp.Identifier).args["this"] = identifier_string
return tree


def add_prefix(tree, prefix):
def add_prefix(tree, prefix) -> Expression:
tree = tree.copy()
identifier_string = tree.find(exp.Identifier).this
identifier_string = f"{prefix}{identifier_string}"
tree.find(exp.Identifier).args["this"] = identifier_string
return tree


def add_table(tree, tablename):
def add_table(tree, tablename) -> Expression:
tree = tree.copy()
table_identifier = exp.Identifier(this=tablename, quoted=True)
identifier = tree.find(exp.Column)
identifier.args["table"] = table_identifier
return tree


def remove_quotes_from_identifiers(tree):
def remove_quotes_from_identifiers(tree) -> Expression:
tree = tree.copy()
for identifier in tree.find_all(exp.Identifier):
identifier.args["quoted"] = False
return tree


class InputColumn:
def __init__(self, name, settings_obj=None, sql_dialect=None):
"""
Represents a SQL column or column reference
Handles SQL dialect-specific issues such as identifier quoting.
The input should be the raw identifier, without SQL-specific identifier quotes.
For example, if a column is named 'first name' (with a space), the input should be
'first name', not '"first name"'.
Examples of valid inputs include:
- 'first_name'
- 'first name'
- 'coordinates['lat']'
- 'coordinates[1]'
"""

def __init__(
self, raw_column_name_or_column_reference, settings_obj=None, sql_dialect=None
):
# If settings_obj is None, then default values will be used
# from the jsonschama
self._settings_obj = settings_obj
Expand All @@ -65,40 +85,48 @@ def __init__(self, name, settings_obj=None, sql_dialect=None):
else:
self._sql_dialect = None

self.input_name = self._quote_name(name)
self.input_name = self._quote_if_sql_keyword(
raw_column_name_or_column_reference
)

self.input_name_as_tree = self.parse_input_name_to_sqlglot_tree()

for identifier in self.input_name_as_tree.find_all(exp.Identifier):
identifier.args["quoted"] = True

def quote(self):
def quote(self) -> "InputColumn":
self_copy = deepcopy(self)
for identifier in self_copy.input_name_as_tree.find_all(exp.Identifier):
identifier.args["quoted"] = True
return self_copy

def unquote(self):
def unquote(self) -> "InputColumn":
self_copy = deepcopy(self)
for identifier in self_copy.input_name_as_tree.find_all(exp.Identifier):
identifier.args["quoted"] = False
return self_copy

def parse_input_name_to_sqlglot_tree(self):
# Cases that could occur for self.input_name:
# SUR name -> parses to 'alias column identifier identifier'
# first and surname -> parses to 'and column column identifier identifier'
# a b c -> parse error
# "SUR name" -> parses to 'column identifier'
# geocode['lat'] -> parsees to bracket column literal identifier
# geocode[1] -> parsees to bracket column literal identifier
def parse_input_name_to_sqlglot_tree(self) -> Expression:
"""
Parses the input name into a SQLglot expression tree.
Fiddly because we need to deal with escaping issues. For example
the column name in the input dataset may be 'first and surname', but
if we naively parse this using sqlglot it will be interpreted as an AND
expression
# Note we don't expect SUR name[1] since the user should have quoted this
Note: We do not support inputs like 'SUR name[1]', in this case the user
would have to quote e.g. `SUR name`[1]
"""

q_s, q_e = _get_dialect_quotes(self._sql_dialect)

try:
tree = sqlglot.parse_one(self.input_name, read=self._sql_dialect)
except ParseError:
tree = sqlglot.parse_one(f'"{self.input_name}"', read=self._sql_dialect)
tree = sqlglot.parse_one(
f"{q_s}{self.input_name}{q_e}", read=self._sql_dialect
)

tree_signature = sqlglot_tree_signature(tree)
valid_signatures = ["column identifier", "bracket column literal identifier"]
Expand All @@ -108,7 +136,9 @@ def parse_input_name_to_sqlglot_tree(self):
else:
# e.g. SUR name parses to 'alias column identifier identifier'
# but we want "SUR name"
tree = sqlglot.parse_one(f'"{self.input_name}"', read=self._sql_dialect)
tree = sqlglot.parse_one(
f"{q_s}{self.input_name}{q_e}", read=self._sql_dialect
)
return tree

def from_settings_obj_else_default(self, key, schema_key=None):
Expand All @@ -121,101 +151,109 @@ def from_settings_obj_else_default(self, key, schema_key=None):
return default_value_from_schema(schema_key, "root")

@property
def gamma_prefix(self):
def gamma_prefix(self) -> str:
return self.from_settings_obj_else_default(
"_gamma_prefix", "comparison_vector_value_column_prefix"
)

@property
def bf_prefix(self):
def bf_prefix(self) -> str:
return self.from_settings_obj_else_default(
"_bf_prefix", "bayes_factor_column_prefix"
)

@property
def tf_prefix(self):
def tf_prefix(self) -> str:
return self.from_settings_obj_else_default(
"_tf_prefix", "term_frequency_adjustment_column_prefix"
)

def name(self):
def name(self) -> str:
return self.input_name_as_tree.sql(dialect=self._sql_dialect)

def name_l(self):
def name_l(self) -> str:
return add_suffix(self.input_name_as_tree, suffix="_l").sql(
dialect=self._sql_dialect
)

def name_r(self):
def name_r(self) -> str:
return add_suffix(self.input_name_as_tree, suffix="_r").sql(
dialect=self._sql_dialect
)

def names_l_r(self):
def names_l_r(self) -> list[str]:
return [self.name_l(), self.name_r()]

def l_name_as_l(self):
def l_name_as_l(self) -> str:
name_with_l_table = add_table(self.input_name_as_tree, "l").sql(
dialect=self._sql_dialect
)
return f"{name_with_l_table} as {self.name_l()}"

def r_name_as_r(self):
def r_name_as_r(self) -> str:
name_with_r_table = add_table(self.input_name_as_tree, "r").sql(
dialect=self._sql_dialect
)
return f"{name_with_r_table} as {self.name_r()}"

def l_r_names_as_l_r(self):
def l_r_names_as_l_r(self) -> list[str]:
return [self.l_name_as_l(), self.r_name_as_r()]

def bf_name(self):
def bf_name(self) -> str:
return add_prefix(self.input_name_as_tree, prefix=self.bf_prefix).sql(
dialect=self._sql_dialect
)

def tf_name(self):
def tf_name(self) -> str:
return add_prefix(self.input_name_as_tree, prefix=self.tf_prefix).sql(
dialect=self._sql_dialect
)

def tf_name_l(self):
def tf_name_l(self) -> str:
tree = add_prefix(self.input_name_as_tree, prefix=self.tf_prefix)
return add_suffix(tree, suffix="_l").sql(dialect=self._sql_dialect)

def tf_name_r(self):
def tf_name_r(self) -> str:
tree = add_prefix(self.input_name_as_tree, prefix=self.tf_prefix)
return add_suffix(tree, suffix="_r").sql(dialect=self._sql_dialect)

def tf_name_l_r(self):
def tf_name_l_r(self) -> list[str]:
return [self.tf_name_l(), self.tf_name_r()]

def l_tf_name_as_l(self):
def l_tf_name_as_l(self) -> str:
tree = add_prefix(self.input_name_as_tree, prefix=self.tf_prefix)
tf_name_with_l_table = add_table(tree, tablename="l").sql(
dialect=self._sql_dialect
)
return f"{tf_name_with_l_table} as {self.tf_name_l()}"

def r_tf_name_as_r(self):
def r_tf_name_as_r(self) -> str:
tree = add_prefix(self.input_name_as_tree, prefix=self.tf_prefix)
tf_name_with_r_table = add_table(tree, tablename="r").sql(
dialect=self._sql_dialect
)
return f"{tf_name_with_r_table} as {self.tf_name_r()}"

def l_r_tf_names_as_l_r(self):
def l_r_tf_names_as_l_r(self) -> list[str]:
return [self.l_tf_name_as_l(), self.r_tf_name_as_r()]

def _quote_name(self, name: str) -> str:
# Quote column names that are also SQL keywords
def _quote_if_sql_keyword(self, name: str) -> str:
if name not in {"group", "index"}:
return name
start, end = _get_dialect_quotes(self._sql_dialect)
return start + name + end


def _get_dialect_quotes(dialect):
"""
Returns the appropriate quotation marks for identifiers based on the SQL dialect.
For most SQL dialects, identifiers are quoted using double quotes.
For example, "first name" is a quoted identifier that
allows for a space in the column name.
However, some SQL dialects, use other identifiers e.g. ` in Spark SQL
"""
start = end = '"'
if dialect is None:
return start, end
Expand Down
3 changes: 3 additions & 0 deletions tests/test_input_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ def test_input_column():

assert c.unquote().name() == "col['lat']"
assert c.unquote().quote().name() == name

c = InputColumn("first name", sql_dialect="spark")
assert c.name() == "`first name`"

0 comments on commit 9ae3014

Please sign in to comment.