Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gtir concat where 2 #1806

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivi

@WhereBuiltinFunction
def concat_where(
mask: common.Field,
mask: common.Domain,
true_field: common.Field | core_defs.ScalarT | Tuple,
false_field: common.Field | core_defs.ScalarT | Tuple,
/,
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp
return ts.OffsetType
elif t is core_defs.ScalarT:
return ts.ScalarType
elif t is common.Domain:
return ts.DomainType
elif t is type:
return (
ts.FunctionType
Expand Down Expand Up @@ -127,7 +129,7 @@ def __gt_type__(self) -> ts.FunctionType:
)


MaskT = TypeVar("MaskT", bound=common.Field)
MaskT = TypeVar("MaskT", bound=Union[common.Field, common.Domain])
FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple])


Expand Down
34 changes: 33 additions & 1 deletion src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
type_specifications as ts_ffront,
)
from gt4py.next.ffront.foast_passes.utils import compute_assign_indices
from gt4py.next.iterator import ir as itir
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation


Expand Down Expand Up @@ -570,6 +571,19 @@ def _deduce_compare_type(
self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
) -> Optional[ts.TypeSpec]:
# check both types compatible
if (
isinstance(left.type, ts.DimensionType)
and isinstance(right.type, ts.ScalarType)
and right.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
):
return ts.DomainType(dims=[left.type.dim])
if (
isinstance(right.type, ts.DimensionType)
and isinstance(left.type, ts.ScalarType)
and left.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
):
return ts.DomainType(dims=[right.type.dim])
# TODO
for arg in (left, right):
if not type_info.is_arithmetic(arg.type):
raise errors.DSLError(
Expand Down Expand Up @@ -908,6 +922,7 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
)

try:
# TODO(tehrengruber): the construct_tuple_type function doesn't look correct
if isinstance(true_branch_type, ts.TupleType) and isinstance(
false_branch_type, ts.TupleType
):
Expand Down Expand Up @@ -943,7 +958,24 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
location=node.location,
)

_visit_concat_where = _visit_where
def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
true_branch_type = node.args[1].type
false_branch_type = node.args[2].type
if true_branch_type != false_branch_type:
raise errors.DSLError(
node.location,
f"Incompatible argument in call to '{node.func!s}': expected "
f"'{true_branch_type}' and '{false_branch_type}' to be equal.",
)
return_type = true_branch_type

return foast.Call(
func=node.func,
args=node.args,
kwargs=node.kwargs,
type=return_type,
location=node.location,
)

def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> foast.Call:
arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type)
Expand Down
16 changes: 13 additions & 3 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ def visit_Assign(
def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym:
return im.sym(node.id)

def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef:
def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef | itir.AxisLiteral:
if isinstance(node.type, ts.DimensionType):
return itir.AxisLiteral(value=node.type.dim.value, kind=node.type.dim.kind)
return im.ref(node.id)

def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr:
Expand Down Expand Up @@ -261,6 +263,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC
)

def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall:
# TODO: double-check if we need the changes in the original PR
return self._lower_and_map(node.op.value, node.left, node.right)

