Skip to content

Commit

Permalink
Some fixes, tuples still not supported
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N committed Jan 31, 2025
1 parent ba8343b commit f90329e
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 58 deletions.
67 changes: 39 additions & 28 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,18 +571,35 @@ def _deduce_compare_type(
self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any
) -> Optional[ts.TypeSpec]:
# check both types compatible
left_t, right_t = left.type, right.type
integer_kind = getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())
if (
isinstance(left.type, ts.DimensionType)
and isinstance(right.type, ts.ScalarType)
and right.type.kind == getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())
isinstance(left_t, ts.DimensionType)
and isinstance(right_t, ts.ScalarType)
and right_t.kind == integer_kind
):
return ts.DomainType(dims=[left_t.dim])
if (
isinstance(right_t, ts.DimensionType)
and isinstance(left_t, ts.ScalarType)
and left_t.kind == integer_kind
):
return ts.DomainType(dims=[left.type.dim])
return ts.DomainType(dims=[right_t.dim])
if (
isinstance(right.type, ts.DimensionType)
and isinstance(left.type, ts.ScalarType)
and left.type.kind == getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())
isinstance(left_t, ts.OffsetType)
and left.op == dialect_ast_enums.BinaryOperator.MOD
and isinstance(right_t, ts.ScalarType)
and right_t.kind == integer_kind
) or (
isinstance(right_t, ts.OffsetType)
and right.op == dialect_ast_enums.BinaryOperator.MOD
and isinstance(left_t, ts.ScalarType)
and left_t.kind == integer_kind
):
return ts.DomainType(dims=[right.type.dim])
raise errors.DSLError(
left.location, "Type 'ts.OffsetType' can not be used in operator 'mod'."
)

# TODO
for arg in (left, right):
if not type_info.is_arithmetic(arg.type):
Expand All @@ -596,13 +613,13 @@ def _deduce_compare_type(
# transform operands to have bool dtype and use regular promotion
# mechanism to handle dimension promotion
return type_info.promote(
with_altered_scalar_kind(left.type, ts.ScalarKind.BOOL),
with_altered_scalar_kind(right.type, ts.ScalarKind.BOOL),
with_altered_scalar_kind(left_t, ts.ScalarKind.BOOL),
with_altered_scalar_kind(right_t, ts.ScalarKind.BOOL),
)
except ValueError as ex:
raise errors.DSLError(
node.location,
f"Could not promote '{left.type}' and '{right.type}' to common type"
f"Could not promote '{left_t}' and '{right_t}' to common type"
f" in call to '{node.op}'.",
) from ex

Expand All @@ -627,10 +644,10 @@ def _deduce_binop_type(
dialect_ast_enums.BinaryOperator.BIT_XOR,
}

def tmp(arg):
def is_logical_or_domain(arg: ts.TypeSpec) -> bool:
return type_info.is_logical(arg) or isinstance(arg, ts.DomainType)

is_compatible = tmp if node.op in logical_ops else type_info.is_arithmetic
is_compatible = is_logical_or_domain if node.op in logical_ops else type_info.is_arithmetic

# check both types compatible
for arg in (left, right):
Expand All @@ -641,26 +658,23 @@ def tmp(arg):
if isinstance(left.type, (ts.ScalarType, ts.FieldType)) and isinstance(
right.type, (ts.ScalarType, ts.FieldType)
):
left_type = cast(ts.FieldType | ts.ScalarType, left.type)
right_type = cast(ts.FieldType | ts.ScalarType, right.type)

if node.op == dialect_ast_enums.BinaryOperator.POW:
return left_type
return left.type

if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral(
right_type
right.type
):
raise errors.DSLError(
arg.location,
f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.",
f"Type '{right.type}' can not be used in operator '{node.op}', it only accepts 'int'.",
)

try:
return type_info.promote(left_type, right_type)
return type_info.promote(left.type, right.type)
except ValueError as ex:
raise errors.DSLError(
node.location,
f"Could not promote '{left_type}' and '{right_type}' to common type"
f"Could not promote '{left.type}' and '{right.type}' to common type"
f" in call to '{node.op}'.",
) from ex
elif isinstance(left.type, ts.DomainType) and isinstance(right.type, ts.DomainType):
Expand Down Expand Up @@ -971,13 +985,10 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call:
true_branch_type = node.args[1].type
false_branch_type = node.args[2].type
if true_branch_type != false_branch_type:
raise errors.DSLError(
node.location,
f"Incompatible argument in call to '{node.func!s}': expected "
f"'{true_branch_type}' and '{false_branch_type}' to be equal.",
)
return_type = true_branch_type
true_branch_fieldtype = cast(ts.FieldType, true_branch_type)
false_branch_fieldtype = cast(ts.FieldType, false_branch_type)
promoted_type = type_info.promote(true_branch_fieldtype, false_branch_fieldtype)
return_type = promoted_type

return foast.Call(
func=node.func,
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/pretty_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def visit_InfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]:
return ["INF"]

def visit_NegInfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]:
return ["NEG"]
return ["-INF"]

