diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index cc4b2863d7..dc8d36af5e 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -571,18 +571,35 @@ def _deduce_compare_type( self, node: foast.Compare, *, left: foast.Expr, right: foast.Expr, **kwargs: Any ) -> Optional[ts.TypeSpec]: # check both types compatible + left_t, right_t = left.type, right.type + integer_kind = getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) if ( - isinstance(left.type, ts.DimensionType) - and isinstance(right.type, ts.ScalarType) - and right.type.kind == getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) + isinstance(left_t, ts.DimensionType) + and isinstance(right_t, ts.ScalarType) + and right_t.kind == integer_kind + ): + return ts.DomainType(dims=[left_t.dim]) + if ( + isinstance(right_t, ts.DimensionType) + and isinstance(left_t, ts.ScalarType) + and left_t.kind == integer_kind ): - return ts.DomainType(dims=[left.type.dim]) + return ts.DomainType(dims=[right_t.dim]) if ( - isinstance(right.type, ts.DimensionType) - and isinstance(left.type, ts.ScalarType) - and left.type.kind == getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) + isinstance(left_t, ts.OffsetType) + and left.op == dialect_ast_enums.BinaryOperator.MOD + and isinstance(right_t, ts.ScalarType) + and right_t.kind == integer_kind + ) or ( + isinstance(right_t, ts.OffsetType) + and right.op == dialect_ast_enums.BinaryOperator.MOD + and isinstance(left_t, ts.ScalarType) + and left_t.kind == integer_kind ): - return ts.DomainType(dims=[right.type.dim]) + raise errors.DSLError( + left.location, "Type 'ts.OffsetType' can not be used in operator 'mod'." + ) + # TODO for arg in (left, right): if not type_info.is_arithmetic(arg.type): @@ -596,13 +613,13 @@ def _deduce_compare_type( # transform operands to have bool dtype and use regular promotion # mechanism to handle dimension promotion return type_info.promote( - with_altered_scalar_kind(left.type, ts.ScalarKind.BOOL), - with_altered_scalar_kind(right.type, ts.ScalarKind.BOOL), + with_altered_scalar_kind(left_t, ts.ScalarKind.BOOL), + with_altered_scalar_kind(right_t, ts.ScalarKind.BOOL), ) except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote '{left.type}' and '{right.type}' to common type" + f"Could not promote '{left_t}' and '{right_t}' to common type" f" in call to '{node.op}'.", ) from ex @@ -627,10 +644,10 @@ def _deduce_binop_type( dialect_ast_enums.BinaryOperator.BIT_XOR, } - def tmp(arg): + def is_logical_or_domain(arg: ts.TypeSpec) -> bool: return type_info.is_logical(arg) or isinstance(arg, ts.DomainType) - is_compatible = tmp if node.op in logical_ops else type_info.is_arithmetic + is_compatible = is_logical_or_domain if node.op in logical_ops else type_info.is_arithmetic # check both types compatible for arg in (left, right): @@ -641,26 +658,23 @@ def tmp(arg): if isinstance(left.type, (ts.ScalarType, ts.FieldType)) and isinstance( right.type, (ts.ScalarType, ts.FieldType) ): - left_type = cast(ts.FieldType | ts.ScalarType, left.type) - right_type = cast(ts.FieldType | ts.ScalarType, right.type) - if node.op == dialect_ast_enums.BinaryOperator.POW: - return left_type + return left.type if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( - right_type + right.type ): raise errors.DSLError( arg.location, - f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.", + f"Type '{right.type}' can not be used in operator '{node.op}', it only accepts 'int'.", ) try: - return type_info.promote(left_type, right_type) + return type_info.promote(left.type, right.type) except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote '{left_type}' and '{right_type}' to common type" + f"Could not promote '{left.type}' and '{right.type}' to common type" f" in call to '{node.op}'.", ) from ex elif isinstance(left.type, ts.DomainType) and isinstance(right.type, ts.DomainType): @@ -971,13 +985,10 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: def _visit_concat_where(self, node: foast.Call, **kwargs: Any) -> foast.Call: true_branch_type = node.args[1].type false_branch_type = node.args[2].type - if true_branch_type != false_branch_type: - raise errors.DSLError( - node.location, - f"Incompatible argument in call to '{node.func!s}': expected " - f"'{true_branch_type}' and '{false_branch_type}' to be equal.", - ) - return_type = true_branch_type + true_branch_fieldtype = cast(ts.FieldType, true_branch_type) + false_branch_fieldtype = cast(ts.FieldType, false_branch_type) + promoted_type = type_info.promote(true_branch_fieldtype, false_branch_fieldtype) + return_type = promoted_type return foast.Call( func=node.func, diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 8f29b3ce9c..1d97878257 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -137,7 +137,7 @@ def visit_InfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]: return ["INF"] def visit_NegInfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]: - return ["NEG"] + return ["-INF"] def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]: return [str(node.value) + "ₒ"] diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 8444f3276b..b3980e70ed 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -16,7 +16,7 @@ class ConstantFolding(PreserveLocationVisitor, NodeTranslator): def apply(cls, node: ir.Node) -> ir.Node: return cls().visit(node) - def visit_FunCall(self, node: ir.FunCall): + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: # visit depth-first such that nested constant expressions (e.g. `(1+2)+3`) are properly folded new_node = self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/infer_domain_ops.py b/src/gt4py/next/iterator/transforms/infer_domain_ops.py index addb4047ef..a5da214ae3 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain_ops.py +++ b/src/gt4py/next/iterator/transforms/infer_domain_ops.py @@ -23,7 +23,7 @@ class InferDomainOps(PreserveLocationVisitor, NodeTranslator): def apply(cls, node: ir.Node): return cls().visit(node) - def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: node = self.generic_visit(node) if ( cpm.is_call_to(node, builtins.BINARY_MATH_COMPARISON_BUILTINS) @@ -33,11 +33,11 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: arg1, arg2 = node.args fun = node.fun if isinstance(arg1, ir.AxisLiteral) and isinstance(arg2, ir.Literal): - dim = common.Dimension(value=arg1.value, kind=common.DimensionKind.VERTICAL) + dim = common.Dimension(value=arg1.value, kind=arg1.kind) value = int(arg2.value) reverse = False elif isinstance(arg1, ir.Literal) and isinstance(arg2, ir.AxisLiteral): - dim = common.Dimension(value=arg2.value, kind=common.DimensionKind.VERTICAL) + dim = common.Dimension(value=arg2.value, kind=arg2.kind) value = int(arg1.value) reverse = True else: diff --git a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py index ee8197579d..258494e0c4 100644 --- a/src/gt4py/next/iterator/transforms/nest_concat_wheres.py +++ b/src/gt4py/next/iterator/transforms/nest_concat_wheres.py @@ -8,15 +8,10 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ( - common_pattern_matcher as cpm, - domain_utils, - ir_makers as im, -) +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im class NestConcatWheres(PreserveLocationVisitor, NodeTranslator): - @classmethod def apply(cls, node: ir.Node): return cls().visit(node) @@ -27,10 +22,17 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.FunCall: cond_expr, field_a, field_b = node.args if cpm.is_call_to(cond_expr, ("and_")): conds = cond_expr.args - return im.concat_where(conds[0], im.concat_where(conds[1],field_a, field_b), field_b) + return im.concat_where( + conds[0], im.concat_where(conds[1], field_a, field_b), field_b + ) if cpm.is_call_to(cond_expr, ("or_")): conds = cond_expr.args - return im.concat_where(conds[0], field_a, im.concat_where(conds[1],field_a, field_b)) - + return im.concat_where( + conds[0], field_a, im.concat_where(conds[1], field_a, field_b) + ) + if cpm.is_call_to(cond_expr, ("eq")): + cond1 = im.less(cond_expr.args[0], cond_expr.args[1]) + cond2 = im.greater(cond_expr.args[0], cond_expr.args[1]) + return im.concat_where(cond1, field_b, im.concat_where(cond2, field_b, field_a)) return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index f733a229be..c6f31d0a51 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -224,8 +224,7 @@ def concat_where( true_field: ts.FieldType | ts.TupleType, false_field: ts.FieldType | ts.TupleType, ) -> ts.FieldType | ts.TupleType: - assert true_field == false_field - return true_field + return type_info.promote(true_field, false_field) @_register_builtin_type_synthesizer diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 89ad556476..66330016ef 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -62,6 +62,7 @@ IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] +JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type] KField: TypeAlias = gtx.Field[[KDim], np.int32] # type: ignore [valid-type] IJField: TypeAlias = gtx.Field[[IDim, JDim], np.int32] # type: ignore [valid-type] IKField: TypeAlias = gtx.Field[[IDim, KDim], np.int32] # type: ignore [valid-type] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 7d03eb4fc1..7db29bc088 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -9,8 +9,9 @@ import numpy as np from typing import Tuple import pytest -from next_tests.integration_tests.cases import KDim, IDim, cartesian_case +from next_tests.integration_tests.cases import IDim, JDim, KDim, cartesian_case from gt4py import next as gtx +from gt4py.next import errors from gt4py.next.ffront.experimental import concat_where from next_tests.integration_tests import cases from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -23,7 +24,7 @@ def test_boundary_same_size_fields(cartesian_case): def testee( k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + return concat_where(KDim == 0, boundary, interior) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() @@ -55,6 +56,22 @@ def testee( cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) +def test_dimension_different_dims(cartesian_case): + @gtx.field_operator + def testee(j: cases.JField, interior: cases.IJField, boundary: cases.JField) -> cases.IJField: + return concat_where(IDim >= 2, boundary, interior) + + j = cases.allocate(cartesian_case, testee, "j", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + j.asnumpy()[:, np.newaxis] >= 2, boundary.asnumpy()[np.newaxis, :], interior.asnumpy() + ) + cases.verify(cartesian_case, testee, j, interior, boundary, out=out, ref=ref) + + def test_dimension_two_nested_conditions(cartesian_case): @gtx.field_operator def testee( @@ -90,6 +107,20 @@ def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> c cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) +def test_dimension_two_conditions_eq(cartesian_case): + @gtx.field_operator + def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> cases.KField: + return concat_where((KDim == 2), interior, boundary) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where(k.asnumpy() == 2, interior.asnumpy(), boundary.asnumpy()) + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + def test_dimension_two_conditions_or(cartesian_case): @gtx.field_operator def testee(k: cases.KField, interior: cases.KField, boundary: cases.KField) -> cases.KField: @@ -109,7 +140,7 @@ def test_boundary_horizontal_slice(cartesian_case): def testee( k: cases.KField, interior: cases.IJKField, boundary: cases.IJField ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + return concat_where(KDim == 0, boundary, interior) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() @@ -130,7 +161,7 @@ def test_boundary_single_layer(cartesian_case): def testee( k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField ) -> cases.IJKField: - return concat_where(k == 0, boundary, interior) + return concat_where(KDim == 0, boundary, interior) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior = cases.allocate(cartesian_case, testee, "interior")() @@ -147,18 +178,22 @@ def testee( def test_alternating_mask(cartesian_case): - @gtx.field_operator - def testee(k: cases.KField, f0: cases.IJKField, f1: cases.IJKField) -> cases.IJKField: - return concat_where(k % 2 == 0, f1, f0) + with pytest.raises( + errors.DSLError, match=("Type 'ts.OffsetType' can not be used in operator 'mod'.") + ): - k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() - f0 = cases.allocate(cartesian_case, testee, "f0")() - f1 = cases.allocate(cartesian_case, testee, "f1")() - out = cases.allocate(cartesian_case, testee, cases.RETURN)() + @gtx.field_operator + def testee(k: cases.KField, f0: cases.IJKField, f1: cases.IJKField) -> cases.IJKField: + return concat_where(KDim % 2 == 0, f1, f0) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + f0 = cases.allocate(cartesian_case, testee, "f0")() + f1 = cases.allocate(cartesian_case, testee, "f1")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() - ref = np.where(k.asnumpy()[np.newaxis, np.newaxis, :] % 2 == 0, f1.asnumpy(), f0.asnumpy()) + ref = np.where(k.asnumpy()[np.newaxis, np.newaxis, :] % 2 == 0, f1.asnumpy(), f0.asnumpy()) - cases.verify(cartesian_case, testee, k, f0, f1, out=out, ref=ref) + cases.verify(cartesian_case, testee, k, f0, f1, out=out, ref=ref) @pytest.mark.uses_tuple_returns @@ -171,7 +206,7 @@ def testee( interior1: cases.IJKField, boundary1: cases.IJField, ) -> Tuple[cases.IJKField, cases.IJKField]: - return concat_where(k == 0, (boundary0, boundary1), (interior0, interior1)) + return concat_where(KDim == 0, (boundary0, boundary1), (interior0, interior1)) k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() interior0 = cases.allocate(cartesian_case, testee, "interior0")()