From 47c1e09c755b0dd93d46a3a203ee6bf644c66ea1 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 11 Oct 2024 15:45:40 -0400 Subject: [PATCH] fix: `sqlparse` fallback for formatting queries (#30578) --- superset/sql/parse.py | 108 ++++++++++++++---- tests/integration_tests/sql_lab/api_tests.py | 2 +- tests/unit_tests/db_engine_specs/test_base.py | 16 +-- tests/unit_tests/sql/parse_tests.py | 34 ++++++ 4 files changed, 125 insertions(+), 35 deletions(-) diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 377411b944814..33ed76473facc 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -26,6 +26,8 @@ from typing import Any, Generic, TypeVar import sqlglot +import sqlparse +from deprecation import deprecated from sqlglot import exp from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.errors import ParseError @@ -138,9 +140,9 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ Base class for SQL statements. - The class can be instantiated with a string representation of the script or, for - efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`, - which will split a script in multiple already parsed statements. + The class should be instantiated with a string representation of the script and, for + efficiency reasons, optionally with a pre-parsed AST. This is useful with + `sqlglot.parse`, which will split a script in multiple already parsed statements. The `engine` parameters comes from the `engine` attribute in a Superset DB engine spec. @@ -148,14 +150,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]): def __init__( self, - statement: str | InternalRepresentation, + statement: str, engine: str, + ast: InternalRepresentation | None = None, ): - self._parsed: InternalRepresentation = ( - self._parse_statement(statement, engine) - if isinstance(statement, str) - else statement - ) + self._sql = statement + self._parsed = ast or self._parse_statement(statement, engine) self.engine = engine self.tables = self._extract_tables_from_statement(self._parsed, self.engine) @@ -239,11 +239,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): def __init__( self, - statement: str | exp.Expression, + statement: str, engine: str, + ast: exp.Expression | None = None, ): self._dialect = SQLGLOT_DIALECTS.get(engine) - super().__init__(statement, engine) + super().__init__(statement, engine, ast) @classmethod def _parse(cls, script: str, engine: str) -> list[exp.Expression]: @@ -275,11 +276,47 @@ def split_script( script: str, engine: str, ) -> list[SQLStatement]: - return [ - cls(statement, engine) - for statement in cls._parse(script, engine) - if statement - ] + if engine in SQLGLOT_DIALECTS: + try: + return [ + cls(ast.sql(), engine, ast) + for ast in cls._parse(script, engine) + if ast + ] + except ValueError: + # `ast.sql()` might raise an error on some cases (eg, `SHOW TABLES + # FROM`). In this case, we rely on the tokenizer to generate the + # statements. + pass + + # When we don't have a sqlglot dialect we can't rely on `ast.sql()` to correctly + # generate the SQL of each statement, so we tokenize the script and split it + # based on the location of semi-colons. + statements = [] + start = 0 + remainder = script + + try: + tokens = sqlglot.tokenize(script) + except sqlglot.errors.TokenError as ex: + raise SupersetParseError( + script, + engine, + message="Unable to tokenize script", + ) from ex + + for token in tokens: + if token.token_type == sqlglot.TokenType.SEMICOLON: + statement, start = script[start : token.start], token.end + 1 + ast = cls._parse(statement, engine)[0] + statements.append(cls(statement.strip(), engine, ast)) + remainder = script[start:] + + if remainder.strip(): + ast = cls._parse(remainder, engine)[0] + statements.append(cls(remainder.strip(), engine, ast)) + + return statements @classmethod def _parse_statement( @@ -349,8 +386,34 @@ def format(self, comments: bool = True) -> str: """ Pretty-format the SQL statement. """ - write = Dialect.get_or_raise(self._dialect) - return write.generate(self._parsed, copy=False, comments=comments, pretty=True) + if self._dialect: + try: + write = Dialect.get_or_raise(self._dialect) + return write.generate( + self._parsed, + copy=False, + comments=comments, + pretty=True, + ) + except ValueError: + pass + + return self._fallback_formatting() + + @deprecated(deprecated_in="4.0", removed_in="5.0") + def _fallback_formatting(self) -> str: + """ + Format SQL without a specific dialect. + + Reformatting SQL using the generic sqlglot dialect is known to break queries. + For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`, which + breaks the query for Firebolt. To avoid this, we use sqlparse for formatting + when the dialect is not known. + + In 5.0 we should remove `sqlparse`, and the method should return the query + unmodified. + """ + return sqlparse.format(self._sql, reindent=True, keyword_case="upper") def get_settings(self) -> dict[str, str | bool]: """ @@ -456,7 +519,9 @@ def split_script( https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string for more information. """ - return [cls(statement, engine) for statement in split_kql(script)] + return [ + cls(statement, engine, statement.strip()) for statement in split_kql(script) + ] @classmethod def _parse_statement( @@ -498,7 +563,7 @@ def format(self, comments: bool = True) -> str: """ Pretty-format the SQL statement. """ - return self._parsed + return self._sql.strip() def get_settings(self) -> dict[str, str | bool]: """ @@ -548,6 +613,9 @@ def __init__( def format(self, comments: bool = True) -> str: """ Pretty-format the SQL script. + + Note that even though KQL is very different from SQL, multiple statements are + still separated by semi-colons. """ return ";\n".join(statement.format(comments) for statement in self.statements) diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py index 19d6e56fb6441..cf1e190bbb9ba 100644 --- a/tests/integration_tests/sql_lab/api_tests.py +++ b/tests/integration_tests/sql_lab/api_tests.py @@ -281,7 +281,7 @@ def test_format_sql_request(self): "/api/v1/sqllab/format_sql/", json=data, ) - success_resp = {"result": "SELECT\n 1\nFROM my_table"} + success_resp = {"result": "SELECT 1\nFROM my_table"} resp_data = json.loads(rv.data.decode("utf-8")) self.assertDictEqual(resp_data, success_resp) # noqa: PT009 assert rv.status_code == 200 diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index a3af155815d1c..d8e632ce09336 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -241,14 +241,7 @@ class NoLimitDBEngineSpec(BaseEngineSpec): latest_partition=False, cols=cols, ) - assert ( - sql - == """SELECT - a -FROM my_table -LIMIT ? -OFFSET ?""" - ) + assert sql == "SELECT a\nFROM my_table\nLIMIT ?\nOFFSET ?" sql = NoLimitDBEngineSpec.select_star( database=database, @@ -260,12 +253,7 @@ class NoLimitDBEngineSpec(BaseEngineSpec): latest_partition=False, cols=cols, ) - assert ( - sql - == """SELECT - a -FROM my_table""" - ) + assert sql == "SELECT a\nFROM my_table" def test_extra_table_metadata(mocker: MockerFixture) -> None: diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index ae5ebf89a8b96..ada6314457fbc 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -284,6 +284,40 @@ def test_extract_tables_show_tables_from() -> None: ) +def test_format_show_tables() -> None: + """ + Test format when `ast.sql()` raises an exception. + + In that case sqlparse should be used instead. + """ + assert ( + SQLScript("SHOW TABLES FROM s1 like '%order%'", "mysql").format() + == "SHOW TABLES FROM s1 LIKE '%order%'" + ) + + +def test_format_no_dialect() -> None: + """ + Test format with an engine that has no corresponding dialect. + """ + assert ( + SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "firebolt").format() + == "SELECT col\nFROM t\nWHERE col NOT IN (1,\n 2)" + ) + + +def test_split_no_dialect() -> None: + """ + Test the statement split when the engine has no corresponding dialect. + """ + sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT foo" + statements = SQLScript(sql, "firebolt").statements + assert len(statements) == 3 + assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)" + assert statements[1]._sql == "SELECT * FROM t" + assert statements[2]._sql == "SELECT foo" + + def test_extract_tables_show_columns_from() -> None: """ Test `SHOW COLUMNS FROM`.