Skip to content

Commit

Permalink
Extend concat_where, now also working for nested concat_wheres and ex…
Browse files Browse the repository at this point in the history
…pressions connected by logical operators (they are resolved by NestConcatWheres): running fine in roundtrip_default, gtfn_run_gtfn and gtfn_run_gtfn_imperative
  • Loading branch information
SF-N committed Jan 30, 2025
1 parent c3a18c4 commit ba8343b
Show file tree
Hide file tree
Showing 20 changed files with 381 additions and 89 deletions.
60 changes: 35 additions & 25 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import gt4py.next.ffront.field_operator_ast as foast
from gt4py.eve import NodeTranslator, NodeVisitor, traits
from gt4py.next import errors
from gt4py.next.common import DimensionKind
from gt4py.next.common import DimensionKind, promote_dims
from gt4py.next.ffront import ( # noqa
dialect_ast_enums,
experimental,
Expand All @@ -20,7 +20,7 @@
type_specifications as ts_ffront,
)
from gt4py.next.ffront.foast_passes.utils import compute_assign_indices
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator import builtins
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation


Expand Down Expand Up @@ -574,13 +574,13 @@ def _deduce_compare_type(
if (
isinstance(left.type, ts.DimensionType)
and isinstance(right.type, ts.ScalarType)
and right.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
and right.type.kind == getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())
):
return ts.DomainType(dims=[left.type.dim])
if (
isinstance(right.type, ts.DimensionType)
and isinstance(left.type, ts.ScalarType)
and left.type.kind == getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
and left.type.kind == getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())
):
return ts.DomainType(dims=[right.type.dim])
# TODO
Expand Down Expand Up @@ -626,37 +626,47 @@ def _deduce_binop_type(
dialect_ast_enums.BinaryOperator.BIT_OR,
dialect_ast_enums.BinaryOperator.BIT_XOR,
}
is_compatible = type_info.is_logical if node.op in logical_ops else type_info.is_arithmetic

def tmp(arg):
return type_info.is_logical(arg) or isinstance(arg, ts.DomainType)

is_compatible = tmp if node.op in logical_ops else type_info.is_arithmetic

# check both types compatible
for arg in (left, right):
if not is_compatible(arg.type):
raise errors.DSLError(
arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'."
)
if isinstance(left.type, (ts.ScalarType, ts.FieldType)) and isinstance(
right.type, (ts.ScalarType, ts.FieldType)
):
left_type = cast(ts.FieldType | ts.ScalarType, left.type)
right_type = cast(ts.FieldType | ts.ScalarType, right.type)

left_type = cast(ts.FieldType | ts.ScalarType, left.type)
right_type = cast(ts.FieldType | ts.ScalarType, right.type)

if node.op == dialect_ast_enums.BinaryOperator.POW:
return left_type
if node.op == dialect_ast_enums.BinaryOperator.POW:
return left_type

if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral(
right_type
):
raise errors.DSLError(
arg.location,
f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.",
)
if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral(
right_type
):
raise errors.DSLError(
arg.location,
f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.",
)

try:
return type_info.promote(left_type, right_type)
except ValueError as ex:
raise errors.DSLError(
node.location,
f"Could not promote '{left_type}' and '{right_type}' to common type"
f" in call to '{node.op}'.",
) from ex
try:
return type_info.promote(left_type, right_type)
except ValueError as ex:
raise errors.DSLError(
node.location,
f"Could not promote '{left_type}' and '{right_type}' to common type"
f" in call to '{node.op}'.",
) from ex
elif isinstance(left.type, ts.DomainType) and isinstance(right.type, ts.DomainType):
return ts.DomainType(dims=promote_dims(left.type.dims, right.type.dims))
else:
raise ValueError("TODO")

