diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index fde6f90292..a030d61b26 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -66,6 +66,16 @@ def _is_collectable_expr(node: ir.Node): return False +def _is_if_can_deref(node: ir.Node): + # `if_(can_deref(...), ..., ...)` + return ( + isinstance(node, ir.FunCall) + and node.fun == ir.SymRef(id="if_") + and isinstance(node.args[0], ir.FunCall) + and node.args[0].fun == ir.SymRef(id="can_deref") + ) + + @dataclasses.dataclass class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor): subexprs: dict[ir.Node, list[tuple[int, set[int]]]] = dataclasses.field( @@ -75,12 +85,19 @@ class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor): @classmethod def apply(cls, node: ir.Node): obj = cls() - obj.visit(node, used_symbol_ids=set(), collected_child_node_ids=set()) + obj.visit( + node, used_symbol_ids=set(), collected_child_node_ids=set(), allow_collection=True + ) # return subexpression in pre-order of the tree, i.e. the nodes closer to the root come # first, and skip the root node itself return {k: v for k, v in reversed(obj.subexprs.items()) if k is not node} def visit(self, node, **kwargs): + # TODO(tehrengruber): improve this case as we might miss subexpression that could be eliminated + # disable collection (for all child nodes) if node matches `if_(can_deref(...), ..., ...)` + if _is_if_can_deref(node): + kwargs["allow_collection"] = False + if not isinstance(node, SymbolTableTrait) and not _is_collectable_expr(node): return super().visit(node, **kwargs) @@ -101,7 +118,7 @@ def visit(self, node, **kwargs): # if no symbols are used that are defined in the root node, i.e. the node given to `apply`, # we collect the subexpression - if not used_symbol_ids and _is_collectable_expr(node): + if not used_symbol_ids and _is_collectable_expr(node) and kwargs["allow_collection"]: self.subexprs.setdefault(node, []).append((id(node), collected_child_node_ids)) # propagate to parent that we have collected its child diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index d01a1bef6c..871528197d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -153,6 +153,20 @@ def fencil(edge_f: Field[[Edge], int64], out: Field[[Vertex], int64]) -> None: assert np.allclose(ref, rs.out.array()) +def test_reduction_with_common_expression(reduction_setup, fieldview_backend): + rs = reduction_setup + V2EDim, V2E = rs.V2EDim, rs.V2E + + @field_operator(backend=fieldview_backend) + def testee(flux: Field[[Edge], int64]) -> Field[[Vertex], int64]: + return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) + + testee(rs.inp, out=rs.out, offset_provider=rs.offset_provider) + + ref = np.sum(rs.v2e_table * 2, axis=1) + assert np.allclose(ref, rs.out.array()) + + def test_conditional_nested_tuple(fieldview_backend): a_I_float = np_as_located_field(IDim)(np.random.randn(size).astype("float64")) b_I_float = np_as_located_field(IDim)(np.random.randn(size).astype("float64")) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index afbd0c1536..dbbbfd97a9 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -131,3 +131,25 @@ def common_expr(): ) actual = CSE().visit(testee) assert actual == expected + + +def test_if_can_deref(): + """ + Test no subexpression is moved outside expressions of the form `if_(can_deref(...), ..., ...)` + """ + # if can_deref(⟪Iₒ, 1ₒ⟫(it)) then ·⟪Iₒ, 1ₒ⟫(it) else ·⟪Iₒ, 1ₒ⟫(it) + ·⟪Iₒ, 1ₒ⟫(it) + testee = im.if_( + im.call("can_deref")(im.shift("I", 1)("it")), + im.deref(im.shift("I", 1)("it")), + # use something more involved where a subexpression can still be eliminated + im.plus(im.deref(im.shift("I", 1)("it")), im.deref(im.shift("I", 1)("it"))), + ) + # if can_deref(⟪Iₒ, 1ₒ⟫(it)) then ·⟪Iₒ, 1ₒ⟫(it) else (λ(_cs_1) → _cs_1 + _cs_1)(·⟪Iₒ, 1ₒ⟫(it)) + expected = im.if_( + im.call("can_deref")(im.shift("I", 1)("it")), + im.deref(im.shift("I", 1)("it")), + im.let("_cs_1", im.deref(im.shift("I", 1)("it")))(im.plus("_cs_1", "_cs_1")), + ) + + actual = CSE().visit(testee) + assert actual == expected