def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
Expand Down Expand Up @@ -394,7 +397,13 @@ def create_if(

return im.let(cond_symref_name, cond_)(result)

_visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where
def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
if not isinstance(node.type, ts.TupleType): # to keep the IR simpler
return im.call("concat_where")(*self.visit(node.args))
else:
raise NotImplementedError()

# TODO: tuple case

def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
expr = self.visit(node.args[0], **kwargs)
Expand Down Expand Up @@ -476,8 +485,9 @@ def _map(
"""
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists.
"""
# TODO double-check that this code is consistent with the changes in the original PR
if all(
isinstance(t, ts.ScalarType)
isinstance(t, (ts.ScalarType, ts.DimensionType))
for arg_type in original_arg_types
for t in type_info.primitive_constituents(arg_type)
):
Expand Down
4 changes: 4 additions & 0 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ class FunctionDefinition(Node, SymbolTableTrait):
"if_",
"index", # `index(dim)` creates a dim-field that has the current index at each point
"as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution)
"concat_where",
"in",
"inf", # TODO: discuss
"neg_inf", # TODO: discuss
*ARITHMETIC_BUILTINS,
*TYPEBUILTINS,
}
Expand Down
46 changes: 46 additions & 0 deletions src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,49 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain:
new_domain_ranges[dim] = SymbolicRange(start, stop)

return SymbolicDomain(domains[0].grid_type, new_domain_ranges)


def domain_intersection(*domains: SymbolicDomain) -> SymbolicDomain:
"""Return the (set) intersection of a list of domains."""
new_domain_ranges = {}
assert all(domain.grid_type == domains[0].grid_type for domain in domains)
for dim in domains[0].ranges.keys():
start = functools.reduce(
lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr),
[domain.ranges[dim].start for domain in domains],
)
stop = functools.reduce(
lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr),
[domain.ranges[dim].stop for domain in domains],
)
new_domain_ranges[dim] = SymbolicRange(start, stop)

return SymbolicDomain(domains[0].grid_type, new_domain_ranges)


def domain_complement(domain: SymbolicDomain) -> SymbolicDomain:
"""Return the (set) complement of a domain."""
dims_dict = {}
for dim in domain.ranges.keys():
lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop
if lb == im.ref("neg_inf"):
dims_dict[dim] = SymbolicRange(start=ub, stop=im.ref("inf"))
elif ub == im.ref("inf"):
dims_dict[dim] = SymbolicRange(start=im.ref("neg_inf"), stop=lb)
else:
raise ValueError("Invalid domain ranges")
return SymbolicDomain(domain.grid_type, dims_dict)


def promote_to_same_dimensions(
domain_small: SymbolicDomain, domain_large: SymbolicDomain
) -> SymbolicDomain:
"""Return an extended domain based on a smaller input domain and a larger domain containing the target dimensions."""
dims_dict = {}
for dim in domain_large.ranges.keys():
if dim in domain_small.ranges.keys():
lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop
dims_dict[dim] = SymbolicRange(lb, ub)
else:
dims_dict[dim] = SymbolicRange(im.ref("neg_inf"), im.ref("inf"))
return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured
5 changes: 5 additions & 0 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,11 @@ def if_(cond, true_val, false_val):
return call("if_")(cond, true_val, false_val)


def concat_where(cond, true_field, false_field):
"""Create a concat_where FunCall, shorthand for ``call("concat_where")(expr)``."""
return call("concat_where")(cond, true_field, false_field)


def lift(expr):
"""Create a lift FunCall, shorthand for ``call(call("lift")(expr))``."""
return call(call("lift")(expr))
Expand Down
38 changes: 34 additions & 4 deletions src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import embedded, ir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im


class ConstantFolding(PreserveLocationVisitor, NodeTranslator):
Expand All @@ -21,12 +21,42 @@ def visit_FunCall(self, node: ir.FunCall):
new_node = self.generic_visit(node)

if (
isinstance(new_node.fun, ir.SymRef)
and new_node.fun.id in ["minimum", "maximum"]
cpm.is_call_to(new_node, ("minimum", "maximum"))
and new_node.args[0] == new_node.args[1]
): # `minimum(a, a)` -> `a`
return new_node.args[0]

if cpm.is_call_to(new_node, "minimum"): # TODO: add tests
# `minimum(neg_inf, neg_inf)` -> `neg_inf`
if cpm.is_ref_to(new_node.args[0], "neg_inf") or cpm.is_ref_to(
new_node.args[1], "neg_inf"
):
return im.ref("neg_inf")
# `minimum(inf, a)` -> `a`
elif cpm.is_ref_to(new_node.args[0], "inf"):
return new_node.args[1]
# `minimum(a, inf)` -> `a`
elif cpm.is_ref_to(new_node.args[1], "inf"):
return new_node.args[0]

