Skip to content

Commit

Permalink
fix: DB-specific quoting in Jinja macro (apache#25779)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Oct 30, 2023
1 parent ed14f36 commit 5659c87
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
45 changes: 31 additions & 14 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from jinja2 import DebugUndefined
from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.expression import bindparam
from sqlalchemy.types import String

from superset.constants import LRU_CACHE_MAX_SIZE
Expand Down Expand Up @@ -396,23 +397,39 @@ def validate_template_context(
return validate_context_types(context)


def where_in(values: list[Any], mark: str = "'") -> str:
"""
Given a list of values, build a parenthesis list suitable for an IN expression.
class WhereInMacro: # pylint: disable=too-few-public-methods
def __init__(self, dialect: Dialect):
self.dialect = dialect

>>> where_in([1, "b", 3])
(1, 'b', 3)
def __call__(self, values: list[Any], mark: Optional[str] = None) -> str:
"""
Given a list of values, build a parenthesis list suitable for an IN expression.
"""
>>> from sqlalchemy.dialects import mysql
>>> where_in = WhereInMacro(dialect=mysql.dialect())
>>> where_in([1, "Joe's", 3])
(1, 'Joe''s', 3)
def quote(value: Any) -> str:
if isinstance(value, str):
value = value.replace(mark, mark * 2)
return f"{mark}{value}{mark}"
return str(value)
"""
binds = [bindparam(f"value_{i}", value) for i, value in enumerate(values)]
string_representations = [
str(
bind.compile(
dialect=self.dialect, compile_kwargs={"literal_binds": True}
)
)
for bind in binds
]
joined_values = ", ".join(string_representations)
result = f"({joined_values})"

if mark:
result += (
"\n-- WARNING: the `mark` parameter was removed from the `where_in` "
"macro for security reasons\n"
)

joined_values = ", ".join(quote(value) for value in values)
return f"({joined_values})"
return result


class BaseTemplateProcessor:
Expand Down Expand Up @@ -448,7 +465,7 @@ def __init__(
self.set_context(**kwargs)

# custom filters
self._env.filters["where_in"] = where_in
self._env.filters["where_in"] = WhereInMacro(database.get_dialect())

def set_context(self, **kwargs: Any) -> None:
self._context.update(kwargs)
Expand Down
9 changes: 7 additions & 2 deletions tests/unit_tests/jinja_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,22 @@

import pytest
from pytest_mock import MockFixture
from sqlalchemy.dialects import mysql

from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.jinja_context import dataset_macro, where_in
from superset.jinja_context import dataset_macro, WhereInMacro


def test_where_in() -> None:
"""
Test the ``where_in`` Jinja2 filter.
"""
where_in = WhereInMacro(mysql.dialect())
assert where_in([1, "b", 3]) == "(1, 'b', 3)"
assert where_in([1, "b", 3], '"') == '(1, "b", 3)'
assert where_in([1, "b", 3], '"') == (
"(1, 'b', 3)\n-- WARNING: the `mark` parameter was removed from the "
"`where_in` macro for security reasons\n"
)
assert where_in(["O'Malley's"]) == "('O''Malley''s')"


Expand Down

0 comments on commit 5659c87

Please sign in to comment.