def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]:
return [str(node.value) + "ₒ"]
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ConstantFolding(PreserveLocationVisitor, NodeTranslator):
def apply(cls, node: ir.Node) -> ir.Node:
return cls().visit(node)

def visit_FunCall(self, node: ir.FunCall):
def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
# visit depth-first such that nested constant expressions (e.g. `(1+2)+3`) are properly folded
new_node = self.generic_visit(node)

Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/infer_domain_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class InferDomainOps(PreserveLocationVisitor, NodeTranslator):
def apply(cls, node: ir.Node):
return cls().visit(node)

def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall:
def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
node = self.generic_visit(node)
if (
cpm.is_call_to(node, builtins.BINARY_MATH_COMPARISON_BUILTINS)
Expand All @@ -33,11 +33,11 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall:
arg1, arg2 = node.args
fun = node.fun
if isinstance(arg1, ir.AxisLiteral) and isinstance(arg2, ir.Literal):
dim = common.Dimension(value=arg1.value, kind=common.DimensionKind.VERTICAL)
dim = common.Dimension(value=arg1.value, kind=arg1.kind)
value = int(arg2.value)
reverse = False
elif isinstance(arg1, ir.Literal) and isinstance(arg2, ir.AxisLiteral):
dim = common.Dimension(value=arg2.value, kind=common.DimensionKind.VERTICAL)
dim = common.Dimension(value=arg2.value, kind=arg2.kind)
value = int(arg1.value)
reverse = True
else:
Expand Down
20 changes: 11 additions & 9 deletions src/gt4py/next/iterator/transforms/nest_concat_wheres.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,10 @@

from gt4py.eve import NodeTranslator, PreserveLocationVisitor
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import (
common_pattern_matcher as cpm,
domain_utils,
ir_makers as im,
)
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im


class NestConcatWheres(PreserveLocationVisitor, NodeTranslator):

@classmethod
def apply(cls, node: ir.Node):
return cls().visit(node)
Expand All @@ -27,10 +22,17 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall:
cond_expr, field_a, field_b = node.args
if cpm.is_call_to(cond_expr, ("and_")):
conds = cond_expr.args
return im.concat_where(conds[0], im.concat_where(conds[1],field_a, field_b), field_b)
return im.concat_where(
conds[0], im.concat_where(conds[1], field_a, field_b), field_b
)
if cpm.is_call_to(cond_expr, ("or_")):
conds = cond_expr.args
return im.concat_where(conds[0], field_a, im.concat_where(conds[1],field_a, field_b))

return im.concat_where(
conds[0], field_a, im.concat_where(conds[1], field_a, field_b)
)
if cpm.is_call_to(cond_expr, ("eq")):
cond1 = im.less(cond_expr.args[0], cond_expr.args[1])
cond2 = im.greater(cond_expr.args[0], cond_expr.args[1])
return im.concat_where(cond1, field_b, im.concat_where(cond2, field_b, field_a))

return self.generic_visit(node)
3 changes: 1 addition & 2 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,7 @@ def concat_where(
true_field: ts.FieldType | ts.TupleType,
false_field: ts.FieldType | ts.TupleType,
) -> ts.FieldType | ts.TupleType:
assert true_field == false_field
return true_field
return type_info.promote(true_field, false_field)


@_register_builtin_type_synthesizer
Expand Down
1 change: 1 addition & 0 deletions tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type]
IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type]
IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type]
JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type]
KField: TypeAlias = gtx.Field[[KDim], np.int32] # type: ignore [valid-type]
IJField: TypeAlias = gtx.Field[[IDim, JDim], np.int32] # type: ignore [valid-type]
IKField: TypeAlias = gtx.Field[[IDim, KDim], np.int32] # type: ignore [valid-type]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import numpy as np
from typing import Tuple
import pytest
from next_tests.integration_tests.cases import KDim, IDim, cartesian_case
from next_tests.integration_tests.cases import IDim, JDim, KDim, cartesian_case
from gt4py import next as gtx
from gt4py.next import errors
from gt4py.next.ffront.experimental import concat_where
from next_tests.integration_tests import cases
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
Expand All @@ -23,7 +24,7 @@ def test_boundary_same_size_fields(cartesian_case):
def testee(
k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField
) -> cases.IJKField:
return concat_where(k == 0, boundary, interior)
return concat_where(KDim == 0, boundary, interior)

