From 45d7f3ad8188c0de462780818af6bd8df0c4e869 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 28 Apr 2023 10:58:23 +0200 Subject: [PATCH] Disable CSE for an `if` with a condition calling `can_deref` (#1246) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently the common subexpression elimination happily extracts from expressions of the form `if_(can_deref(...), ..., ...)` which results in _deref_-ing iterators before it is checked whether they are _deref_-able. For example the following expression ``` if can_deref(⟪Iₒ, 1ₒ⟫(it)) then ·⟪Iₒ, 1ₒ⟫(it) else ·⟪Iₒ, 1ₒ⟫(it) ``` is transformed into ``` (λ(_cs_3) → (λ(_cs_1) → if can_deref(_cs_3) then _cs_1 else _cs_1)(·_cs_3))(⟪Iₒ, 1ₒ⟫(it)) ``` With the evaluation order used in our backends `·_cs_3` is executed unconditionally which works for the embedded backend, but fails with an assertion error in the gtfn backend. This PR fixes that by disabling CSE on all such expressions. Since we want to avoid such cases also for the embedded backend it is enabled everywhere. The pass could still be improved as we miss extracting expressions that could be extracted, but it works for now. --- src/gt4py/next/iterator/transforms/cse.py | 21 ++++++++++++++++-- .../ffront_tests/test_gt4py_builtins.py | 14 ++++++++++++ .../transforms_tests/test_cse.py | 22 +++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index fde6f90292..a030d61b26 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -66,6 +66,16 @@ def _is_collectable_expr(node: ir.Node): return False +def _is_if_can_deref(node: ir.Node): + # `if_(can_deref(...), ..., ...)` + return ( + isinstance(node, ir.FunCall) + and node.fun == ir.SymRef(id="if_") + and isinstance(node.args[0], ir.FunCall) + and node.args[0].fun == ir.SymRef(id="can_deref") + ) + + @dataclasses.dataclass class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor): subexprs: dict[ir.Node, list[tuple[int, set[int]]]] = dataclasses.field( @@ -75,12 +85,19 @@ class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor): @classmethod def apply(cls, node: ir.Node): obj = cls() - obj.visit(node, used_symbol_ids=set(), collected_child_node_ids=set()) + obj.visit( + node, used_symbol_ids=set(), collected_child_node_ids=set(), allow_collection=True + ) # return subexpression in pre-order of the tree, i.e. the nodes closer to the root come # first, and skip the root node itself return {k: v for k, v in reversed(obj.subexprs.items()) if k is not node} def visit(self, node, **kwargs): + # TODO(tehrengruber): improve this case as we might miss subexpression that could be eliminated + # disable collection (for all child nodes) if node matches `if_(can_deref(...), ..., ...)` + if _is_if_can_deref(node): + kwargs["allow_collection"] = False + if not isinstance(node, SymbolTableTrait) and not _is_collectable_expr(node): return super().visit(node, **kwargs) @@ -101,7 +118,7 @@ def visit(self, node, **kwargs): # if no symbols are used that are defined in the root node, i.e. the node given to `apply`, # we collect the subexpression - if not used_symbol_ids and _is_collectable_expr(node): + if not used_symbol_ids and _is_collectable_expr(node) and kwargs["allow_collection"]: self.subexprs.setdefault(node, []).append((id(node), collected_child_node_ids)) # propagate to parent that we have collected its child diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index d01a1bef6c..871528197d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -153,6 +153,20 @@ def fencil(edge_f: Field[[Edge], int64], out: Field[[Vertex], int64]) -> None: assert np.allclose(ref, rs.out.array()) +def test_reduction_with_common_expression(reduction_setup, fieldview_backend): + rs = reduction_setup + V2EDim, V2E = rs.V2EDim, rs.V2E + + @field_operator(backend=fieldview_backend) + def testee(flux: Field[[Edge], int64]) -> Field[[Vertex], int64]: + return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) + + testee(rs.inp, out=rs.out, offset_provider=rs.offset_provider) + + ref = np.sum(rs.v2e_table * 2, axis=1) + assert np.allclose(ref, rs.out.array()) + + def test_conditional_nested_tuple(fieldview_backend): a_I_float = np_as_located_field(IDim)(np.random.randn(size).astype("float64")) b_I_float = np_as_located_field(IDim)(np.random.randn(size).astype("float64")) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index afbd0c1536..dbbbfd97a9 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -131,3 +131,25 @@ def common_expr(): ) actual = CSE().visit(testee) assert actual == expected + + +def test_if_can_deref(): + """ + Test no subexpression is moved outside expressions of the form `if_(can_deref(...), ..., ...)` + """ + # if can_deref(⟪Iₒ, 1ₒ⟫(it)) then ·⟪Iₒ, 1ₒ⟫(it) else ·⟪Iₒ, 1ₒ⟫(it) + ·⟪Iₒ, 1ₒ⟫(it) + testee = im.if_( + im.call("can_deref")(im.shift("I", 1)("it")), + im.deref(im.shift("I", 1)("it")), + # use something more involved where a subexpression can still be eliminated + im.plus(im.deref(im.shift("I", 1)("it")), im.deref(im.shift("I", 1)("it"))), + ) + # if can_deref(⟪Iₒ, 1ₒ⟫(it)) then ·⟪Iₒ, 1ₒ⟫(it) else (λ(_cs_1) → _cs_1 + _cs_1)(·⟪Iₒ, 1ₒ⟫(it)) + expected = im.if_( + im.call("can_deref")(im.shift("I", 1)("it")), + im.deref(im.shift("I", 1)("it")), + im.let("_cs_1", im.deref(im.shift("I", 1)("it")))(im.plus("_cs_1", "_cs_1")), + ) + + actual = CSE().visit(testee) + assert actual == expected