Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix volatile with bindings in group statements. #8020

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions edb/edgeql/compiler/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,31 +309,32 @@ def compile_InternalGroupQuery(
expr: qlast.InternalGroupQuery, *, ctx: context.ContextLevel
) -> irast.Set:
# We disallow use of FOR GROUP except for when running in test mode.
if not expr.from_desugaring and not ctx.env.options.testmode:
raise errors.UnsupportedFeatureError(
"'FOR GROUP' is an internal testing feature",
span=expr.span,
)

_protect_expr(expr.subject, ctx=ctx)
_protect_expr(expr.result, ctx=ctx)

with ctx.subquery() as sctx:
stmt = irast.GroupStmt(by=expr.by)

init_stmt(stmt, expr, ctx=sctx, parent_ctx=ctx)

with sctx.newscope(fenced=True) as topctx:

# N.B: Subject is exposed because we want any shape on the
# subject to be exposed on bare references to the group
# alias. This is frankly pretty dodgy behavior for
# FOR GROUP to have but the real GROUP needs to
# maintain shapes, and this is the easiest way to handle
# that.
stmt.subject = compile_result_clause(
expr.subject,
result_alias=expr.subject_alias,
exprtype=s_types.ExprType.Group,
ctx=topctx)
stmt.subject = setgen.scoped_set(
compile_result_clause(
(expr.subject),
result_alias=expr.subject_alias,
exprtype=s_types.ExprType.Group,
ctx=topctx,
),
ctx=topctx,
)