k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())()
interior = cases.allocate(cartesian_case, testee, "interior")()
Expand Down Expand Up @@ -55,6 +56,22 @@ def testee(
cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref)


def test_dimension_different_dims(cartesian_case):
@gtx.field_operator
def testee(j: cases.JField, interior: cases.IJField, boundary: cases.JField) -> cases.IJField:
return concat_where(IDim >= 2, boundary, interior)

j = cases.allocate(cartesian_case, testee, "j", strategy=cases.IndexInitializer())()
interior = cases.allocate(cartesian_case, testee, "interior")()
boundary = cases.allocate(cartesian_case, testee, "boundary")()
out = cases.allocate(cartesian_case, testee, cases.RETURN)()

ref = np.where(
j.asnumpy()[:, np.newaxis] >= 2, boundary.asnumpy()[np.newaxis, :], interior.asnumpy()
)
cases.verify(cartesian_case, testee, j, interior, boundary, out=out, ref=ref)


def test_dimension_two_nested_conditions(cartesian_case):
@gtx.field_operator
def testee(
Expand Down Expand Up @@ -90,6 +107,20 @@ def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> c
cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref)


def test_dimension_two_conditions_eq(cartesian_case):
@gtx.field_operator
def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> cases.KField:
return concat_where((KDim == 2), interior, boundary)

k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())()
interior = cases.allocate(cartesian_case, testee, "interior")()
boundary = cases.allocate(cartesian_case, testee, "boundary")()
out = cases.allocate(cartesian_case, testee, cases.RETURN)()

ref = np.where(k.asnumpy() == 2, interior.asnumpy(), boundary.asnumpy())
cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref)


def test_dimension_two_conditions_or(cartesian_case):
@gtx.field_operator
def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> cases.KField:
Expand All @@ -109,7 +140,7 @@ def test_boundary_horizontal_slice(cartesian_case):
def testee(
k: cases.KField, interior: cases.IJKField, boundary: cases.IJField
) -> cases.IJKField:
return concat_where(k == 0, boundary, interior)
return concat_where(KDim == 0, boundary, interior)

k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())()
interior = cases.allocate(cartesian_case, testee, "interior")()
Expand All @@ -130,7 +161,7 @@ def test_boundary_single_layer(cartesian_case):
def testee(
k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField
) -> cases.IJKField:
return concat_where(k == 0, boundary, interior)
return concat_where(KDim == 0, boundary, interior)

k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())()
interior = cases.allocate(cartesian_case, testee, "interior")()
Expand All @@ -147,18 +178,22 @@ def testee(


def test_alternating_mask(cartesian_case):
@gtx.field_operator
def testee(k: cases.KField, f0: cases.IJKField, f1: cases.IJKField) -> cases.IJKField:
return concat_where(k % 2 == 0, f1, f0)
with pytest.raises(
errors.DSLError, match=("Type 'ts.OffsetType' can not be used in operator 'mod'.")
):

k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())()
f0 = cases.allocate(cartesian_case, testee, "f0")()
f1 = cases.allocate(cartesian_case, testee, "f1")()
out = cases.allocate(cartesian_case, testee, cases.RETURN)()
@gtx.field_operator
def testee(k: cases.KField, f0: cases.IJKField, f1: cases.IJKField) -> cases.IJKField:
return concat_where(KDim % 2 == 0, f1, f0)

k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())()
f0 = cases.allocate(cartesian_case, testee, "f0")()
f1 = cases.allocate(cartesian_case, testee, "f1")()
out = cases.allocate(cartesian_case, testee, cases.RETURN)()

ref = np.where(k.asnumpy()[np.newaxis, np.newaxis, :] % 2 == 0, f1.asnumpy(), f0.asnumpy())
ref = np.where(k.asnumpy()[np.newaxis, np.newaxis, :] % 2 == 0, f1.asnumpy(), f0.asnumpy())

cases.verify(cartesian_case, testee, k, f0, f1, out=out, ref=ref)
cases.verify(cartesian_case, testee, k, f0, f1, out=out, ref=ref)


@pytest.mark.uses_tuple_returns
Expand All @@ -171,7 +206,7 @@ def testee(
interior1: cases.IJKField,
boundary1: cases.IJField,
) -> Tuple[cases.IJKField, cases.IJKField]:
return concat_where(k == 0, (boundary0, boundary1), (interior0, interior1))
return concat_where(KDim == 0, (boundary0, boundary1), (interior0, interior1))

k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())()
interior0 = cases.allocate(cartesian_case, testee, "interior0")()
Expand Down

0 comments on commit f90329e

Please sign in to comment.