diff --git a/edb/edgeql/compiler/stmt.py b/edb/edgeql/compiler/stmt.py index 02bdf674d09..91060e0dca4 100644 --- a/edb/edgeql/compiler/stmt.py +++ b/edb/edgeql/compiler/stmt.py @@ -309,31 +309,32 @@ def compile_InternalGroupQuery( expr: qlast.InternalGroupQuery, *, ctx: context.ContextLevel ) -> irast.Set: # We disallow use of FOR GROUP except for when running in test mode. - if not expr.from_desugaring and not ctx.env.options.testmode: - raise errors.UnsupportedFeatureError( - "'FOR GROUP' is an internal testing feature", - span=expr.span, - ) _protect_expr(expr.subject, ctx=ctx) _protect_expr(expr.result, ctx=ctx) with ctx.subquery() as sctx: stmt = irast.GroupStmt(by=expr.by) + init_stmt(stmt, expr, ctx=sctx, parent_ctx=ctx) with sctx.newscope(fenced=True) as topctx: + # N.B: Subject is exposed because we want any shape on the # subject to be exposed on bare references to the group # alias. This is frankly pretty dodgy behavior for # FOR GROUP to have but the real GROUP needs to # maintain shapes, and this is the easiest way to handle # that. - stmt.subject = compile_result_clause( - expr.subject, - result_alias=expr.subject_alias, - exprtype=s_types.ExprType.Group, - ctx=topctx) + stmt.subject = setgen.scoped_set( + compile_result_clause( + (expr.subject), + result_alias=expr.subject_alias, + exprtype=s_types.ExprType.Group, + ctx=topctx, + ), + ctx=topctx, + ) if topctx.partial_path_prefix: pathctx.register_set_in_scope( @@ -424,10 +425,14 @@ def compile_InternalGroupQuery( stmt.grouping_binding, path_scope=bctx.path_scope, ctx=bctx ) - stmt.result = compile_result_clause( - astutils.ensure_ql_query(expr.result), - result_alias=expr.result_alias, - ctx=bctx) + stmt.result = setgen.scoped_set( + compile_result_clause( + astutils.ensure_ql_query(expr.result), + result_alias=expr.result_alias, + ctx=bctx, + ), + ctx=bctx, + ) stmt.where = clauses.compile_where_clause(expr.where, ctx=bctx) diff --git a/edb/pgsql/compiler/clauses.py b/edb/pgsql/compiler/clauses.py index 31e531edad4..76bce2ad89d 100644 --- a/edb/pgsql/compiler/clauses.py +++ b/edb/pgsql/compiler/clauses.py @@ -23,6 +23,7 @@ import random +from edb.common import topological from edb.common import ast as ast_visitor from edb.edgeql import qltypes @@ -114,12 +115,34 @@ def compile_materialized_exprs( matctx.materializing |= {stmt} matctx.expr_exposed = True - # HACK: Sort longer paths before shorter ones - # We want foo->bar to appear before foo - mat_sets = sorted( - (stmt.materialized_sets.values()), - key=lambda m: -len(m.materialized.path_id), - ) + # Determine the order to materialize sets. If a set A references another + # set B, B should materialize first. + mat_set_dependency_graph: dict[ + irast.PathId, + topological.DepGraphEntry[ + irast.PathId, irast.MaterializedSet, None + ], + ] = {} + for mat_set in stmt.materialized_sets.values(): + path_id = mat_set.materialized.path_id + + mat_set_dependency_graph[path_id] = topological.DepGraphEntry( + item=mat_set, + deps={ + child.path_id + for child in ast_visitor.find_children( + mat_set.materialized, irast.Set + ) + if isinstance(child.expr, irast.MaterializedExpr) + }, + ) + + mat_sets: list[irast.MaterializedSet] = [ + mat_set + for mat_set in topological.sort( + mat_set_dependency_graph, allow_unresolved=True, + ) + ] for mat_set in mat_sets: if len(mat_set.uses) <= 1: @@ -274,7 +297,9 @@ def compile_volatile_bindings( stmt: irast.Stmt, *, ctx: context.CompilerContextLevel -) -> None: +) -> list[irast.Set]: + uncompiled_bindings = [] + for binding, volatility in (stmt.bindings or ()): # If something we are WITH binding contains DML, we want to # compile it *now*, in the context of its initial appearance @@ -299,6 +324,11 @@ def compile_volatile_bindings( with ctx.substmt() as bctx: dispatch.compile(binding, ctx=bctx) + else: + uncompiled_bindings.append(binding) + + return uncompiled_bindings + def _compile_volatile_binding_for_dml( stmt: irast.Stmt, diff --git a/edb/pgsql/compiler/group.py b/edb/pgsql/compiler/group.py index a3c8e5a0c9a..32ce3ee09f5 100644 --- a/edb/pgsql/compiler/group.py +++ b/edb/pgsql/compiler/group.py @@ -173,17 +173,39 @@ def _compile_group( ctx: context.CompilerContextLevel, parent_ctx: context.CompilerContextLevel) -> pgast.BaseExpr: - clauses.compile_volatile_bindings(stmt, ctx=ctx) + uncompiled_bindings = clauses.compile_volatile_bindings(stmt, ctx=ctx) query = ctx.stmt + # Process materialized sets + clauses.compile_materialized_exprs(ctx.rel, stmt, ctx=ctx) + # Compile a GROUP BY into a subquery, along with all the aggregations with ctx.subrel() as groupctx: grouprel = groupctx.rel + # Compile the remaining bindings here + if False: + for binding in uncompiled_bindings: + with groupctx.subrel() as bindctx: + dispatch.compile(binding, ctx=bindctx) + + bind_rvar = relctx.rvar_for_rel( + bindctx.rel, ctx=groupctx, lateral=True + ) + relctx.include_rvar( + groupctx.rel, + bind_rvar, + binding.path_id, + ctx=groupctx, + ) + # First compile the actual subject # subrel *solely* for path id map reasons with groupctx.subrel() as subjctx: + # Process materialized sets + #clauses.compile_materialized_exprs(groupctx.rel, stmt, ctx=subjctx) + subjctx.expr_exposed = False dispatch.visit(stmt.subject, ctx=subjctx) @@ -392,9 +414,6 @@ def _get_volatility_ref() -> Optional[pgast.BaseExpr]: outctx.volatility_ref += (lambda stmt, xctx: _get_volatility_ref(),) - # Process materialized sets - clauses.compile_materialized_exprs(query, stmt, ctx=outctx) - clauses.compile_output(stmt.result, ctx=outctx) with ctx.new() as ictx: diff --git a/edb/pgsql/compiler/relgen.py b/edb/pgsql/compiler/relgen.py index 361de222023..d77cbe951dd 100644 --- a/edb/pgsql/compiler/relgen.py +++ b/edb/pgsql/compiler/relgen.py @@ -1559,9 +1559,10 @@ def process_set_as_subquery( return _new_subquery_stmt_set_rvar(ir_set, stmt, ctx=newctx) # materialized refs should always get picked up by now - assert not isinstance(expr, irast.MaterializedExpr), ( - f"Can't find materialized set {ir_set.path_id}" - ) + if isinstance(expr, irast.MaterializedExpr): + raise AssertionError( + f"Can't find materialized set {ir_set.path_id}" + ) assert isinstance(expr, irast.Stmt) inner_set = expr.result diff --git a/tests/test_edgeql_group.py b/tests/test_edgeql_group.py index 90bbd714536..6bd606b0a6a 100644 --- a/tests/test_edgeql_group.py +++ b/tests/test_edgeql_group.py @@ -1089,6 +1089,176 @@ async def test_edgeql_group_binding_01(self): ]) ) + async def test_edgeql_group_binding_02(self): + await self.assert_query_result( + ''' + group ( + with X := {8, 9} + for x in enumerate(X) + select {a := 1, b := x.0, c := x.1} + ) by .a; + ''', + tb.bag([ + { + 'key': {'a': 1}, + 'grouping': ['a'], + 'elements': [ + {'a': 1, 'b': 0, 'c': 8}, + {'a': 1, 'b': 1, 'c': 9}, + ] + } + ]) + ) + + async def test_edgeql_group_binding_03(self): + await self.assert_query_result( + ''' + group ( + with X := {random(), random()} + for x in enumerate(X) + select {a := 1, b := x.0, c := (x.1 <= 1)} + ) by .a; + ''', + tb.bag([ + { + 'key': {'a': 1}, + 'grouping': ['a'], + 'elements': [ + {'a': 1, 'b': 0, 'c': True}, + {'a': 1, 'b': 1, 'c': True}, + ] + } + ]) + ) + + async def test_edgeql_group_binding_04(self): + await self.assert_query_result( + ''' + with X := {8, 9} + group ( + for x in enumerate(X) + select {a := 1, b := x.0, c := x.1} + ) by .a; + ''', + tb.bag([ + { + 'key': {'a': 1}, + 'grouping': ['a'], + 'elements': [ + {'a': 1, 'b': 0, 'c': 8}, + {'a': 1, 'b': 1, 'c': 9}, + ] + } + ]) + ) + + async def test_edgeql_group_binding_05(self): + await self.assert_query_result( + ''' + with X := {random(), random()} + group ( + for x in enumerate(X) + select {a := 1, b := x.0, c := (x.1 <= 1)} + ) by .a; + ''', + tb.bag([ + { + 'key': {'a': 1}, + 'grouping': ['a'], + 'elements': [ + {'a': 1, 'b': 0, 'c': True}, + {'a': 1, 'b': 1, 'c': True}, + ] + } + ]) + ) + + async def test_edgeql_group_binding_06(self): + await self.assert_query_result( + ''' + with X := {8, 9}, + Y := ( + for x in enumerate(X) + select {a := 1, b := x.0, c := x.1} + ) + group Y { a, b, c } by .a; + ''', + tb.bag([ + { + 'key': {'a': 1}, + 'grouping': ['a'], + 'elements': [ + {'a': 1, 'b': 0, 'c': 8}, + {'a': 1, 'b': 1, 'c': 9}, + ] + } + ]) + ) + + async def test_edgeql_group_binding_07(self): + await self.assert_query_result( + ''' + with X := {random(), random()}, + Y := ( + for x in enumerate(X) + select {a := 1, b := x.0, c := (x.1 <= 1)} + ) + group Y { a, b, c } by .a; + ''', + tb.bag([ + { + 'key': {'a': 1}, + 'grouping': ['a'], + 'elements': [ + {'a': 1, 'b': 0, 'c': True}, + {'a': 1, 'b': 1, 'c': True}, + ] + } + ]) + ) + + async def test_edgeql_group_binding_08(self): + await self.assert_query_result( + ''' + with Y := ( + for x in enumerate({8, 9}) + select {a := 1, b := x.0, c := x.1} + ) + group Y { a, b, c } by .a; + ''', + tb.bag([ + { + 'key': {'a': 1}, + 'grouping': ['a'], + 'elements': [ + {'a': 1, 'b': 0}, + {'a': 1, 'b': 1}, + ] + } + ]) + ) + + async def test_edgeql_group_binding_09(self): + await self.assert_query_result( + ''' + with Y := ( + for x in enumerate({random(), random()}) + select {a := 1, b := x.0, c := (x.1 <= 1)} + ) + group Y { a, b, c } by .a; + ''', + tb.bag([ + { + 'key': {'a': 1}, + 'grouping': ['a'], + 'elements': [ + {'a': 1, 'b': 0}, + {'a': 1, 'b': 1}, + ] + } + ]) + ) + async def test_edgeql_group_ordering_01(self): res = [ { diff --git a/tests/test_edgeql_internal_group.py b/tests/test_edgeql_internal_group.py index 42535926152..0b8e09b671a 100644 --- a/tests/test_edgeql_internal_group.py +++ b/tests/test_edgeql_internal_group.py @@ -1540,3 +1540,199 @@ async def test_edgeql_igroup_reshape_02(self): {"avg_cost": 2, "element": "Water"}, ]) ) + + async def test_edgeql_with_alias_simple_01(self): + await self.assert_query_result( + r""" + WITH + MODULE cards, + C := Card + FOR GROUP C + USING e := .element + BY e + IN g + UNION + { + key := e, + names := g.name, + }; + """, + tb.bag([ + {'key': 'Water', 'names': ['Bog monster', 'Giant turtle']}, + {'key': 'Fire', 'names': ['Imp', 'Dragon']}, + {'key': 'Earth', 'names': ['Dwarf', 'Golem']}, + {'key': 'Air', 'names': ['Sprite', 'Giant eagle', 'Djinn']}, + ]) + ) + + async def test_edgeql_with_alias_simple_02(self): + await self.assert_query_result( + r""" + WITH + MODULE cards, + C := Card + FOR GROUP (select C) + USING e := .element + BY e + IN g + UNION + { + key := e, + names := g.name, + }; + """, + tb.bag([ + {'key': 'Water', 'names': ['Bog monster', 'Giant turtle']}, + {'key': 'Fire', 'names': ['Imp', 'Dragon']}, + {'key': 'Earth', 'names': ['Dwarf', 'Golem']}, + {'key': 'Air', 'names': ['Sprite', 'Giant eagle', 'Djinn']}, + ]) + ) + + async def test_edgeql_with_alias_simple_03(self): + await self.assert_query_result( + r""" + WITH + MODULE cards, + C := (select Card { name_lower := str_lower(.name)}) + FOR GROUP C + USING e := .element + BY e + IN g + UNION + { + key := e, + names := g.name_lower, + }; + """, + tb.bag([ + {'key': 'Water', 'names': ['bog monster', 'giant turtle']}, + {'key': 'Fire', 'names': ['imp', 'dragon']}, + {'key': 'Earth', 'names': ['dwarf', 'golem']}, + {'key': 'Air', 'names': ['sprite', 'giant eagle', 'djinn']}, + ]) + ) + + async def test_edgeql_with_alias_simple_04(self): + await self.assert_query_result( + r""" + WITH + MODULE cards, + C := (select Card { number := 9 }) + FOR GROUP C { number } + USING e := .element + BY e + IN g + UNION + { + key := e, + numbers := g.number, + }; + """, + tb.bag([ + {'key': 'Water', 'numbers': {9, 9}}, + {'key': 'Fire', 'numbers': {9, 9}}, + {'key': 'Earth', 'numbers': {9, 9}}, + {'key': 'Air', 'numbers': {9, 9, 9}}, + ]) + ) + + async def test_edgeql_with_alias_for_01(self): + await self.assert_query_result( + r""" + WITH + MODULE cards, + U := (for x in {8, 9} select User { a := x }), + FOR GROUP U { name, a } + USING e := .name + BY e + IN g + UNION + { e_ := e, g_ := g} { + key := .e_, + numbers := .g_.a, + }; + """, + tb.bag([ + {'key': 'Alice', 'numbers': {8, 9}}, + {'key': 'Bob', 'numbers': {8, 9}}, + {'key': 'Carol', 'numbers': {8, 9}}, + {'key': 'Dave', 'numbers': {8, 9}}, + ]) + ) + + async def test_edgeql_with_alias_for_02(self): + await self.assert_query_result( + r""" + WITH + MODULE cards, + X := {8, 9}, + U := (for x in X select User { a := x }), + FOR GROUP U + USING e := .name + BY e + IN g + UNION + { e_ := e, g_ := g} { + key := .e_, + numbers := .g_.a, + }; + """, + tb.bag([ + {'key': 'Alice', 'numbers': {8, 9}}, + {'key': 'Bob', 'numbers': {8, 9}}, + {'key': 'Carol', 'numbers': {8, 9}}, + {'key': 'Dave', 'numbers': {8, 9}}, + ]) + ) + + async def test_edgeql_with_alias_for_03(self): + await self.assert_query_result( + r""" + WITH + MODULE cards, + X := {8, 9}, + FOR GROUP (for x in X select User { a := x }) + USING e := .name + BY e + IN g + UNION + { e_ := e, g_ := g} { + key := .e_, + numbers := .g_.a, + }; + """, + tb.bag([ + {'key': 'Alice', 'numbers': {8, 9}}, + {'key': 'Bob', 'numbers': {8, 9}}, + {'key': 'Carol', 'numbers': {8, 9}}, + {'key': 'Dave', 'numbers': {8, 9}}, + ]) + ) + + async def test_edgeql_with_alias_for_04(self): + await self.assert_query_result( + r""" + WITH + MODULE cards, + U := ( + for x in enumerate({8, 9}) + select User { a := x.0, b := x.1 } + ), + FOR GROUP U + USING e := .name + BY e + IN g + UNION + { e_ := e, g_ := g} { + key := .e_, + numbers := (.g_.a, .g_.b), + }; + """, + tb.bag([ + {'key': 'Alice', 'numbers': {(0, 8), (1, 9)}}, + {'key': 'Bob', 'numbers': {(0, 8), (1, 9)}}, + {'key': 'Carol', 'numbers': {(0, 8), (1, 9)}}, + {'key': 'Dave', 'numbers': {(0, 8), (1, 9)}}, + ]) + )