if cpm.is_call_to(new_node, "maximum"): # TODO: add tests
# `minimum(inf, inf)` -> `inf`
if cpm.is_ref_to(new_node.args[0], "inf") or cpm.is_ref_to(new_node.args[1], "inf"):
return im.ref("inf")
# `minimum(neg_inf, a)` -> `a`
elif cpm.is_ref_to(new_node.args[0], "neg_inf"):
return new_node.args[1]
# `minimum(a, neg_inf)` -> `a`
elif cpm.is_ref_to(new_node.args[1], "neg_inf"):
return new_node.args[0]
if cpm.is_call_to(new_node, ("less", "less_equal")) and cpm.is_ref_to(
new_node.args[0], "neg_inf"
):
return im.literal_from_value(True) # TODO: add tests
if cpm.is_call_to(new_node, ("greater", "greater_equal")) and cpm.is_ref_to(
new_node.args[0], "inf"
):
return im.literal_from_value(True) # TODO: add tests
if (
isinstance(new_node.fun, ir.SymRef)
and new_node.fun.id == "if_"
Expand All @@ -52,6 +82,6 @@ def visit_FunCall(self, node: ir.FunCall):
]
new_node = im.literal_from_value(fun(*arg_values))
except ValueError:
pass # happens for inf and neginf
pass # happens for SymRefs which are not inf or neg_inf

return new_node
39 changes: 39 additions & 0 deletions src/gt4py/next/iterator/transforms/expand_library_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from functools import reduce

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import (
common_pattern_matcher as cpm,
domain_utils,
ir_makers as im,
)


class ExpandLibraryFunctions(PreserveLocationVisitor, NodeTranslator):
@classmethod
def apply(cls, node: ir.Node):
return cls().visit(node)

def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall:
if cpm.is_call_to(node, "in"):
ret = []
pos, domain = node.args
for i, (k, v) in enumerate(
domain_utils.SymbolicDomain.from_expr(node.args[1]).ranges.items()
):
ret.append(
im.and_(
im.less_equal(v.start, im.tuple_get(i, pos)),
im.less(im.tuple_get(i, pos), v.stop),
)
) # TODO: avoid pos duplication
return reduce(im.and_, ret)
return self.generic_visit(node)
32 changes: 32 additions & 0 deletions src/gt4py/next/iterator/transforms/infer_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,36 @@ def _infer_if(
return result_expr, actual_domains


def _infer_concat_where(
expr: itir.Expr,
domain: DomainAccess,
**kwargs: Unpack[InferenceOptions],
) -> tuple[itir.Expr, AccessedDomains]:
assert cpm.is_call_to(expr, "concat_where")
assert isinstance(domain, domain_utils.SymbolicDomain)
infered_args_expr = []
actual_domains: AccessedDomains = {}
cond, true_field, false_field = expr.args
symbolic_cond = domain_utils.SymbolicDomain.from_expr(cond)
for arg in [true_field, false_field]:
if arg == true_field:
extended_cond = domain_utils.promote_to_same_dimensions(symbolic_cond, domain)
domain_ = domain_utils.domain_intersection(domain, extended_cond)
elif arg == false_field:
cond_complement = domain_utils.domain_complement(symbolic_cond)
extended_cond_complement = domain_utils.promote_to_same_dimensions(
cond_complement, domain
)
domain_ = domain_utils.domain_intersection(domain, extended_cond_complement)

infered_arg_expr, actual_domains_arg = infer_expr(arg, domain_, **kwargs)
infered_args_expr.append(infered_arg_expr)
actual_domains = _merge_domains(actual_domains, actual_domains_arg)

result_expr = im.call(expr.fun)(cond, *infered_args_expr)
return result_expr, actual_domains


def _infer_expr(
expr: itir.Expr,
domain: DomainAccess,
Expand All @@ -382,6 +412,8 @@ def _infer_expr(
return _infer_tuple_get(expr, domain, **kwargs)
elif cpm.is_call_to(expr, "if_"):
return _infer_if(expr, domain, **kwargs)
elif cpm.is_call_to(expr, "concat_where"):
return _infer_concat_where(expr, domain, **kwargs)
elif (
cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS)
or cpm.is_call_to(expr, itir.TYPEBUILTINS)
Expand Down
Loading
Loading