From bdf29cb7c2d61662cbbae8035b5a8cc682140b45 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 13 Sep 2024 16:24:19 -0400 Subject: [PATCH] chore: organize SQL parsing files (#30258) --- superset/db_engine_specs/base.py | 3 +- superset/db_engine_specs/postgres.py | 2 +- superset/exceptions.py | 27 +- superset/models/helpers.py | 3 +- superset/sql/__init__.py | 16 + superset/sql/parse.py | 648 +++++++++++++++++++ superset/sql_lab.py | 3 +- superset/sql_parse.py | 649 +------------------ superset/sqllab/api.py | 2 +- superset/tasks/thumbnails.py | 3 +- tests/unit_tests/sql/__init__.py | 16 + tests/unit_tests/sql/parse_tests.py | 920 +++++++++++++++++++++++++++ tests/unit_tests/sql_parse_tests.py | 244 +------ 13 files changed, 1650 insertions(+), 886 deletions(-) create mode 100644 superset/sql/__init__.py create mode 100644 superset/sql/parse.py create mode 100644 tests/unit_tests/sql/__init__.py create mode 100644 tests/unit_tests/sql/parse_tests.py diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 51ea6caf5e6df..dcdfff6c3f242 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -63,7 +63,8 @@ from superset.databases.utils import get_table_metadata, make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError -from superset.sql_parse import ParsedQuery, SQLScript, Table +from superset.sql.parse import SQLScript, Table +from superset.sql_parse import ParsedQuery from superset.superset_typing import ( OAuth2ClientConfig, OAuth2State, diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 8525ea05da9b6..0a638f65fe136 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -35,7 +35,7 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException, SupersetSecurityException from superset.models.sql_lab import Query -from superset.sql_parse import SQLScript +from superset.sql.parse import SQLScript from superset.utils import core as utils, json from superset.utils.core import GenericDataType diff --git a/superset/exceptions.py b/superset/exceptions.py index a000e08165c47..492007523de3f 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + from collections import defaultdict from typing import Any, Optional @@ -304,12 +307,30 @@ class SupersetParseError(SupersetErrorException): status = 422 - def __init__(self, sql: str, engine: Optional[str] = None): + def __init__( # pylint: disable=too-many-arguments + self, + sql: str, + engine: Optional[str] = None, + message: Optional[str] = None, + highlight: Optional[str] = None, + line: Optional[int] = None, + column: Optional[int] = None, + ): + if message is None: + parts = [_("Error parsing")] + if highlight: + parts.append(_(" near '%(highlight)s'", highlight=highlight)) + if line: + parts.append(_(" at line %(line)d", line=line)) + if column: + parts.append(_(":%(column)d", column=column)) + message = "".join(parts) + error = SupersetError( - message=_("The SQL is invalid and cannot be parsed."), + message=message, error_type=SupersetErrorType.INVALID_SQL_ERROR, level=ErrorLevel.ERROR, - extra={"sql": sql, "engine": engine}, + extra={"sql": sql, "engine": engine, "line": line, "column": column}, ) super().__init__(error) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 295ecea70ea45..80e66f50270c8 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -68,13 +68,12 @@ ) from superset.extensions import feature_flag_manager from superset.jinja_context import BaseTemplateProcessor +from superset.sql.parse import SQLScript, SQLStatement from superset.sql_parse import ( has_table_query, insert_rls_in_predicate, ParsedQuery, sanitize_clause, - SQLScript, - SQLStatement, ) from superset.superset_typing import ( AdhocMetric, diff --git a/superset/sql/__init__.py b/superset/sql/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/superset/sql/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/sql/parse.py b/superset/sql/parse.py new file mode 100644 index 0000000000000..3ec928fabdd3e --- /dev/null +++ b/superset/sql/parse.py @@ -0,0 +1,648 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import enum +import logging +import re +import urllib.parse +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any, Generic, TypeVar + +import sqlglot +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect, Dialects +from sqlglot.errors import ParseError +from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope + +from superset.exceptions import SupersetParseError + +logger = logging.getLogger(__name__) + + +# mapping between DB engine specs and sqlglot dialects +SQLGLOT_DIALECTS = { + "base": Dialects.DIALECT, + "ascend": Dialects.HIVE, + "awsathena": Dialects.PRESTO, + "bigquery": Dialects.BIGQUERY, + "clickhouse": Dialects.CLICKHOUSE, + "clickhousedb": Dialects.CLICKHOUSE, + "cockroachdb": Dialects.POSTGRES, + "couchbase": Dialects.MYSQL, + # "crate": ??? + # "databend": ??? + "databricks": Dialects.DATABRICKS, + # "db2": ??? + # "dremio": ??? + "drill": Dialects.DRILL, + # "druid": ??? + "duckdb": Dialects.DUCKDB, + # "dynamodb": ??? + # "elasticsearch": ??? + # "exa": ??? + # "firebird": ??? + # "firebolt": ??? + "gsheets": Dialects.SQLITE, + "hana": Dialects.POSTGRES, + "hive": Dialects.HIVE, + # "ibmi": ??? + # "impala": ??? + # "kustokql": ??? + # "kylin": ??? + "mssql": Dialects.TSQL, + "mysql": Dialects.MYSQL, + "netezza": Dialects.POSTGRES, + # "ocient": ??? + # "odelasticsearch": ??? + "oracle": Dialects.ORACLE, + # "pinot": ??? + "postgresql": Dialects.POSTGRES, + "presto": Dialects.PRESTO, + "pydoris": Dialects.DORIS, + "redshift": Dialects.REDSHIFT, + # "risingwave": ??? + # "rockset": ??? + "shillelagh": Dialects.SQLITE, + "snowflake": Dialects.SNOWFLAKE, + # "solr": ??? + "spark": Dialects.SPARK, + "sqlite": Dialects.SQLITE, + "starrocks": Dialects.STARROCKS, + "superset": Dialects.SQLITE, + "teradatasql": Dialects.TERADATA, + "trino": Dialects.TRINO, + "vertica": Dialects.POSTGRES, +} + + +@dataclass(eq=True, frozen=True) +class Table: + """ + A fully qualified SQL table conforming to [[catalog.]schema.]table. + """ + + table: str + schema: str | None = None + catalog: str | None = None + + def __str__(self) -> str: + """ + Return the fully qualified SQL table name. + + Should not be used for SQL generation, only for logging and debugging, since the + quoting is not engine-specific. + """ + return ".".join( + urllib.parse.quote(part, safe="").replace(".", "%2E") + for part in [self.catalog, self.schema, self.table] + if part + ) + + def __eq__(self, other: Any) -> bool: + return str(self) == str(other) + + +# To avoid unnecessary parsing/formatting of queries, the statement has the concept of +# an "internal representation", which is the AST of the SQL statement. For most of the +# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special +# case: KustoKQL uses a different syntax and there are no Python parsers for it, so we +# store the AST as a string (the original query), and manipulate it with regular +# expressions. +InternalRepresentation = TypeVar("InternalRepresentation") + +# The base type. This helps type checking the `split_query` method correctly, since each +# derived class has a more specific return type (the class itself). This will no longer +# be needed once Python 3.11 is the lowest version supported. See PEP 673 for more +# information: https://peps.python.org/pep-0673/ +TBaseSQLStatement = TypeVar("TBaseSQLStatement") # pylint: disable=invalid-name + + +class BaseSQLStatement(Generic[InternalRepresentation]): + """ + Base class for SQL statements. + + The class can be instantiated with a string representation of the script or, for + efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`, + which will split a script in multiple already parsed statements. + + The `engine` parameters comes from the `engine` attribute in a Superset DB engine + spec. + """ + + def __init__( + self, + statement: str | InternalRepresentation, + engine: str, + ): + self._parsed: InternalRepresentation = ( + self._parse_statement(statement, engine) + if isinstance(statement, str) + else statement + ) + self.engine = engine + self.tables = self._extract_tables_from_statement(self._parsed, self.engine) + + @classmethod + def split_script( + cls: type[TBaseSQLStatement], + script: str, + engine: str, + ) -> list[TBaseSQLStatement]: + """ + Split a script into multiple instantiated statements. + + This is a helper function to split a full SQL script into multiple + `BaseSQLStatement` instances. It's used by `SQLScript` when instantiating the + statements within a script. + """ + raise NotImplementedError() + + @classmethod + def _parse_statement( + cls, + statement: str, + engine: str, + ) -> InternalRepresentation: + """ + Parse a string containing a single SQL statement, and returns the parsed AST. + + Derived classes should not assume that `statement` contains a single statement, + and MUST explicitly validate that. Since this validation is parser dependent the + responsibility is left to the children classes. + """ + raise NotImplementedError() + + @classmethod + def _extract_tables_from_statement( + cls, + parsed: InternalRepresentation, + engine: str, + ) -> set[Table]: + """ + Extract all table references in a given statement. + """ + raise NotImplementedError() + + def format(self, comments: bool = True) -> str: + """ + Format the statement, optionally ommitting comments. + """ + raise NotImplementedError() + + def get_settings(self) -> dict[str, str | bool]: + """ + Return any settings set by the statement. + + For example, for this statement: + + sql> SET foo = 'bar'; + + The method should return `{"foo": "'bar'"}`. Note the single quotes. + """ + raise NotImplementedError() + + def is_mutating(self) -> bool: + """ + Check if the statement mutates data (DDL/DML). + + :return: True if the statement mutates data. + """ + raise NotImplementedError() + + def __str__(self) -> str: + return self.format() + + +class SQLStatement(BaseSQLStatement[exp.Expression]): + """ + A SQL statement. + + This class is used for all engines with dialects that can be parsed using sqlglot. + """ + + def __init__( + self, + statement: str | exp.Expression, + engine: str, + ): + self._dialect = SQLGLOT_DIALECTS.get(engine) + super().__init__(statement, engine) + + @classmethod + def _parse(cls, script: str, engine: str) -> list[exp.Expression]: + """ + Parse helper. + """ + dialect = SQLGLOT_DIALECTS.get(engine) + try: + return sqlglot.parse(script, dialect=dialect) + except sqlglot.errors.ParseError as ex: + error = ex.errors[0] + raise SupersetParseError( + script, + engine, + highlight=error["highlight"], + line=error["line"], + column=error["col"], + ) from ex + except sqlglot.errors.SqlglotError as ex: + raise SupersetParseError( + script, + engine, + message="Unable to parse script", + ) from ex + + @classmethod + def split_script( + cls, + script: str, + engine: str, + ) -> list[SQLStatement]: + return [ + cls(statement, engine) + for statement in cls._parse(script, engine) + if statement + ] + + @classmethod + def _parse_statement( + cls, + statement: str, + engine: str, + ) -> exp.Expression: + """ + Parse a single SQL statement. + """ + statements = cls.split_script(statement, engine) + if len(statements) != 1: + raise SupersetParseError("SQLStatement should have exactly one statement") + + return statements[0]._parsed # pylint: disable=protected-access + + @classmethod + def _extract_tables_from_statement( + cls, + parsed: exp.Expression, + engine: str, + ) -> set[Table]: + """ + Find all referenced tables. + """ + dialect = SQLGLOT_DIALECTS.get(engine) + return extract_tables_from_statement(parsed, dialect) + + def is_mutating(self) -> bool: + """ + Check if the statement mutates data (DDL/DML). + + :return: True if the statement mutates data. + """ + for node in self._parsed.walk(): + if isinstance( + node, + ( + exp.Insert, + exp.Update, + exp.Delete, + exp.Merge, + exp.Create, + exp.Drop, + exp.TruncateTable, + ), + ): + return True + + if isinstance(node, exp.Command) and node.name == "ALTER": + return True + + # Postgres runs DMLs prefixed by `EXPLAIN ANALYZE`, see + # https://www.postgresql.org/docs/current/sql-explain.html + if ( + self._dialect == Dialects.POSTGRES + and isinstance(self._parsed, exp.Command) + and self._parsed.name == "EXPLAIN" + and self._parsed.expression.name.upper().startswith("ANALYZE ") + ): + analyzed_sql = self._parsed.expression.name[len("ANALYZE ") :] + return SQLStatement(analyzed_sql, self.engine).is_mutating() + + return False + + def format(self, comments: bool = True) -> str: + """ + Pretty-format the SQL statement. + """ + write = Dialect.get_or_raise(self._dialect) + return write.generate(self._parsed, copy=False, comments=comments, pretty=True) + + def get_settings(self) -> dict[str, str | bool]: + """ + Return the settings for the SQL statement. + + >>> statement = SQLStatement("SET foo = 'bar'") + >>> statement.get_settings() + {"foo": "'bar'"} + + """ + return { + eq.this.sql(): eq.expression.sql() + for set_item in self._parsed.find_all(exp.SetItem) + for eq in set_item.find_all(exp.EQ) + } + + +class KQLSplitState(enum.Enum): + """ + State machine for splitting a KQL script. + + The state machine keeps track of whether we're inside a string or not, so we + don't split the script in a semi-colon that's part of a string. + """ + + OUTSIDE_STRING = enum.auto() + INSIDE_SINGLE_QUOTED_STRING = enum.auto() + INSIDE_DOUBLE_QUOTED_STRING = enum.auto() + INSIDE_MULTILINE_STRING = enum.auto() + + +def split_kql(kql: str) -> list[str]: + """ + Custom function for splitting KQL statements. + """ + statements = [] + state = KQLSplitState.OUTSIDE_STRING + statement_start = 0 + script = kql if kql.endswith(";") else kql + ";" + for i, character in enumerate(script): + if state == KQLSplitState.OUTSIDE_STRING: + if character == ";": + statements.append(script[statement_start:i]) + statement_start = i + 1 + elif character == "'": + state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING + elif character == '"': + state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING + elif character == "`" and script[i - 2 : i] == "``": + state = KQLSplitState.INSIDE_MULTILINE_STRING + + elif ( + state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING + and character == "'" + and script[i - 1] != "\\" + ): + state = KQLSplitState.OUTSIDE_STRING + + elif ( + state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING + and character == '"' + and script[i - 1] != "\\" + ): + state = KQLSplitState.OUTSIDE_STRING + + elif ( + state == KQLSplitState.INSIDE_MULTILINE_STRING + and character == "`" + and script[i - 2 : i] == "``" + ): + state = KQLSplitState.OUTSIDE_STRING + + return statements + + +class KustoKQLStatement(BaseSQLStatement[str]): + """ + Special class for Kusto KQL. + + Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look + like this: + + StormEvents + | summarize PropertyDamage = sum(DamageProperty) by State + | join kind=innerunique PopulationData on State + | project State, PropertyDamagePerCapita = PropertyDamage / Population + | sort by PropertyDamagePerCapita + + See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more + details about it. + """ + + @classmethod + def split_script( + cls, + script: str, + engine: str, + ) -> list[KustoKQLStatement]: + """ + Split a script at semi-colons. + + Since we don't have a parser, we use a simple state machine based function. See + https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string + for more information. + """ + return [cls(statement, engine) for statement in split_kql(script)] + + @classmethod + def _parse_statement( + cls, + statement: str, + engine: str, + ) -> str: + if engine != "kustokql": + raise SupersetParseError(f"Invalid engine: {engine}") + + statements = split_kql(statement) + if len(statements) != 1: + raise SupersetParseError("SQLStatement should have exactly one statement") + + return statements[0].strip() + + @classmethod + def _extract_tables_from_statement( + cls, + parsed: str, + engine: str, + ) -> set[Table]: + """ + Extract all tables referenced in the statement. + + StormEvents + | where InjuriesDirect + InjuriesIndirect > 50 + | join (PopulationData) on State + | project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect + + """ + logger.warning( + "Kusto KQL doesn't support table extraction. This means that data access " + "roles will not be enforced by Superset in the database." + ) + return set() + + def format(self, comments: bool = True) -> str: + """ + Pretty-format the SQL statement. + """ + return self._parsed + + def get_settings(self) -> dict[str, str | bool]: + """ + Return the settings for the SQL statement. + + >>> statement = KustoKQLStatement("set querytrace;") + >>> statement.get_settings() + {"querytrace": True} + + """ + set_regex = r"^set\s+(?P\w+)(?:\s*=\s*(?P\w+))?$" + if match := re.match(set_regex, self._parsed, re.IGNORECASE): + return {match.group("name"): match.group("value") or True} + + return {} + + def is_mutating(self) -> bool: + """ + Check if the statement mutates data (DDL/DML). + + :return: True if the statement mutates data. + """ + return self._parsed.startswith(".") and not self._parsed.startswith(".show") + + +class SQLScript: + """ + A SQL script, with 0+ statements. + """ + + # Special engines that can't be parsed using sqlglot. Supporting non-SQL engines + # adds a lot of complexity to Superset, so we should avoid adding new engines to + # this data structure. + special_engines = { + "kustokql": KustoKQLStatement, + } + + def __init__( + self, + script: str, + engine: str, + ): + statement_class = self.special_engines.get(engine, SQLStatement) + self.engine = engine + self.statements = statement_class.split_script(script, engine) + + def format(self, comments: bool = True) -> str: + """ + Pretty-format the SQL script. + """ + return ";\n".join(statement.format(comments) for statement in self.statements) + + def get_settings(self) -> dict[str, str | bool]: + """ + Return the settings for the SQL script. + + >>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'") + >>> statement.get_settings() + {"foo": "'baz'"} + + """ + settings: dict[str, str | bool] = {} + for statement in self.statements: + settings.update(statement.get_settings()) + + return settings + + def has_mutation(self) -> bool: + """ + Check if the script contains mutating statements. + + :return: True if the script contains mutating statements + """ + return any(statement.is_mutating() for statement in self.statements) + + +def extract_tables_from_statement( + statement: exp.Expression, + dialect: Dialects | None, +) -> set[Table]: + """ + Extract all table references in a single statement. + + Please not that this is not trivial; consider the following queries: + + DESCRIBE some_table; + SHOW PARTITIONS FROM some_table; + WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name; + + See the unit tests for other tricky cases. + """ + sources: Iterable[exp.Table] + + if isinstance(statement, exp.Describe): + # A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly + # query for all tables. + sources = statement.find_all(exp.Table) + elif isinstance(statement, exp.Command): + # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a + # `SELECT` statetement in order to extract tables. + literal = statement.find(exp.Literal) + if not literal: + return set() + + try: + pseudo_query = sqlglot.parse_one(f"SELECT {literal.this}", dialect=dialect) + except ParseError: + return set() + sources = pseudo_query.find_all(exp.Table) + else: + sources = [ + source + for scope in traverse_scope(statement) + for source in scope.sources.values() + if isinstance(source, exp.Table) and not is_cte(source, scope) + ] + + return { + Table( + source.name, + source.db if source.db != "" else None, + source.catalog if source.catalog != "" else None, + ) + for source in sources + } + + +def is_cte(source: exp.Table, scope: Scope) -> bool: + """ + Is the source a CTE? + + CTEs in the parent scope look like tables (and are represented by + exp.Table objects), but should not be considered as such; + otherwise a user with access to table `foo` could access any table + with a query like this: + + WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo + + """ + parent_sources = scope.parent.sources if scope.parent else {} + ctes_in_scope = { + name + for name, parent_scope in parent_sources.items() + if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE + } + + return source.name in ctes_in_scope diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 3c35899706127..3d3b2898fafa6 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -51,13 +51,12 @@ from superset.models.core import Database from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet +from superset.sql.parse import SQLStatement, Table from superset.sql_parse import ( CtasMethod, insert_rls_as_subquery, insert_rls_in_predicate, ParsedQuery, - SQLStatement, - Table, ) from superset.sqllab.limiting_factor import LimitingFactor from superset.sqllab.utils import write_ipc_buffer diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 59d6d643eb2a5..91b68126356a2 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -19,23 +19,16 @@ from __future__ import annotations -import enum import logging import re -import urllib.parse -from collections.abc import Iterable, Iterator -from dataclasses import dataclass -from typing import Any, cast, Generic, TYPE_CHECKING, TypeVar +from collections.abc import Iterator +from typing import Any, cast, TYPE_CHECKING -import sqlglot import sqlparse from flask_babel import gettext as __ from jinja2 import nodes from sqlalchemy import and_ -from sqlglot import exp, parse, parse_one -from sqlglot.dialects.dialect import Dialect, Dialects -from sqlglot.errors import ParseError, SqlglotError -from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope +from sqlglot.dialects.dialect import Dialects from sqlparse import keywords from sqlparse.lexer import Lexer from sqlparse.sql import ( @@ -68,6 +61,7 @@ SupersetParseError, SupersetSecurityException, ) +from superset.sql.parse import extract_tables_from_statement, SQLScript, Table from superset.utils.backports import StrEnum try: @@ -226,7 +220,9 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, str]: def check_sql_functions_exist( - sql: str, function_list: set[str], engine: str | None = None + sql: str, + function_list: set[str], + engine: str = "base", ) -> bool: """ Check if the SQL statement contains any of the specified functions. @@ -238,7 +234,7 @@ def check_sql_functions_exist( return ParsedQuery(sql, engine=engine).check_functions_exist(function_list) -def strip_comments_from_sql(statement: str, engine: str | None = None) -> str: +def strip_comments_from_sql(statement: str, engine: str = "base") -> str: """ Strips comments from a SQL statement, does a simple test first to avoid always instantiating the expensive ParsedQuery constructor @@ -255,554 +251,18 @@ def strip_comments_from_sql(statement: str, engine: str | None = None) -> str: ) -@dataclass(eq=True, frozen=True) -class Table: - """ - A fully qualified SQL table conforming to [[catalog.]schema.]table. - """ - - table: str - schema: str | None = None - catalog: str | None = None - - def __str__(self) -> str: - """ - Return the fully qualified SQL table name. - """ - - return ".".join( - urllib.parse.quote(part, safe="").replace(".", "%2E") - for part in [self.catalog, self.schema, self.table] - if part - ) - - def __eq__(self, __o: object) -> bool: - return str(self) == str(__o) - - -def extract_tables_from_statement( - statement: exp.Expression, - dialect: Dialects | None, -) -> set[Table]: - """ - Extract all table references in a single statement. - - Please not that this is not trivial; consider the following queries: - - DESCRIBE some_table; - SHOW PARTITIONS FROM some_table; - WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name; - - See the unit tests for other tricky cases. - """ - sources: Iterable[exp.Table] - - if isinstance(statement, exp.Describe): - # A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly - # query for all tables. - sources = statement.find_all(exp.Table) - elif isinstance(statement, exp.Command): - # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a - # `SELECT` statetement in order to extract tables. - literal = statement.find(exp.Literal) - if not literal: - return set() - - try: - pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect) - except ParseError: - return set() - sources = pseudo_query.find_all(exp.Table) - else: - sources = [ - source - for scope in traverse_scope(statement) - for source in scope.sources.values() - if isinstance(source, exp.Table) and not is_cte(source, scope) - ] - - return { - Table( - source.name, - source.db if source.db != "" else None, - source.catalog if source.catalog != "" else None, - ) - for source in sources - } - - -def is_cte(source: exp.Table, scope: Scope) -> bool: - """ - Is the source a CTE? - - CTEs in the parent scope look like tables (and are represented by - exp.Table objects), but should not be considered as such; - otherwise a user with access to table `foo` could access any table - with a query like this: - - WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo - - """ - parent_sources = scope.parent.sources if scope.parent else {} - ctes_in_scope = { - name - for name, parent_scope in parent_sources.items() - if isinstance(parent_scope, Scope) and parent_scope.scope_type == ScopeType.CTE - } - - return source.name in ctes_in_scope - - -# To avoid unnecessary parsing/formatting of queries, the statement has the concept of -# an "internal representation", which is the AST of the SQL statement. For most of the -# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special -# case: KustoKQL uses a different syntax and there are no Python parsers for it, so we -# store the AST as a string (the original query), and manipulate it with regular -# expressions. -InternalRepresentation = TypeVar("InternalRepresentation") - -# The base type. This helps type checking the `split_query` method correctly, since each -# derived class has a more specific return type (the class itself). This will no longer -# be needed once Python 3.11 is the lowest version supported. See PEP 673 for more -# information: https://peps.python.org/pep-0673/ -TBaseSQLStatement = TypeVar("TBaseSQLStatement") # pylint: disable=invalid-name - - -class BaseSQLStatement(Generic[InternalRepresentation]): - """ - Base class for SQL statements. - - The class can be instantiated with a string representation of the query or, for - efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`, - which will split a query in multiple already parsed statements. - - The `engine` parameters comes from the `engine` attribute in a Superset DB engine - spec. - """ - - def __init__( - self, - statement: str | InternalRepresentation, - engine: str, - ): - self._parsed: InternalRepresentation = ( - self._parse_statement(statement, engine) - if isinstance(statement, str) - else statement - ) - self.engine = engine - self.tables = self._extract_tables_from_statement(self._parsed, self.engine) - - @classmethod - def split_query( - cls: type[TBaseSQLStatement], - query: str, - engine: str, - ) -> list[TBaseSQLStatement]: - """ - Split a query into multiple instantiated statements. - - This is a helper function to split a full SQL query into multiple - `BaseSQLStatement` instances. It's used by `SQLScript` when instantiating the - statements within a query. - """ - raise NotImplementedError() - - @classmethod - def _parse_statement( - cls, - statement: str, - engine: str, - ) -> InternalRepresentation: - """ - Parse a string containing a single SQL statement, and returns the parsed AST. - - Derived classes should not assume that `statement` contains a single statement, - and MUST explicitly validate that. Since this validation is parser dependent the - responsibility is left to the children classes. - """ - raise NotImplementedError() - - @classmethod - def _extract_tables_from_statement( - cls, - parsed: InternalRepresentation, - engine: str, - ) -> set[Table]: - """ - Extract all table references in a given statement. - """ - raise NotImplementedError() - - def format(self, comments: bool = True) -> str: - """ - Format the statement, optionally ommitting comments. - """ - raise NotImplementedError() - - def get_settings(self) -> dict[str, str | bool]: - """ - Return any settings set by the statement. - - For example, for this statement: - - sql> SET foo = 'bar'; - - The method should return `{"foo": "'bar'"}`. Note the single quotes. - """ - raise NotImplementedError() - - def is_mutating(self) -> bool: - """ - Check if the statement mutates data (DDL/DML). - - :return: True if the statement mutates data. - """ - raise NotImplementedError() - - def __str__(self) -> str: - return self.format() - - -class SQLStatement(BaseSQLStatement[exp.Expression]): - """ - A SQL statement. - - This class is used for all engines with dialects that can be parsed using sqlglot. - """ - - def __init__( - self, - statement: str | exp.Expression, - engine: str, - ): - self._dialect = SQLGLOT_DIALECTS.get(engine) - super().__init__(statement, engine) - - @classmethod - def split_query( - cls, - query: str, - engine: str, - ) -> list[SQLStatement]: - dialect = SQLGLOT_DIALECTS.get(engine) - - try: - statements = sqlglot.parse(query, dialect=dialect) - except sqlglot.errors.ParseError as ex: - raise SupersetParseError("Unable to split query") from ex - - return [cls(statement, engine) for statement in statements if statement] - - @classmethod - def _parse_statement( - cls, - statement: str, - engine: str, - ) -> exp.Expression: - """ - Parse a single SQL statement. - """ - dialect = SQLGLOT_DIALECTS.get(engine) - - # We could parse with `sqlglot.parse_one` to get a single statement, but we need - # to verify that the string contains exactly one statement. - try: - statements = sqlglot.parse(statement, dialect=dialect) - except sqlglot.errors.ParseError as ex: - raise SupersetParseError("Unable to split query") from ex - - statements = [statement for statement in statements if statement] - if len(statements) != 1: - raise SupersetParseError("SQLStatement should have exactly one statement") - - return statements[0] - - @classmethod - def _extract_tables_from_statement( - cls, - parsed: exp.Expression, - engine: str, - ) -> set[Table]: - """ - Find all referenced tables. - """ - dialect = SQLGLOT_DIALECTS.get(engine) - return extract_tables_from_statement(parsed, dialect) - - def is_mutating(self) -> bool: - """ - Check if the statement mutates data (DDL/DML). - - :return: True if the statement mutates data. - """ - for node in self._parsed.walk(): - if isinstance( - node, - ( - exp.Insert, - exp.Update, - exp.Delete, - exp.Merge, - exp.Create, - exp.Drop, - exp.TruncateTable, - ), - ): - return True - - if isinstance(node, exp.Command) and node.name == "ALTER": - return True - - # Postgres runs DMLs prefixed by `EXPLAIN ANALYZE`, see - # https://www.postgresql.org/docs/current/sql-explain.html - if ( - self._dialect == Dialects.POSTGRES - and isinstance(self._parsed, exp.Command) - and self._parsed.name == "EXPLAIN" - and self._parsed.expression.name.upper().startswith("ANALYZE ") - ): - analyzed_sql = self._parsed.expression.name[len("ANALYZE ") :] - return SQLStatement(analyzed_sql, self.engine).is_mutating() - - return False - - def format(self, comments: bool = True) -> str: - """ - Pretty-format the SQL statement. - """ - write = Dialect.get_or_raise(self._dialect) - return write.generate(self._parsed, copy=False, comments=comments, pretty=True) - - def get_settings(self) -> dict[str, str | bool]: - """ - Return the settings for the SQL statement. - - >>> statement = SQLStatement("SET foo = 'bar'") - >>> statement.get_settings() - {"foo": "'bar'"} - - """ - return { - eq.this.sql(): eq.expression.sql() - for set_item in self._parsed.find_all(exp.SetItem) - for eq in set_item.find_all(exp.EQ) - } - - -class KQLSplitState(enum.Enum): - """ - State machine for splitting a KQL query. - - The state machine keeps track of whether we're inside a string or not, so we - don't split the query in a semi-colon that's part of a string. - """ - - OUTSIDE_STRING = enum.auto() - INSIDE_SINGLE_QUOTED_STRING = enum.auto() - INSIDE_DOUBLE_QUOTED_STRING = enum.auto() - INSIDE_MULTILINE_STRING = enum.auto() - - -def split_kql(kql: str) -> list[str]: - """ - Custom function for splitting KQL statements. - """ - statements = [] - state = KQLSplitState.OUTSIDE_STRING - statement_start = 0 - query = kql if kql.endswith(";") else kql + ";" - for i, character in enumerate(query): - if state == KQLSplitState.OUTSIDE_STRING: - if character == ";": - statements.append(query[statement_start:i]) - statement_start = i + 1 - elif character == "'": - state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING - elif character == '"': - state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING - elif character == "`" and query[i - 2 : i] == "``": - state = KQLSplitState.INSIDE_MULTILINE_STRING - - elif ( - state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING - and character == "'" - and query[i - 1] != "\\" - ): - state = KQLSplitState.OUTSIDE_STRING - - elif ( - state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING - and character == '"' - and query[i - 1] != "\\" - ): - state = KQLSplitState.OUTSIDE_STRING - - elif ( - state == KQLSplitState.INSIDE_MULTILINE_STRING - and character == "`" - and query[i - 2 : i] == "``" - ): - state = KQLSplitState.OUTSIDE_STRING - - return statements - - -class KustoKQLStatement(BaseSQLStatement[str]): - """ - Special class for Kusto KQL. - - Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look - like this: - - StormEvents - | summarize PropertyDamage = sum(DamageProperty) by State - | join kind=innerunique PopulationData on State - | project State, PropertyDamagePerCapita = PropertyDamage / Population - | sort by PropertyDamagePerCapita - - See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more - details about it. - """ - - @classmethod - def split_query( - cls, - query: str, - engine: str, - ) -> list[KustoKQLStatement]: - """ - Split a query at semi-colons. - - Since we don't have a parser, we use a simple state machine based function. See - https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string - for more information. - """ - return [cls(statement, engine) for statement in split_kql(query)] - - @classmethod - def _parse_statement( - cls, - statement: str, - engine: str, - ) -> str: - if engine != "kustokql": - raise SupersetParseError(f"Invalid engine: {engine}") - - statements = split_kql(statement) - if len(statements) != 1: - raise SupersetParseError("SQLStatement should have exactly one statement") - - return statements[0].strip() - - @classmethod - def _extract_tables_from_statement(cls, parsed: str, engine: str) -> set[Table]: - """ - Extract all tables referenced in the statement. - - StormEvents - | where InjuriesDirect + InjuriesIndirect > 50 - | join (PopulationData) on State - | project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect - - """ - logger.warning( - "Kusto KQL doesn't support table extraction. This means that data access " - "roles will not be enforced by Superset in the database." - ) - return set() - - def format(self, comments: bool = True) -> str: - """ - Pretty-format the SQL statement. - """ - return self._parsed - - def get_settings(self) -> dict[str, str | bool]: - """ - Return the settings for the SQL statement. - - >>> statement = KustoKQLStatement("set querytrace;") - >>> statement.get_settings() - {"querytrace": True} - - """ - set_regex = r"^set\s+(?P\w+)(?:\s*=\s*(?P\w+))?$" - if match := re.match(set_regex, self._parsed, re.IGNORECASE): - return {match.group("name"): match.group("value") or True} - - return {} - - def is_mutating(self) -> bool: - """ - Check if the statement mutates data (DDL/DML). - - :return: True if the statement mutates data. - """ - return self._parsed.startswith(".") and not self._parsed.startswith(".show") - - -class SQLScript: - """ - A SQL script, with 0+ statements. - """ - - # Special engines that can't be parsed using sqlglot. Supporting non-SQL engines - # adds a lot of complexity to Superset, so we should avoid adding new engines to - # this data structure. - special_engines = { - "kustokql": KustoKQLStatement, - } - - def __init__( - self, - query: str, - engine: str, - ): - statement_class = self.special_engines.get(engine, SQLStatement) - self.statements = statement_class.split_query(query, engine) - - def format(self, comments: bool = True) -> str: - """ - Pretty-format the SQL query. - """ - return ";\n".join(statement.format(comments) for statement in self.statements) - - def get_settings(self) -> dict[str, str | bool]: - """ - Return the settings for the SQL query. - - >>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'") - >>> statement.get_settings() - {"foo": "'baz'"} - - """ - settings: dict[str, str | bool] = {} - for statement in self.statements: - settings.update(statement.get_settings()) - - return settings - - def has_mutation(self) -> bool: - """ - Check if the script contains mutating statements. - - :return: True if the script contains mutating statements - """ - return any(statement.is_mutating() for statement in self.statements) - - class ParsedQuery: def __init__( self, sql_statement: str, strip_comments: bool = False, - engine: str | None = None, + engine: str = "base", ): if strip_comments: sql_statement = sqlparse.format(sql_statement, strip_comments=True) self.sql: str = sql_statement + self._engine = engine self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None self._tables: set[Table] = set() self._alias_names: set[str] = set() @@ -854,24 +314,18 @@ def _extract_tables_from_sql(self) -> set[Table]: Note: this uses sqlglot, since it's better at catching more edge cases. """ try: - statements = parse(self.stripped(), dialect=self._dialect) - except SqlglotError as ex: + statements = [ + statement._parsed # pylint: disable=protected-access + for statement in SQLScript(self.stripped(), self._engine).statements + ] + except SupersetParseError as ex: logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql) - - message = ( - "Error parsing near '{highlight}' at line {line}:{col}".format( # pylint: disable=consider-using-f-string - **ex.errors[0] - ) - if isinstance(ex, ParseError) - else str(ex) - ) - raise SupersetSecurityException( SupersetError( error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR, message=__( "You may have an error in your SQL statement. {message}" - ).format(message=message), + ).format(message=ex.error.message), level=ErrorLevel.ERROR, ) ) from ex @@ -883,77 +337,6 @@ def _extract_tables_from_sql(self) -> set[Table]: if statement } - def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]: - """ - Extract all table references in a single statement. - - Please not that this is not trivial; consider the following queries: - - DESCRIBE some_table; - SHOW PARTITIONS FROM some_table; - WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name; - - See the unit tests for other tricky cases. - """ - sources: Iterable[exp.Table] - - if isinstance(statement, exp.Describe): - # A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly - # query for all tables. - sources = statement.find_all(exp.Table) - elif isinstance(statement, exp.Command): - # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a - # `SELECT` statetement in order to extract tables. - if not (literal := statement.find(exp.Literal)): - return set() - - try: - pseudo_query = parse_one( - f"SELECT {literal.this}", - dialect=self._dialect, - ) - sources = pseudo_query.find_all(exp.Table) - except SqlglotError: - return set() - else: - sources = [ - source - for scope in traverse_scope(statement) - for source in scope.sources.values() - if isinstance(source, exp.Table) and not self._is_cte(source, scope) - ] - - return { - Table( - source.name, - source.db if source.db != "" else None, - source.catalog if source.catalog != "" else None, - ) - for source in sources - } - - def _is_cte(self, source: exp.Table, scope: Scope) -> bool: - """ - Is the source a CTE? - - CTEs in the parent scope look like tables (and are represented by - exp.Table objects), but should not be considered as such; - otherwise a user with access to table `foo` could access any table - with a query like this: - - WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo - - """ - parent_sources = scope.parent.sources if scope.parent else {} - ctes_in_scope = { - name - for name, parent_scope in parent_sources.items() - if isinstance(parent_scope, Scope) - and parent_scope.scope_type == ScopeType.CTE - } - - return source.name in ctes_in_scope - @property def limit(self) -> int | None: return self._limit diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py index f7d66ed4e19fa..2403a36583e7b 100644 --- a/superset/sqllab/api.py +++ b/superset/sqllab/api.py @@ -35,8 +35,8 @@ from superset.extensions import event_logger from superset.jinja_context import get_template_processor from superset.models.sql_lab import Query +from superset.sql.parse import SQLScript from superset.sql_lab import get_sql_results -from superset.sql_parse import SQLScript from superset.sqllab.command_status import SqlJsonExecutionStatus from superset.sqllab.exceptions import ( QueryIsForbiddenToAccessException, diff --git a/superset/tasks/thumbnails.py b/superset/tasks/thumbnails.py index 1c5db2a7f3d45..35781608b8abf 100644 --- a/superset/tasks/thumbnails.py +++ b/superset/tasks/thumbnails.py @@ -108,9 +108,8 @@ def cache_dashboard_thumbnail( ) -# pylint: disable=too-many-arguments @celery_app.task(name="cache_dashboard_screenshot", soft_time_limit=300) -def cache_dashboard_screenshot( +def cache_dashboard_screenshot( # pylint: disable=too-many-arguments username: str, dashboard_id: int, dashboard_url: str, diff --git a/tests/unit_tests/sql/__init__.py b/tests/unit_tests/sql/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/unit_tests/sql/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py new file mode 100644 index 0000000000000..f5d55bc13bd0a --- /dev/null +++ b/tests/unit_tests/sql/parse_tests.py @@ -0,0 +1,920 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, redefined-outer-name, too-many-lines + + +import pytest + +from superset.exceptions import SupersetParseError +from superset.sql.parse import ( + extract_tables_from_statement, + KustoKQLStatement, + split_kql, + SQLGLOT_DIALECTS, + SQLScript, + SQLStatement, + Table, +) + + +def test_table() -> None: + """ + Test the `Table` class and its string conversion. + + Special characters in the table, schema, or catalog name should be escaped correctly. + """ + assert str(Table("tbname")) == "tbname" + assert str(Table("tbname", "schemaname")) == "schemaname.tbname" + assert ( + str(Table("tbname", "schemaname", "catalogname")) + == "catalogname.schemaname.tbname" + ) + assert ( + str(Table("table.name", "schema/name", "catalog\nname")) + == "catalog%0Aname.schema%2Fname.table%2Ename" + ) + + +def extract_tables_from_sql(sql: str, engine: str = "postgresql") -> set[Table]: + """ + Helper function to extract tables from SQL. + """ + dialect = SQLGLOT_DIALECTS.get(engine) + return { + table + for statement in SQLScript(sql, engine).statements + for table in extract_tables_from_statement(statement._parsed, dialect) + } + + +def test_extract_tables_from_sql() -> None: + """ + Test that referenced tables are parsed correctly from the SQL. + """ + assert extract_tables_from_sql("SELECT * FROM tbname") == {Table("tbname")} + assert extract_tables_from_sql("SELECT * FROM tbname foo") == {Table("tbname")} + assert extract_tables_from_sql("SELECT * FROM tbname AS foo") == {Table("tbname")} + + # underscore + assert extract_tables_from_sql("SELECT * FROM tb_name") == {Table("tb_name")} + + # quotes + assert extract_tables_from_sql('SELECT * FROM "tbname"') == {Table("tbname")} + + # unicode + assert extract_tables_from_sql('SELECT * FROM "tb_name" WHERE city = "Lübeck"') == { + Table("tb_name") + } + + # columns + assert extract_tables_from_sql("SELECT field1, field2 FROM tb_name") == { + Table("tb_name") + } + assert extract_tables_from_sql("SELECT t1.f1, t2.f2 FROM t1, t2") == { + Table("t1"), + Table("t2"), + } + + # named table + assert extract_tables_from_sql( + "SELECT a.date, a.field FROM left_table a LIMIT 10" + ) == {Table("left_table")} + + assert extract_tables_from_sql( + "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;" + ) == {Table("forbidden_table")} + + assert extract_tables_from_sql( + "select * from (select * from forbidden_table) forbidden_table" + ) == {Table("forbidden_table")} + + +def test_extract_tables_subselect() -> None: + """ + Test that tables inside subselects are parsed correctly. + """ + assert extract_tables_from_sql( + """ +SELECT sub.* +FROM ( + SELECT * + FROM s1.t1 + WHERE day_of_week = 'Friday' + ) sub, s2.t2 +WHERE sub.resolution = 'NONE' +""" + ) == {Table("t1", "s1"), Table("t2", "s2")} + + assert extract_tables_from_sql( + """ +SELECT sub.* +FROM ( + SELECT * + FROM s1.t1 + WHERE day_of_week = 'Friday' +) sub +WHERE sub.resolution = 'NONE' +""" + ) == {Table("t1", "s1")} + + assert extract_tables_from_sql( + """ +SELECT * FROM t1 +WHERE s11 > ANY ( + SELECT COUNT(*) /* no hint */ FROM t2 + WHERE NOT EXISTS ( + SELECT * FROM t3 + WHERE ROW(5*t2.s1,77)=( + SELECT 50,11*s1 FROM t4 + ) + ) +) +""" + ) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")} + + +def test_extract_tables_select_in_expression() -> None: + """ + Test that parser works with `SELECT`s used as expressions. + """ + assert extract_tables_from_sql("SELECT f1, (SELECT count(1) FROM t2) FROM t1") == { + Table("t1"), + Table("t2"), + } + assert extract_tables_from_sql( + "SELECT f1, (SELECT count(1) FROM t2) as f2 FROM t1" + ) == { + Table("t1"), + Table("t2"), + } + + +def test_extract_tables_parenthesis() -> None: + """ + Test that parenthesis are parsed correctly. + """ + assert extract_tables_from_sql("SELECT f1, (x + y) AS f2 FROM t1") == {Table("t1")} + + +def test_extract_tables_with_schema() -> None: + """ + Test that schemas are parsed correctly. + """ + assert extract_tables_from_sql("SELECT * FROM schemaname.tbname") == { + Table("tbname", "schemaname") + } + assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname"') == { + Table("tbname", "schemaname") + } + assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname" foo') == { + Table("tbname", "schemaname") + } + assert extract_tables_from_sql('SELECT * FROM "schemaname"."tbname" AS foo') == { + Table("tbname", "schemaname") + } + + +def test_extract_tables_union() -> None: + """ + Test that `UNION` queries work as expected. + """ + assert extract_tables_from_sql("SELECT * FROM t1 UNION SELECT * FROM t2") == { + Table("t1"), + Table("t2"), + } + assert extract_tables_from_sql("SELECT * FROM t1 UNION ALL SELECT * FROM t2") == { + Table("t1"), + Table("t2"), + } + assert extract_tables_from_sql( + "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2" + ) == { + Table("t1"), + Table("t2"), + } + + +def test_extract_tables_select_from_values() -> None: + """ + Test that selecting from values returns no tables. + """ + assert extract_tables_from_sql("SELECT * FROM VALUES (13, 42)") == set() + + +def test_extract_tables_select_array() -> None: + """ + Test that queries selecting arrays work as expected. + """ + assert extract_tables_from_sql( + """ +SELECT ARRAY[1, 2, 3] AS my_array +FROM t1 LIMIT 10 +""" + ) == {Table("t1")} + + +def test_extract_tables_select_if() -> None: + """ + Test that queries with an `IF` work as expected. + """ + assert extract_tables_from_sql( + """ +SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL) +FROM t1 LIMIT 10 +""" + ) == {Table("t1")} + + +def test_extract_tables_with_catalog() -> None: + """ + Test that catalogs are parsed correctly. + """ + assert extract_tables_from_sql("SELECT * FROM catalogname.schemaname.tbname") == { + Table("tbname", "schemaname", "catalogname") + } + + +def test_extract_tables_illdefined() -> None: + """ + Test that ill-defined tables return an empty set. + """ + with pytest.raises(SupersetParseError) as excinfo: + extract_tables_from_sql("SELECT * FROM schemaname.") + assert str(excinfo.value) == "Error parsing near '.' at line 1:25" + + with pytest.raises(SupersetParseError) as excinfo: + extract_tables_from_sql("SELECT * FROM catalogname.schemaname.") + assert str(excinfo.value) == "Error parsing near '.' at line 1:37" + + with pytest.raises(SupersetParseError) as excinfo: + extract_tables_from_sql("SELECT * FROM catalogname..") + assert str(excinfo.value) == "Error parsing near '.' at line 1:27" + + with pytest.raises(SupersetParseError) as excinfo: + extract_tables_from_sql('SELECT * FROM "tbname') + assert str(excinfo.value) == "Unable to parse script" + + # odd edge case that works + assert extract_tables_from_sql("SELECT * FROM catalogname..tbname") == { + Table(table="tbname", schema=None, catalog="catalogname") + } + + +def test_extract_tables_show_tables_from() -> None: + """ + Test `SHOW TABLES FROM`. + """ + assert ( + extract_tables_from_sql("SHOW TABLES FROM s1 like '%order%'", "mysql") == set() + ) + + +def test_extract_tables_show_columns_from() -> None: + """ + Test `SHOW COLUMNS FROM`. + """ + assert extract_tables_from_sql("SHOW COLUMNS FROM t1") == {Table("t1")} + + +def test_extract_tables_where_subquery() -> None: + """ + Test that tables in a `WHERE` subquery are parsed correctly. + """ + assert extract_tables_from_sql( + """ +SELECT name +FROM t1 +WHERE regionkey = (SELECT max(regionkey) FROM t2) +""" + ) == {Table("t1"), Table("t2")} + + assert extract_tables_from_sql( + """ +SELECT name +FROM t1 +WHERE regionkey IN (SELECT regionkey FROM t2) +""" + ) == {Table("t1"), Table("t2")} + + assert extract_tables_from_sql( + """ +SELECT name +FROM t1 +WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey); +""" + ) == {Table("t1"), Table("t2")} + + +def test_extract_tables_describe() -> None: + """ + Test `DESCRIBE`. + """ + assert extract_tables_from_sql("DESCRIBE t1") == {Table("t1")} + + +def test_extract_tables_show_partitions() -> None: + """ + Test `SHOW PARTITIONS`. + """ + assert extract_tables_from_sql( + """ +SHOW PARTITIONS FROM orders +WHERE ds >= '2013-01-01' ORDER BY ds DESC +""" + ) == {Table("orders")} + + +def test_extract_tables_join() -> None: + """ + Test joins. + """ + assert extract_tables_from_sql( + "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;" + ) == { + Table("t1"), + Table("t2"), + } + + assert extract_tables_from_sql( + """ +SELECT a.date, b.name +FROM left_table a +JOIN ( + SELECT + CAST((b.year) as VARCHAR) date, + name + FROM right_table +) b +ON a.date = b.date +""" + ) == {Table("left_table"), Table("right_table")} + + assert extract_tables_from_sql( + """ +SELECT a.date, b.name +FROM left_table a +LEFT INNER JOIN ( + SELECT + CAST((b.year) as VARCHAR) date, + name + FROM right_table +) b +ON a.date = b.date +""" + ) == {Table("left_table"), Table("right_table")} + + assert extract_tables_from_sql( + """ +SELECT a.date, b.name +FROM left_table a +RIGHT OUTER JOIN ( + SELECT + CAST((b.year) as VARCHAR) date, + name + FROM right_table +) b +ON a.date = b.date +""" + ) == {Table("left_table"), Table("right_table")} + + assert extract_tables_from_sql( + """ +SELECT a.date, b.name +FROM left_table a +FULL OUTER JOIN ( + SELECT + CAST((b.year) as VARCHAR) date, + name + FROM right_table +) b +ON a.date = b.date +""" + ) == {Table("left_table"), Table("right_table")} + + +def test_extract_tables_semi_join() -> None: + """ + Test `LEFT SEMI JOIN`. + """ + assert extract_tables_from_sql( + """ +SELECT a.date, b.name +FROM left_table a +LEFT SEMI JOIN ( + SELECT + CAST((b.year) as VARCHAR) date, + name + FROM right_table +) b +ON a.data = b.date +""" + ) == {Table("left_table"), Table("right_table")} + + +def test_extract_tables_combinations() -> None: + """ + Test a complex case with nested queries. + """ + assert extract_tables_from_sql( + """ +SELECT * FROM t1 +WHERE s11 > ANY ( + SELECT * FROM t1 UNION ALL SELECT * FROM ( + SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a + ) tmp_join + WHERE NOT EXISTS ( + SELECT * FROM t3 + WHERE ROW(5*t3.s1,77)=( + SELECT 50,11*s1 FROM t4 + ) + ) +) +""" + ) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")} + + assert extract_tables_from_sql( + """ +SELECT * FROM ( + SELECT * FROM ( + SELECT * FROM ( + SELECT * FROM EmployeeS + ) AS S1 + ) AS S2 +) AS S3 +""" + ) == {Table("EmployeeS")} + + +def test_extract_tables_with() -> None: + """ + Test `WITH`. + """ + assert extract_tables_from_sql( + """ +WITH + x AS (SELECT a FROM t1), + y AS (SELECT a AS b FROM t2), + z AS (SELECT b AS c FROM t3) +SELECT c FROM z +""" + ) == {Table("t1"), Table("t2"), Table("t3")} + + assert extract_tables_from_sql( + """ +WITH + x AS (SELECT a FROM t1), + y AS (SELECT a AS b FROM x), + z AS (SELECT b AS c FROM y) +SELECT c FROM z +""" + ) == {Table("t1")} + + +def test_extract_tables_reusing_aliases() -> None: + """ + Test that the parser follows aliases. + """ + assert extract_tables_from_sql( + """ +with q1 as ( select key from q2 where key = '5'), +q2 as ( select key from src where key = '5') +select * from (select key from q1) a +""" + ) == {Table("src")} + + # weird query with circular dependency + assert ( + extract_tables_from_sql( + """ +with src as ( select key from q2 where key = '5'), +q2 as ( select key from src where key = '5') +select * from (select key from src) a +""" + ) + == set() + ) + + +def test_extract_tables_multistatement() -> None: + """ + Test that the parser works with multiple statements. + """ + assert extract_tables_from_sql("SELECT * FROM t1; SELECT * FROM t2") == { + Table("t1"), + Table("t2"), + } + assert extract_tables_from_sql("SELECT * FROM t1; SELECT * FROM t2;") == { + Table("t1"), + Table("t2"), + } + assert extract_tables_from_sql( + "ADD JAR file:///hive.jar; SELECT * FROM t1;", + engine="hive", + ) == {Table("t1")} + + +def test_extract_tables_complex() -> None: + """ + Test a few complex queries. + """ + assert extract_tables_from_sql( + """ +SELECT sum(m_examples) AS "sum__m_example" +FROM ( + SELECT + COUNT(DISTINCT id_userid) AS m_examples, + some_more_info + FROM my_b_table b + JOIN my_t_table t ON b.ds=t.ds + JOIN my_l_table l ON b.uid=l.uid + WHERE + b.rid IN ( + SELECT other_col + FROM inner_table + ) + AND l.bla IN ('x', 'y') + GROUP BY 2 + ORDER BY 2 ASC +) AS "meh" +ORDER BY "sum__m_example" DESC +LIMIT 10; +""" + ) == { + Table("my_l_table"), + Table("my_b_table"), + Table("my_t_table"), + Table("inner_table"), + } + + assert extract_tables_from_sql( + """ +SELECT * +FROM table_a AS a, table_b AS b, table_c as c +WHERE a.id = b.id and b.id = c.id +""" + ) == {Table("table_a"), Table("table_b"), Table("table_c")} + + assert extract_tables_from_sql( + """ +SELECT somecol AS somecol +FROM ( + WITH bla AS ( + SELECT col_a + FROM a + WHERE + 1=1 + AND column_of_choice NOT IN ( + SELECT interesting_col + FROM b + ) + ), + rb AS ( + SELECT yet_another_column + FROM ( + SELECT a + FROM c + GROUP BY the_other_col + ) not_table + LEFT JOIN bla foo + ON foo.prop = not_table.bad_col0 + WHERE 1=1 + GROUP BY + not_table.bad_col1 , + not_table.bad_col2 , + ORDER BY not_table.bad_col_3 DESC , + not_table.bad_col4 , + not_table.bad_col5 + ) + SELECT random_col + FROM d + WHERE 1=1 + UNION ALL SELECT even_more_cols + FROM e + WHERE 1=1 + UNION ALL SELECT lets_go_deeper + FROM f + WHERE 1=1 + WHERE 2=2 + GROUP BY last_col + LIMIT 50000 +) +""" + ) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")} + + +def test_extract_tables_mixed_from_clause() -> None: + """ + Test that the parser handles a `FROM` clause with table and subselect. + """ + assert extract_tables_from_sql( + """ +SELECT * +FROM table_a AS a, (select * from table_b) AS b, table_c as c +WHERE a.id = b.id and b.id = c.id +""" + ) == {Table("table_a"), Table("table_b"), Table("table_c")} + + +def test_extract_tables_nested_select() -> None: + """ + Test that the parser handles selects inside functions. + """ + assert extract_tables_from_sql( + """ +select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME) +from INFORMATION_SCHEMA.COLUMNS +WHERE TABLE_SCHEMA like "%bi%"),0x7e))); +""", + "mysql", + ) == {Table("COLUMNS", "INFORMATION_SCHEMA")} + + assert extract_tables_from_sql( + """ +select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME) +from INFORMATION_SCHEMA.COLUMNS +WHERE TABLE_NAME="bi_achievement_daily"),0x7e))); +""", + "mysql", + ) == {Table("COLUMNS", "INFORMATION_SCHEMA")} + + +def test_extract_tables_complex_cte_with_prefix() -> None: + """ + Test that the parser handles CTEs with prefixes. + """ + assert extract_tables_from_sql( + """ +WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear) +AS ( + SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear + FROM SalesOrderHeader + WHERE SalesPersonID IS NOT NULL +) +SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear +FROM CTE__test +GROUP BY SalesYear, SalesPersonID +ORDER BY SalesPersonID, SalesYear; +""" + ) == {Table("SalesOrderHeader")} + + +def test_extract_tables_identifier_list_with_keyword_as_alias() -> None: + """ + Test that aliases that are keywords are parsed correctly. + """ + assert extract_tables_from_sql( + """ +WITH + f AS (SELECT * FROM foo), + match AS (SELECT * FROM f) +SELECT * FROM match +""" + ) == {Table("foo")} + + +def test_sqlscript() -> None: + """ + Test the `SQLScript` class. + """ + script = SQLScript("SELECT 1; SELECT 2;", "sqlite") + + assert len(script.statements) == 2 + assert script.format() == "SELECT\n 1;\nSELECT\n 2" + assert script.statements[0].format() == "SELECT\n 1" + + script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite") + assert script.get_settings() == {"a": "2"} + + query = SQLScript( + """set querytrace; +Events | take 100""", + "kustokql", + ) + assert query.get_settings() == {"querytrace": True} + + +def test_sqlstatement() -> None: + """ + Test the `SQLStatement` class. + """ + statement = SQLStatement( + "SELECT * FROM table1 UNION ALL SELECT * FROM table2", + "sqlite", + ) + + assert statement.tables == { + Table(table="table1", schema=None, catalog=None), + Table(table="table2", schema=None, catalog=None), + } + assert ( + statement.format() + == "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2" + ) + + statement = SQLStatement("SET a=1", "sqlite") + assert statement.get_settings() == {"a": "1"} + + +def test_kustokqlstatement_split_script() -> None: + """ + Test the `KustoKQLStatement` split method. + """ + statements = KustoKQLStatement.split_script( + """ +let totalPagesPerDay = PageViews +| summarize by Page, Day = startofday(Timestamp) +| summarize count() by Day; +let materializedScope = PageViews +| summarize by Page, Day = startofday(Timestamp); +let cachedResult = materialize(materializedScope); +cachedResult +| project Page, Day1 = Day +| join kind = inner +( + cachedResult + | project Page, Day2 = Day +) +on Page +| where Day2 > Day1 +| summarize count() by Day1, Day2 +| join kind = inner + totalPagesPerDay +on $left.Day1 == $right.Day +| project Day1, Day2, Percentage = count_*100.0/count_1 + """, + "kustokql", + ) + assert len(statements) == 4 + + +def test_kustokqlstatement_with_program() -> None: + """ + Test the `KustoKQLStatement` split method when the KQL has a program. + """ + statements = KustoKQLStatement.split_script( + """ +print program = ``` + public class Program { + public static void Main() { + System.Console.WriteLine("Hello!"); + } + }``` + """, + "kustokql", + ) + assert len(statements) == 1 + + +def test_kustokqlstatement_with_set() -> None: + """ + Test the `KustoKQLStatement` split method when the KQL has a set command. + """ + statements = KustoKQLStatement.split_script( + """ +set querytrace; +Events | take 100 + """, + "kustokql", + ) + assert len(statements) == 2 + assert statements[0].format() == "set querytrace" + assert statements[1].format() == "Events | take 100" + + +@pytest.mark.parametrize( + "kql,statements", + [ + ('print banner=strcat("Hello", ", ", "World!")', 1), + (r"print 'O\'Malley\'s'", 1), + (r"print 'O\'Mal;ley\'s'", 1), + ("print ```foo;\nbar;\nbaz;```\n", 1), + ], +) +def test_kustokql_statement_split_special(kql: str, statements: int) -> None: + assert len(KustoKQLStatement.split_script(kql, "kustokql")) == statements + + +def test_split_kql() -> None: + """ + Test the `split_kql` function. + """ + kql = """ +let totalPagesPerDay = PageViews +| summarize by Page, Day = startofday(Timestamp) +| summarize count() by Day; +let materializedScope = PageViews +| summarize by Page, Day = startofday(Timestamp); +let cachedResult = materialize(materializedScope); +cachedResult +| project Page, Day1 = Day +| join kind = inner +( + cachedResult + | project Page, Day2 = Day +) +on Page +| where Day2 > Day1 +| summarize count() by Day1, Day2 +| join kind = inner + totalPagesPerDay +on $left.Day1 == $right.Day +| project Day1, Day2, Percentage = count_*100.0/count_1 + """ + assert split_kql(kql) == [ + """ +let totalPagesPerDay = PageViews +| summarize by Page, Day = startofday(Timestamp) +| summarize count() by Day""", + """ +let materializedScope = PageViews +| summarize by Page, Day = startofday(Timestamp)""", + """ +let cachedResult = materialize(materializedScope)""", + """ +cachedResult +| project Page, Day1 = Day +| join kind = inner +( + cachedResult + | project Page, Day2 = Day +) +on Page +| where Day2 > Day1 +| summarize count() by Day1, Day2 +| join kind = inner + totalPagesPerDay +on $left.Day1 == $right.Day +| project Day1, Day2, Percentage = count_*100.0/count_1 + """, + ] + + +@pytest.mark.parametrize( + ("engine", "sql", "expected"), + [ + # SQLite tests + ("sqlite", "SELECT 1", False), + ("sqlite", "INSERT INTO foo VALUES (1)", True), + ("sqlite", "UPDATE foo SET bar = 2 WHERE id = 1", True), + ("sqlite", "DELETE FROM foo WHERE id = 1", True), + ("sqlite", "CREATE TABLE foo (id INT, bar TEXT)", True), + ("sqlite", "DROP TABLE foo", True), + ("sqlite", "EXPLAIN SELECT * FROM foo", False), + ("sqlite", "PRAGMA table_info(foo)", False), + ("postgresql", "SELECT 1", False), + ("postgresql", "INSERT INTO foo (id, bar) VALUES (1, 'test')", True), + ("postgresql", "UPDATE foo SET bar = 'new' WHERE id = 1", True), + ("postgresql", "DELETE FROM foo WHERE id = 1", True), + ("postgresql", "CREATE TABLE foo (id SERIAL PRIMARY KEY, bar TEXT)", True), + ("postgresql", "DROP TABLE foo", True), + ("postgresql", "EXPLAIN ANALYZE SELECT * FROM foo", False), + ("postgresql", "EXPLAIN ANALYZE DELETE FROM foo", True), + ("postgresql", "SHOW search_path", False), + ("postgresql", "SET search_path TO public", False), + ( + "postgres", + """ + with source as ( + select 1 as one + ) + select * from source + """, + False, + ), + ("trino", "SELECT 1", False), + ("trino", "INSERT INTO foo VALUES (1, 'bar')", True), + ("trino", "UPDATE foo SET bar = 'baz' WHERE id = 1", True), + ("trino", "DELETE FROM foo WHERE id = 1", True), + ("trino", "CREATE TABLE foo (id INT, bar VARCHAR)", True), + ("trino", "DROP TABLE foo", True), + ("trino", "EXPLAIN SELECT * FROM foo", False), + ("trino", "SHOW SCHEMAS", False), + ("trino", "SET SESSION optimization_level = '3'", False), + ("kustokql", "tbl | limit 100", False), + ("kustokql", "let foo = 1; tbl | where bar == foo", False), + ("kustokql", ".show tables", False), + ("kustokql", "print 1", False), + ("kustokql", "set querytrace; Events | take 100", False), + ("kustokql", ".drop table foo", True), + ("kustokql", ".set-or-append table foo <| bar", True), + ], +) +def test_has_mutation(engine: str, sql: str, expected: bool) -> None: + """ + Test the `has_mutation` method. + """ + assert SQLScript(sql, engine).has_mutation() == expected diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 216d896a5ea99..23d51de64cdeb 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -30,6 +30,7 @@ QueryClauseValidationException, SupersetSecurityException, ) +from superset.sql.parse import Table from superset.sql_parse import ( add_table_name, check_sql_functions_exist, @@ -39,18 +40,13 @@ has_table_query, insert_rls_as_subquery, insert_rls_in_predicate, - KustoKQLStatement, ParsedQuery, sanitize_clause, - split_kql, - SQLScript, - SQLStatement, strip_comments_from_sql, - Table, ) -def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]: +def extract_tables(query: str, engine: str = "base") -> set[Table]: """ Helper function to extract tables referenced in a query. """ @@ -285,7 +281,7 @@ def test_extract_tables_illdefined() -> None: extract_tables('SELECT * FROM "tbname') assert ( str(excinfo.value) - == "You may have an error in your SQL statement. Error tokenizing 'SELECT * FROM \"tbnam'" + == "You may have an error in your SQL statement. Unable to parse script" ) # odd edge case that works @@ -1834,49 +1830,6 @@ def test_is_select() -> None: assert ParsedQuery("USE foo; SELECT * FROM bar").is_select() -def test_sqlquery() -> None: - """ - Test the `SQLScript` class. - """ - script = SQLScript("SELECT 1; SELECT 2;", "sqlite") - - assert len(script.statements) == 2 - assert script.format() == "SELECT\n 1;\nSELECT\n 2" - assert script.statements[0].format() == "SELECT\n 1" - - script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite") - assert script.get_settings() == {"a": "2"} - - query = SQLScript( - """set querytrace; -Events | take 100""", - "kustokql", - ) - assert query.get_settings() == {"querytrace": True} - - -def test_sqlstatement() -> None: - """ - Test the `SQLStatement` class. - """ - statement = SQLStatement( - "SELECT * FROM table1 UNION ALL SELECT * FROM table2", - "sqlite", - ) - - assert statement.tables == { - Table(table="table1", schema=None, catalog=None), - Table(table="table2", schema=None, catalog=None), - } - assert ( - statement.format() - == "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2" - ) - - statement = SQLStatement("SET a=1", "sqlite") - assert statement.get_settings() == {"a": "1"} - - @pytest.mark.parametrize( "engine", [ @@ -1924,194 +1877,3 @@ def test_extract_tables_from_jinja_sql( ) == expected ) - - -def test_kustokqlstatement_split_query() -> None: - """ - Test the `KustoKQLStatement` split method. - """ - statements = KustoKQLStatement.split_query( - """ -let totalPagesPerDay = PageViews -| summarize by Page, Day = startofday(Timestamp) -| summarize count() by Day; -let materializedScope = PageViews -| summarize by Page, Day = startofday(Timestamp); -let cachedResult = materialize(materializedScope); -cachedResult -| project Page, Day1 = Day -| join kind = inner -( - cachedResult - | project Page, Day2 = Day -) -on Page -| where Day2 > Day1 -| summarize count() by Day1, Day2 -| join kind = inner - totalPagesPerDay -on $left.Day1 == $right.Day -| project Day1, Day2, Percentage = count_*100.0/count_1 - """, - "kustokql", - ) - assert len(statements) == 4 - - -def test_kustokqlstatement_with_program() -> None: - """ - Test the `KustoKQLStatement` split method when the KQL has a program. - """ - statements = KustoKQLStatement.split_query( - """ -print program = ``` - public class Program { - public static void Main() { - System.Console.WriteLine("Hello!"); - } - }``` - """, - "kustokql", - ) - assert len(statements) == 1 - - -def test_kustokqlstatement_with_set() -> None: - """ - Test the `KustoKQLStatement` split method when the KQL has a set command. - """ - statements = KustoKQLStatement.split_query( - """ -set querytrace; -Events | take 100 - """, - "kustokql", - ) - assert len(statements) == 2 - assert statements[0].format() == "set querytrace" - assert statements[1].format() == "Events | take 100" - - -@pytest.mark.parametrize( - "kql,statements", - [ - ('print banner=strcat("Hello", ", ", "World!")', 1), - (r"print 'O\'Malley\'s'", 1), - (r"print 'O\'Mal;ley\'s'", 1), - ("print ```foo;\nbar;\nbaz;```\n", 1), - ], -) -def test_kustokql_statement_split_special(kql: str, statements: int) -> None: - assert len(KustoKQLStatement.split_query(kql, "kustokql")) == statements - - -def test_split_kql() -> None: - """ - Test the `split_kql` function. - """ - kql = """ -let totalPagesPerDay = PageViews -| summarize by Page, Day = startofday(Timestamp) -| summarize count() by Day; -let materializedScope = PageViews -| summarize by Page, Day = startofday(Timestamp); -let cachedResult = materialize(materializedScope); -cachedResult -| project Page, Day1 = Day -| join kind = inner -( - cachedResult - | project Page, Day2 = Day -) -on Page -| where Day2 > Day1 -| summarize count() by Day1, Day2 -| join kind = inner - totalPagesPerDay -on $left.Day1 == $right.Day -| project Day1, Day2, Percentage = count_*100.0/count_1 - """ - assert split_kql(kql) == [ - """ -let totalPagesPerDay = PageViews -| summarize by Page, Day = startofday(Timestamp) -| summarize count() by Day""", - """ -let materializedScope = PageViews -| summarize by Page, Day = startofday(Timestamp)""", - """ -let cachedResult = materialize(materializedScope)""", - """ -cachedResult -| project Page, Day1 = Day -| join kind = inner -( - cachedResult - | project Page, Day2 = Day -) -on Page -| where Day2 > Day1 -| summarize count() by Day1, Day2 -| join kind = inner - totalPagesPerDay -on $left.Day1 == $right.Day -| project Day1, Day2, Percentage = count_*100.0/count_1 - """, - ] - - -@pytest.mark.parametrize( - ("engine", "sql", "expected"), - [ - # SQLite tests - ("sqlite", "SELECT 1", False), - ("sqlite", "INSERT INTO foo VALUES (1)", True), - ("sqlite", "UPDATE foo SET bar = 2 WHERE id = 1", True), - ("sqlite", "DELETE FROM foo WHERE id = 1", True), - ("sqlite", "CREATE TABLE foo (id INT, bar TEXT)", True), - ("sqlite", "DROP TABLE foo", True), - ("sqlite", "EXPLAIN SELECT * FROM foo", False), - ("sqlite", "PRAGMA table_info(foo)", False), - ("postgresql", "SELECT 1", False), - ("postgresql", "INSERT INTO foo (id, bar) VALUES (1, 'test')", True), - ("postgresql", "UPDATE foo SET bar = 'new' WHERE id = 1", True), - ("postgresql", "DELETE FROM foo WHERE id = 1", True), - ("postgresql", "CREATE TABLE foo (id SERIAL PRIMARY KEY, bar TEXT)", True), - ("postgresql", "DROP TABLE foo", True), - ("postgresql", "EXPLAIN ANALYZE SELECT * FROM foo", False), - ("postgresql", "EXPLAIN ANALYZE DELETE FROM foo", True), - ("postgresql", "SHOW search_path", False), - ("postgresql", "SET search_path TO public", False), - ( - "postgres", - """ - with source as ( - select 1 as one - ) - select * from source - """, - False, - ), - ("trino", "SELECT 1", False), - ("trino", "INSERT INTO foo VALUES (1, 'bar')", True), - ("trino", "UPDATE foo SET bar = 'baz' WHERE id = 1", True), - ("trino", "DELETE FROM foo WHERE id = 1", True), - ("trino", "CREATE TABLE foo (id INT, bar VARCHAR)", True), - ("trino", "DROP TABLE foo", True), - ("trino", "EXPLAIN SELECT * FROM foo", False), - ("trino", "SHOW SCHEMAS", False), - ("trino", "SET SESSION optimization_level = '3'", False), - ("kustokql", "tbl | limit 100", False), - ("kustokql", "let foo = 1; tbl | where bar == foo", False), - ("kustokql", ".show tables", False), - ("kustokql", "print 1", False), - ("kustokql", "set querytrace; Events | take 100", False), - ("kustokql", ".drop table foo", True), - ("kustokql", ".set-or-append table foo <| bar", True), - ], -) -def test_has_mutation(engine: str, sql: str, expected: bool) -> None: - """ - Test the `has_mutation` method. - """ - assert SQLScript(sql, engine).has_mutation() == expected