Skip to content

Commit

Permalink
Precompile SELECT set_config to SET
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen committed Sep 27, 2024
1 parent f42d362 commit 120a207
Showing 1 changed file with 82 additions and 3 deletions.
85 changes: 82 additions & 3 deletions edb/server/compiler/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@
'allow_user_specified_id': True,
'server_version': False,
'server_version_num': False,
# This is front-end only, because we cannot simply pass it trough,
# because queries for "system" views are compiled to "nonsystem" views.
# So having it as a FE setting, it is basically ignored.
'restrict_nonsystem_relation_kind': True,
}
)

Expand Down Expand Up @@ -82,6 +86,8 @@ def compile_sql(
query=orig_text,
)

stmt = precompile_set_config(stmt)

if isinstance(stmt, (pgast.VariableSetStmt, pgast.VariableResetStmt)):
value: Optional[dbstate.SQLSetting]
if isinstance(stmt, pgast.VariableSetStmt):
Expand Down Expand Up @@ -124,9 +130,11 @@ def compile_sql(
if stmt.scope == pgast.OptionsScope.SESSION:
unit.set_vars = {
f"default_{name}": (
value.val
if isinstance(value, pgast.StringConstant)
else pg_codegen.generate_source(value),
(
value.val
if isinstance(value, pgast.StringConstant)
else pg_codegen.generate_source(value)
),
)
for name, value in stmt.options.options.items()
}
Expand Down Expand Up @@ -337,3 +345,74 @@ def pg_const_to_python(expr: pgast.BaseExpr) -> str | int | float:
return float(expr.val)

raise NotImplementedError()


def precompile_set_config(stmt: pgast.BaseExpr) -> pgast.BaseExpr:
'''
Turn
SELECT set_config('...', ...)
and
SELECT set_config(name, ...) FROM pg_setting WHERE name = ...
into
SET ... TO ...
'''

if not (
isinstance(stmt, pgast.SelectStmt)
and len(stmt.target_list) == 1
and isinstance(stmt.target_list[0].val, pgast.FuncCall)
and stmt.target_list[0].val.name[-1] == 'set_config'
and len(stmt.target_list[0].val.args) == 3
and stmt.limit_count == None
and stmt.sort_clause == None
and (stmt.ctes == None or len(stmt.ctes) == 0)
):
return stmt

args = stmt.target_list[0].val.args
name = args[0]
val = args[1]
is_local = args[2]
if not (
isinstance(val, pgast.StringConstant)
and isinstance(is_local, pgast.BooleanConstant)
):
return stmt

setting_name: Optional[str] = None
if (
isinstance(name, pgast.StringConstant)
and len(stmt.from_clause) == 0
and (stmt.where_clause == None or len(stmt.where_clause) == 0)
):
# basic case of SELECT set_config('...', ...)
setting_name = name.val

elif (
len(stmt.from_clause) == 1
and isinstance(stmt.from_clause[0], pgast.RelRangeVar)
and stmt.from_clause[0].relation.name == 'pg_settings'
and isinstance(name, pgast.ColumnRef)
and tuple(name.name) == ('name',)
and isinstance(stmt.where_clause, pgast.Expr)
and stmt.where_clause.name == '='
and isinstance(stmt.where_clause.lexpr, pgast.ColumnRef)
and tuple(stmt.where_clause.lexpr.name) == ('name',)
and isinstance(stmt.where_clause.rexpr, pgast.StringConstant)
):
# SELECT set_config(name, ...) FROM pg_settings WHERE name = '...'
setting_name = stmt.where_clause.rexpr.val

if setting_name != None:
stmt = pgast.VariableSetStmt(
name=setting_name,
args=pgast.ArgsList(
args=[pgast.StringConstant(val=p) for p in val.val.split(', ')]
),
scope=(
pgast.OptionsScope.TRANSACTION
if is_local.val
else pgast.OptionsScope.SESSION
),
)
return stmt

0 comments on commit 120a207

Please sign in to comment.