def _check_operand_dtypes_match(
self, node: foast.BinOp | foast.Compare, left: foast.Expr, right: foast.Expr
Expand Down
23 changes: 22 additions & 1 deletion src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,28 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
raise NotImplementedError(f"Unary operator '{node.op}' is not supported.")

def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall:
return self._lower_and_map(node.op.value, node.left, node.right)
if (
node.op == dialect_ast_enums.BinaryOperator.BIT_AND
and isinstance(node.left.type, ts.DomainType)
and isinstance(node.right.type, ts.DomainType)
):
return im.and_(self.visit(node.left), self.visit(node.right))
if (
node.op == dialect_ast_enums.BinaryOperator.BIT_OR
and isinstance(node.left.type, ts.DomainType)
and isinstance(node.right.type, ts.DomainType)
):
return im.or_(self.visit(node.left), self.visit(node.right))
if (
node.op == dialect_ast_enums.BinaryOperator.BIT_XOR
and isinstance(node.left.type, ts.DomainType)
and isinstance(node.right.type, ts.DomainType)
):
raise NotImplementedError(
f"Binary operator '{node.op}' is not supported for '{node.right.type}' and '{node.right.type}'."
)
else:
return self._lower_and_map(node.op.value, node.left, node.right)

def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall:
assert (
Expand Down
2 changes: 0 additions & 2 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,6 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
"unstructured_domain",
"concat_where",
"in",
"inf", # TODO: discuss
"neg_inf", # TODO: discuss
*ARITHMETIC_BUILTINS,
*TYPE_BUILTINS,
}
Expand Down
10 changes: 10 additions & 0 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ class NoneLiteral(Expr):
_none_literal: int = 0


class InfinityLiteral(Expr):
pass


class NegInfinityLiteral(Expr):
pass


class OffsetLiteral(Expr):
value: Union[int, str]

Expand Down Expand Up @@ -142,3 +150,5 @@ class Program(Node, ValidatedSymbolTableTrait):
Program.__hash__ = Node.__hash__ # type: ignore[method-assign]
SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign]
IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign]
InfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
NegInfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
10 changes: 5 additions & 5 deletions src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ def domain_complement(domain: SymbolicDomain) -> SymbolicDomain:
dims_dict = {}
for dim in domain.ranges.keys():
lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop
if lb == im.ref("neg_inf"):
dims_dict[dim] = SymbolicRange(start=ub, stop=im.ref("inf"))
elif ub == im.ref("inf"):
dims_dict[dim] = SymbolicRange(start=im.ref("neg_inf"), stop=lb)
if isinstance(lb, itir.NegInfinityLiteral):
dims_dict[dim] = SymbolicRange(start=ub, stop=itir.InfinityLiteral())
elif isinstance(ub, itir.InfinityLiteral):
dims_dict[dim] = SymbolicRange(start=itir.NegInfinityLiteral(), stop=lb)
else:
raise ValueError("Invalid domain ranges")
return SymbolicDomain(domain.grid_type, dims_dict)
Expand All @@ -218,5 +218,5 @@ def promote_to_same_dimensions(
lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop
dims_dict[dim] = SymbolicRange(lb, ub)
else:
dims_dict[dim] = SymbolicRange(im.ref("neg_inf"), im.ref("inf"))
dims_dict[dim] = SymbolicRange(itir.NegInfinityLiteral(), itir.InfinityLiteral())
return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured
6 changes: 6 additions & 0 deletions src/gt4py/next/iterator/pretty_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def visit_Sym(self, node: ir.Sym, *, prec: int) -> list[str]:
def visit_Literal(self, node: ir.Literal, *, prec: int) -> list[str]:
return [str(node.value)]

def visit_InfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]:
return ["INF"]

def visit_NegInfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]:
return ["NEG"]

def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]:
return [str(node.value) + "ₒ"]

Expand Down
50 changes: 31 additions & 19 deletions src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,49 @@ def visit_FunCall(self, node: ir.FunCall):
): # `minimum(a, a)` -> `a`
return new_node.args[0]

