Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precompile SELECT set_config to SET #7810

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 82 additions & 3 deletions edb/server/compiler/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@
'allow_user_specified_id': True,
'server_version': False,
'server_version_num': False,
# This is front-end only, because we cannot simply pass it trough,
# because queries for "system" views are compiled to "nonsystem" views.
# So having it as a FE setting, it is basically ignored.
'restrict_nonsystem_relation_kind': True,
}
)

Expand Down Expand Up @@ -82,6 +86,8 @@ def compile_sql(
query=orig_text,
)

stmt = precompile_set_config(stmt)

if isinstance(stmt, (pgast.VariableSetStmt, pgast.VariableResetStmt)):
value: Optional[dbstate.SQLSetting]
if isinstance(stmt, pgast.VariableSetStmt):
Expand Down Expand Up @@ -124,9 +130,11 @@ def compile_sql(
if stmt.scope == pgast.OptionsScope.SESSION:
unit.set_vars = {
f"default_{name}": (
value.val
if isinstance(value, pgast.StringConstant)
else pg_codegen.generate_source(value),
(
value.val
if isinstance(value, pgast.StringConstant)
else pg_codegen.generate_source(value)
),
)
for name, value in stmt.options.options.items()
}
Expand Down Expand Up @@ -337,3 +345,74 @@ def pg_const_to_python(expr: pgast.BaseExpr) -> str | int | float:
return float(expr.val)

raise NotImplementedError()


def precompile_set_config(stmt: pgast.BaseExpr) -> pgast.BaseExpr:
'''
Turn
SELECT set_config('...', ...)
and
SELECT set_config(name, ...) FROM pg_setting WHERE name = ...
into
SET ... TO ...
'''

if not (
isinstance(stmt, pgast.SelectStmt)
and len(stmt.target_list) == 1
and isinstance(stmt.target_list[0].val, pgast.FuncCall)
and stmt.target_list[0].val.name[-1] == 'set_config'
and len(stmt.target_list[0].val.args) == 3
and stmt.limit_count == None
and stmt.sort_clause == None
and (stmt.ctes == None or len(stmt.ctes) == 0)
):
return stmt

args = stmt.target_list[0].val.args
name = args[0]
val = args[1]
is_local = args[2]
if not (
isinstance(val, pgast.StringConstant)
and isinstance(is_local, pgast.BooleanConstant)
):
return stmt

setting_name: Optional[str] = None
if (
isinstance(name, pgast.StringConstant)
and len(stmt.from_clause) == 0
and (stmt.where_clause == None or len(stmt.where_clause) == 0)
):
# basic case of SELECT set_config('...', ...)
setting_name = name.val

elif (
len(stmt.from_clause) == 1
and isinstance(stmt.from_clause[0], pgast.RelRangeVar)
and stmt.from_clause[0].relation.name == 'pg_settings'
and isinstance(name, pgast.ColumnRef)
and tuple(name.name) == ('name',)
and isinstance(stmt.where_clause, pgast.Expr)
and stmt.where_clause.name == '='
and isinstance(stmt.where_clause.lexpr, pgast.ColumnRef)
and tuple(stmt.where_clause.lexpr.name) == ('name',)
and isinstance(stmt.where_clause.rexpr, pgast.StringConstant)
):
# SELECT set_config(name, ...) FROM pg_settings WHERE name = '...'
setting_name = stmt.where_clause.rexpr.val

if setting_name != None:
stmt = pgast.VariableSetStmt(
name=setting_name,
args=pgast.ArgsList(
args=[pgast.StringConstant(val=p) for p in val.val.split(', ')]
),
scope=(
pgast.OptionsScope.TRANSACTION
if is_local.val
else pgast.OptionsScope.SESSION
),
)
return stmt
Loading