if topctx.partial_path_prefix:
pathctx.register_set_in_scope(
Expand Down Expand Up @@ -424,10 +425,14 @@ def compile_InternalGroupQuery(
stmt.grouping_binding, path_scope=bctx.path_scope, ctx=bctx
)

stmt.result = compile_result_clause(
astutils.ensure_ql_query(expr.result),
result_alias=expr.result_alias,
ctx=bctx)
stmt.result = setgen.scoped_set(
compile_result_clause(
astutils.ensure_ql_query(expr.result),
result_alias=expr.result_alias,
ctx=bctx,
),
ctx=bctx,
)

stmt.where = clauses.compile_where_clause(expr.where, ctx=bctx)

Expand Down
44 changes: 37 additions & 7 deletions edb/pgsql/compiler/clauses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import random

from edb.common import topological
from edb.common import ast as ast_visitor

from edb.edgeql import qltypes
Expand Down Expand Up @@ -114,12 +115,34 @@ def compile_materialized_exprs(
matctx.materializing |= {stmt}
matctx.expr_exposed = True

# HACK: Sort longer paths before shorter ones
# We want foo->bar to appear before foo
mat_sets = sorted(
(stmt.materialized_sets.values()),
key=lambda m: -len(m.materialized.path_id),
)
# Determine the order to materialize sets. If a set A references another
# set B, B should materialize first.
mat_set_dependency_graph: dict[
irast.PathId,
topological.DepGraphEntry[
irast.PathId, irast.MaterializedSet, None
],
] = {}
for mat_set in stmt.materialized_sets.values():
path_id = mat_set.materialized.path_id

mat_set_dependency_graph[path_id] = topological.DepGraphEntry(
item=mat_set,
deps={
child.path_id
for child in ast_visitor.find_children(
mat_set.materialized, irast.Set
)
if isinstance(child.expr, irast.MaterializedExpr)
},
)

mat_sets: list[irast.MaterializedSet] = [
mat_set
for mat_set in topological.sort(
mat_set_dependency_graph, allow_unresolved=True,
)
]

for mat_set in mat_sets:
if len(mat_set.uses) <= 1:
Expand Down Expand Up @@ -274,7 +297,9 @@ def compile_volatile_bindings(
stmt: irast.Stmt,
*,
ctx: context.CompilerContextLevel
) -> None:
) -> list[irast.Set]:
uncompiled_bindings = []

for binding, volatility in (stmt.bindings or ()):
# If something we are WITH binding contains DML, we want to
# compile it *now*, in the context of its initial appearance
Expand All @@ -299,6 +324,11 @@ def compile_volatile_bindings(
with ctx.substmt() as bctx:
dispatch.compile(binding, ctx=bctx)

else:
uncompiled_bindings.append(binding)

return uncompiled_bindings


def _compile_volatile_binding_for_dml(
stmt: irast.Stmt,
Expand Down
27 changes: 23 additions & 4 deletions edb/pgsql/compiler/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,39 @@ def _compile_group(
ctx: context.CompilerContextLevel,
parent_ctx: context.CompilerContextLevel) -> pgast.BaseExpr:

clauses.compile_volatile_bindings(stmt, ctx=ctx)
uncompiled_bindings = clauses.compile_volatile_bindings(stmt, ctx=ctx)

query = ctx.stmt

# Process materialized sets
clauses.compile_materialized_exprs(ctx.rel, stmt, ctx=ctx)

# Compile a GROUP BY into a subquery, along with all the aggregations
with ctx.subrel() as groupctx:
grouprel = groupctx.rel

# Compile the remaining bindings here
if False:
for binding in uncompiled_bindings:
with groupctx.subrel() as bindctx:
dispatch.compile(binding, ctx=bindctx)

bind_rvar = relctx.rvar_for_rel(
bindctx.rel, ctx=groupctx, lateral=True
)
relctx.include_rvar(
groupctx.rel,
bind_rvar,
binding.path_id,
ctx=groupctx,
)

# First compile the actual subject
# subrel *solely* for path id map reasons
with groupctx.subrel() as subjctx:
# Process materialized sets
#clauses.compile_materialized_exprs(groupctx.rel, stmt, ctx=subjctx)

subjctx.expr_exposed = False

dispatch.visit(stmt.subject, ctx=subjctx)
Expand Down Expand Up @@ -392,9 +414,6 @@ def _get_volatility_ref() -> Optional[pgast.BaseExpr]:

outctx.volatility_ref += (lambda stmt, xctx: _get_volatility_ref(),)

# Process materialized sets
clauses.compile_materialized_exprs(query, stmt, ctx=outctx)

clauses.compile_output(stmt.result, ctx=outctx)

with ctx.new() as ictx:
Expand Down
7 changes: 4 additions & 3 deletions edb/pgsql/compiler/relgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,9 +1559,10 @@ def process_set_as_subquery(
return _new_subquery_stmt_set_rvar(ir_set, stmt, ctx=newctx)

# materialized refs should always get picked up by now
assert not isinstance(expr, irast.MaterializedExpr), (
f"Can't find materialized set {ir_set.path_id}"
)
if isinstance(expr, irast.MaterializedExpr):
raise AssertionError(
f"Can't find materialized set {ir_set.path_id}"
)
assert isinstance(expr, irast.Stmt)

inner_set = expr.result
Expand Down
170 changes: 170 additions & 0 deletions tests/test_edgeql_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,176 @@ async def test_edgeql_group_binding_01(self):
])
)

async def test_edgeql_group_binding_02(self):
await self.assert_query_result(
'''
group (
with X := {8, 9}
for x in enumerate(X)
select {a := 1, b := x.0, c := x.1}
) by .a;
''',
tb.bag([
{
'key': {'a': 1},
'grouping': ['a'],
'elements': [
{'a': 1, 'b': 0, 'c': 8},
{'a': 1, 'b': 1, 'c': 9},
]
}
])
)

async def test_edgeql_group_binding_03(self):
await self.assert_query_result(
'''
group (
with X := {random(), random()}
for x in enumerate(X)
select {a := 1, b := x.0, c := (x.1 <= 1)}
) by .a;
''',
tb.bag([
{
'key': {'a': 1},
'grouping': ['a'],
'elements': [
{'a': 1, 'b': 0, 'c': True},
{'a': 1, 'b': 1, 'c': True},
]
}
])
)

async def test_edgeql_group_binding_04(self):
await self.assert_query_result(
'''
with X := {8, 9}
group (
for x in enumerate(X)
select {a := 1, b := x.0, c := x.1}
) by .a;
''',
tb.bag([
{
'key': {'a': 1},
'grouping': ['a'],
'elements': [
{'a': 1, 'b': 0, 'c': 8},
{'a': 1, 'b': 1, 'c': 9},
]
}
])
)

async def test_edgeql_group_binding_05(self):
await self.assert_query_result(
'''
with X := {random(), random()}
group (
for x in enumerate(X)
select {a := 1, b := x.0, c := (x.1 <= 1)}
) by .a;
''',
tb.bag([
{
'key': {'a': 1},
'grouping': ['a'],
'elements': [
{'a': 1, 'b': 0, 'c': True},
{'a': 1, 'b': 1, 'c': True},
]
}
])
)

async def test_edgeql_group_binding_06(self):
await self.assert_query_result(
'''
with X := {8, 9},
Y := (
for x in enumerate(X)
select {a := 1, b := x.0, c := x.1}
)
group Y { a, b, c } by .a;
''',
tb.bag([
{
'key': {'a': 1},
'grouping': ['a'],
'elements': [
{'a': 1, 'b': 0, 'c': 8},
{'a': 1, 'b': 1, 'c': 9},
]
}
])
)

async def test_edgeql_group_binding_07(self):
await self.assert_query_result(
'''
with X := {random(), random()},
Y := (
for x in enumerate(X)
select {a := 1, b := x.0, c := (x.1 <= 1)}
)
group Y { a, b, c } by .a;
''',
tb.bag([
{
'key': {'a': 1},
'grouping': ['a'],
'elements': [
{'a': 1, 'b': 0, 'c': True},
{'a': 1, 'b': 1, 'c': True},
]
}
])
)

async def test_edgeql_group_binding_08(self):
await self.assert_query_result(
'''
with Y := (
for x in enumerate({8, 9})
select {a := 1, b := x.0, c := x.1}
)
group Y { a, b, c } by .a;
''',
tb.bag([
{
'key': {'a': 1},
'grouping': ['a'],
'elements': [
{'a': 1, 'b': 0},
{'a': 1, 'b': 1},
]
}
])
)

async def test_edgeql_group_binding_09(self):
await self.assert_query_result(
'''
with Y := (
for x in enumerate({random(), random()})
select {a := 1, b := x.0, c := (x.1 <= 1)}
)
group Y { a, b, c } by .a;
''',
tb.bag([
{
'key': {'a': 1},
'grouping': ['a'],
'elements': [
{'a': 1, 'b': 0},
{'a': 1, 'b': 1},
]
}
])
)

async def test_edgeql_group_ordering_01(self):
res = [
{
Expand Down
Loading
Loading