if cpm.is_call_to(new_node, "minimum"): # TODO: add tests
if cpm.is_call_to(new_node, "minimum"):
# `minimum(neg_inf, neg_inf)` -> `neg_inf`
if cpm.is_ref_to(new_node.args[0], "neg_inf") or cpm.is_ref_to(
new_node.args[1], "neg_inf"
if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance(
new_node.args[1], ir.NegInfinityLiteral
):
return im.ref("neg_inf")
return ir.NegInfinityLiteral()
# `minimum(inf, a)` -> `a`
elif cpm.is_ref_to(new_node.args[0], "inf"):
elif isinstance(new_node.args[0], ir.InfinityLiteral):
return new_node.args[1]
# `minimum(a, inf)` -> `a`
elif cpm.is_ref_to(new_node.args[1], "inf"):
elif isinstance(new_node.args[1], ir.InfinityLiteral):
return new_node.args[0]

if cpm.is_call_to(new_node, "maximum"): # TODO: add tests
if cpm.is_call_to(new_node, "maximum"):
# `minimum(inf, inf)` -> `inf`
if cpm.is_ref_to(new_node.args[0], "inf") or cpm.is_ref_to(new_node.args[1], "inf"):
return im.ref("inf")
if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance(
new_node.args[1], ir.InfinityLiteral
):
return ir.InfinityLiteral()
# `minimum(neg_inf, a)` -> `a`
elif cpm.is_ref_to(new_node.args[0], "neg_inf"):
elif isinstance(new_node.args[0], ir.NegInfinityLiteral):
return new_node.args[1]
# `minimum(a, neg_inf)` -> `a`
elif cpm.is_ref_to(new_node.args[1], "neg_inf"):
elif isinstance(new_node.args[1], ir.NegInfinityLiteral):
return new_node.args[0]
if cpm.is_call_to(new_node, ("less", "less_equal")) and cpm.is_ref_to(
new_node.args[0], "neg_inf"
):
return im.literal_from_value(True) # TODO: add tests
if cpm.is_call_to(new_node, ("greater", "greater_equal")) and cpm.is_ref_to(
new_node.args[0], "inf"
):
return im.literal_from_value(True) # TODO: add tests
if cpm.is_call_to(new_node, ("less", "less_equal")):
if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance(
new_node.args[1], ir.InfinityLiteral
):
return im.literal_from_value(True)
if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance(
new_node.args[1], ir.NegInfinityLiteral
):
return im.literal_from_value(False)
if cpm.is_call_to(new_node, ("greater", "greater_equal")):
if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance(
new_node.args[1], ir.InfinityLiteral
):
return im.literal_from_value(False)
if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance(
new_node.args[1], ir.NegInfinityLiteral
):
return im.literal_from_value(True)
if (
isinstance(new_node.fun, ir.SymRef)
and new_node.fun.id == "if_"
Expand Down
16 changes: 15 additions & 1 deletion src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,21 @@ def _is_collectable_expr(node: itir.Node) -> bool:
# conceptual problems (other parts of the tool chain rely on the arguments being present directly
# on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend
# backend (single pass eager depth first visit approach)
if isinstance(node.fun, itir.SymRef) and node.fun.id in ["lift", "shift", "reduce", "map_"]:
# do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement
# instead of an as_fieldop
if isinstance(node.fun, itir.SymRef) and node.fun.id in [
"lift",
"shift",
"reduce",
"map_",
"index",
]:
return False
# do also not collect make_tuple(index) nodes because otherwise the right hand side of SetAts becomes a let statement
# instead of an as_fieldop
if cpm.is_call_to(node, "make_tuple") and all(
cpm.is_call_to(arg, "index") for arg in node.args
):
return False
return True
elif isinstance(node, itir.Lambda):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall:
if cpm.is_call_to(node, "in"):
ret = []
pos, domain = node.args
for i, (k, v) in enumerate(
for i, (_, v) in enumerate(
domain_utils.SymbolicDomain.from_expr(node.args[1]).ranges.items()
):
ret.append(
Expand Down
Loading

0 comments on commit ba8343b

Please sign in to comment.