Skip to content

Commit

Permalink
feat[next]: Extend the IR pass for pruning of unnecessary casts (#1728)
Browse files Browse the repository at this point in the history
Extend the IR pass delivered in #1688 

Pruning of cast expressions may appear as a `as_fieldop` expression with
form `(⇑(λ(__val) → cast_(·__val, float64)))(a)`, where `a` is already a
field with data type `float64` in this example. This PR adds pruning of
such trivial expressions.
  • Loading branch information
edopao authored Nov 20, 2024
1 parent 9dbc884 commit 5e93736
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 40 deletions.
4 changes: 1 addition & 3 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr:

def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall:
if isinstance(t[0], ts.FieldType):
return im.as_fieldop(
im.lambda_("__val")(im.call("cast_")(im.deref("__val"), str(new_type)))
)(expr)
return im.cast_as_fieldop(str(new_type))(expr)
else:
assert isinstance(t[0], ts.ScalarType)
return im.call("cast_")(expr, str(new_type))
Expand Down
26 changes: 26 additions & 0 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import TypeGuard

from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im


def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
Expand Down Expand Up @@ -84,3 +85,28 @@ def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunC

def is_ref_to(node, ref: str):
return isinstance(node, itir.SymRef) and node.id == ref


def is_identity_as_fieldop(node: itir.Expr):
"""
Match field operators implementing element-wise copy of a field argument,
that is expressions of the form `as_fieldop(stencil)(*args)`
>>> from gt4py.next.iterator.ir_utils import ir_makers as im
>>> node = im.as_fieldop(im.lambda_("__arg0")(im.deref("__arg0")))("a")
>>> is_identity_as_fieldop(node)
True
>>> node = im.as_fieldop("deref")("a")
>>> is_identity_as_fieldop(node)
False
"""
if not is_applied_as_fieldop(node):
return False
stencil = node.fun.args[0] # type: ignore[attr-defined]
if (
isinstance(stencil, itir.Lambda)
and len(stencil.params) == 1
and stencil == im.lambda_(stencil.params[0])(im.deref(stencil.params[0].id))
):
return True
return False
22 changes: 22 additions & 0 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,28 @@ def _impl(*its: itir.Expr) -> itir.FunCall:
return _impl


def cast_as_fieldop(type_: str, domain: Optional[itir.FunCall] = None):
"""
Promotes the function `cast_` to a field_operator.
Args:
type_: the target type to be passed as argument to `cast_` function.
domain: the domain of the returned field.
Returns:
A function from Fields to Field.
Examples:
>>> str(cast_as_fieldop("float32")("a"))
'(⇑(λ(__arg0) → cast_(·__arg0, float32)))(a)'
"""

def _impl(it: itir.Expr) -> itir.FunCall:
return op_as_fieldop(lambda v: call("cast_")(v, type_), domain)(it)

return _impl


def map_(op):
"""Create a `map_` call."""
return call(call("map_")(op))
31 changes: 18 additions & 13 deletions src/gt4py/next/iterator/transforms/prune_casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,42 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py import eve
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.type_system import type_specifications as ts


class PruneCasts(PreserveLocationVisitor, NodeTranslator):
class PruneCasts(eve.NodeTranslator):
"""
Removes cast expressions where the argument is already in the target type.
This transformation requires the IR to be fully type-annotated,
therefore it should be applied after type-inference.
"""

PRESERVED_ANNEX_ATTRS = ("domain",)

def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
node = self.generic_visit(node)

if not cpm.is_call_to(node, "cast_"):
return node
if cpm.is_call_to(node, "cast_"):
value, type_constructor = node.args

value, type_constructor = node.args
assert (
value.type
and isinstance(type_constructor, ir.SymRef)
and (type_constructor.id in ir.TYPEBUILTINS)
)
dtype = ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper()))

assert (
value.type
and isinstance(type_constructor, ir.SymRef)
and (type_constructor.id in ir.TYPEBUILTINS)
)
dtype = ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper()))
if value.type == dtype:
return value

if value.type == dtype:
return value
elif cpm.is_identity_as_fieldop(node):
# pruning of cast expressions may leave some trivial `as_fieldop` expressions
# with form '(⇑(λ(__arg) → ·__arg))(a)'
return node.args[0]

return node

Expand Down
28 changes: 7 additions & 21 deletions tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,7 @@ def foo(a: gtx.Field[[TDim], float64]):
parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)

reference = im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))(
"a"
)
reference = im.cast_as_fieldop("int32")("a")

assert lowered.expr == reference

Expand All @@ -312,12 +310,8 @@ def foo(a: tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]]):
lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered)

reference = im.make_tuple(
im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))(
im.tuple_get(0, "a")
),
im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))(
im.tuple_get(1, "a")
),
im.cast_as_fieldop("int32")(im.tuple_get(0, "a")),
im.cast_as_fieldop("int32")(im.tuple_get(1, "a")),
)

assert lowered_inlined.expr == reference
Expand All @@ -332,9 +326,7 @@ def foo(a: tuple[gtx.Field[[TDim], float64], float64]):
lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered)

reference = im.make_tuple(
im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))(
im.tuple_get(0, "a")
),
im.cast_as_fieldop("int32")(im.tuple_get(0, "a")),
im.call("cast_")(im.tuple_get(1, "a"), "int32"),
)

Expand All @@ -356,16 +348,10 @@ def foo(

reference = im.make_tuple(
im.make_tuple(
im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))(
im.tuple_get(0, im.tuple_get(0, "a"))
),
im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))(
im.tuple_get(1, im.tuple_get(0, "a"))
),
),
im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))(
im.tuple_get(1, "a")
im.cast_as_fieldop("int32")(im.tuple_get(0, im.tuple_get(0, "a"))),
im.cast_as_fieldop("int32")(im.tuple_get(1, im.tuple_get(0, "a"))),
),
im.cast_as_fieldop("int32")(im.tuple_get(1, "a")),
)

assert lowered_inlined.expr == reference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py import next as gtx
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.type_system import type_specifications as ts
from gt4py.next.iterator.transforms.prune_casts import PruneCasts
Expand All @@ -21,3 +22,21 @@ def test_prune_casts_simple():
expected = im.call("plus")(im.call("cast_")(x_ref, "float64"), y_ref)
actual = PruneCasts.apply(testee)
assert actual == expected


def test_prune_casts_fieldop():
IDim = gtx.Dimension("IDim")
x_ref = im.ref("x", ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)))
y_ref = im.ref("y", ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)))
testee = im.op_as_fieldop("plus")(
im.cast_as_fieldop("float64")(x_ref),
im.cast_as_fieldop("float64")(y_ref),
)
testee = type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True)

expected = im.op_as_fieldop("plus")(
im.cast_as_fieldop("float64")(x_ref),
y_ref,
)
actual = PruneCasts.apply(testee)
assert actual == expected
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ def test_gtir_cast():
body=[
gtir.SetAt(
expr=im.op_as_fieldop("eq", domain)(
im.as_fieldop(
im.lambda_("a")(im.call("cast_")(im.deref("a"), "float32")), domain
)("x"),
im.cast_as_fieldop("float32", domain)("x"),
"y",
),
domain=domain,
Expand Down

0 comments on commit 5e93736

Please sign in to comment.