Skip to content

Commit

Permalink
Disable CSE for an if with a condition calling can_deref (#1246)
Browse files Browse the repository at this point in the history
Currently the common subexpression elimination happily extracts from expressions of the form `if_(can_deref(...), ..., ...)` which results in _deref_-ing iterators before it is checked whether they are _deref_-able. For example the following expression
```
if can_deref(⟪Iₒ, 1ₒ⟫(it)) then ·⟪Iₒ, 1ₒ⟫(it) else ·⟪Iₒ, 1ₒ⟫(it)
```
is transformed into
```
(λ(_cs_3) → (λ(_cs_1) → if can_deref(_cs_3) then _cs_1 else _cs_1)(·_cs_3))(⟪Iₒ, 1ₒ⟫(it))
```
With the evaluation order used in our backends `·_cs_3` is executed unconditionally which works for the embedded backend, but fails with an assertion error in the gtfn backend. This PR fixes that by disabling CSE on all such expressions. Since we want to avoid such cases also for the embedded backend it is enabled everywhere. 

The pass could still be improved as we miss extracting expressions that could be extracted, but it works for now.
  • Loading branch information
tehrengruber authored Apr 28, 2023
1 parent ae683f0 commit 45d7f3a
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 2 deletions.
21 changes: 19 additions & 2 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ def _is_collectable_expr(node: ir.Node):
return False


def _is_if_can_deref(node: ir.Node):
# `if_(can_deref(...), ..., ...)`
return (
isinstance(node, ir.FunCall)
and node.fun == ir.SymRef(id="if_")
and isinstance(node.args[0], ir.FunCall)
and node.args[0].fun == ir.SymRef(id="can_deref")
)


@dataclasses.dataclass
class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor):
subexprs: dict[ir.Node, list[tuple[int, set[int]]]] = dataclasses.field(
Expand All @@ -75,12 +85,19 @@ class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor):
@classmethod
def apply(cls, node: ir.Node):
obj = cls()
obj.visit(node, used_symbol_ids=set(), collected_child_node_ids=set())
obj.visit(
node, used_symbol_ids=set(), collected_child_node_ids=set(), allow_collection=True
)
# return subexpression in pre-order of the tree, i.e. the nodes closer to the root come
# first, and skip the root node itself
return {k: v for k, v in reversed(obj.subexprs.items()) if k is not node}

def visit(self, node, **kwargs):
# TODO(tehrengruber): improve this case as we might miss subexpression that could be eliminated
# disable collection (for all child nodes) if node matches `if_(can_deref(...), ..., ...)`
if _is_if_can_deref(node):
kwargs["allow_collection"] = False

if not isinstance(node, SymbolTableTrait) and not _is_collectable_expr(node):
return super().visit(node, **kwargs)

Expand All @@ -101,7 +118,7 @@ def visit(self, node, **kwargs):

# if no symbols are used that are defined in the root node, i.e. the node given to `apply`,
# we collect the subexpression
if not used_symbol_ids and _is_collectable_expr(node):
if not used_symbol_ids and _is_collectable_expr(node) and kwargs["allow_collection"]:
self.subexprs.setdefault(node, []).append((id(node), collected_child_node_ids))

# propagate to parent that we have collected its child
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,20 @@ def fencil(edge_f: Field[[Edge], int64], out: Field[[Vertex], int64]) -> None:
assert np.allclose(ref, rs.out.array())


def test_reduction_with_common_expression(reduction_setup, fieldview_backend):
rs = reduction_setup
V2EDim, V2E = rs.V2EDim, rs.V2E

@field_operator(backend=fieldview_backend)
def testee(flux: Field[[Edge], int64]) -> Field[[Vertex], int64]:
return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim)

testee(rs.inp, out=rs.out, offset_provider=rs.offset_provider)

ref = np.sum(rs.v2e_table * 2, axis=1)
assert np.allclose(ref, rs.out.array())


def test_conditional_nested_tuple(fieldview_backend):
a_I_float = np_as_located_field(IDim)(np.random.randn(size).astype("float64"))
b_I_float = np_as_located_field(IDim)(np.random.randn(size).astype("float64"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,25 @@ def common_expr():
)
actual = CSE().visit(testee)
assert actual == expected


def test_if_can_deref():
"""
Test no subexpression is moved outside expressions of the form `if_(can_deref(...), ..., ...)`
"""
# if can_deref(⟪Iₒ, 1ₒ⟫(it)) then ·⟪Iₒ, 1ₒ⟫(it) else ·⟪Iₒ, 1ₒ⟫(it) + ·⟪Iₒ, 1ₒ⟫(it)
testee = im.if_(
im.call("can_deref")(im.shift("I", 1)("it")),
im.deref(im.shift("I", 1)("it")),
# use something more involved where a subexpression can still be eliminated
im.plus(im.deref(im.shift("I", 1)("it")), im.deref(im.shift("I", 1)("it"))),
)
# if can_deref(⟪Iₒ, 1ₒ⟫(it)) then ·⟪Iₒ, 1ₒ⟫(it) else (λ(_cs_1) → _cs_1 + _cs_1)(·⟪Iₒ, 1ₒ⟫(it))
expected = im.if_(
im.call("can_deref")(im.shift("I", 1)("it")),
im.deref(im.shift("I", 1)("it")),
im.let("_cs_1", im.deref(im.shift("I", 1)("it")))(im.plus("_cs_1", "_cs_1")),
)

actual = CSE().visit(testee)
assert actual == expected

0 comments on commit 45d7f3a

Please sign in to comment.