From f9af77515f461729c7cf9ed63f53db0b5a13b8d2 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 11 Sep 2024 16:58:49 -0500 Subject: [PATCH 01/10] Disable arg casting for scalars --- pytato/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytato/utils.py b/pytato/utils.py index f4261685c..63db0742f 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -227,7 +227,9 @@ def cast_to_result_type( expr = TypeCast(result_dtype, expr) elif isinstance(expr, SCALAR_CLASSES): - expr = result_dtype.type(expr) + # Disabled due to https://github.com/inducer/pytato/issues/542 + # expr = result_dtype.type(expr) + return expr return expr From 5e3a1412bd00e57a754b6f1981357319fb976409 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 11 Sep 2024 16:59:16 -0500 Subject: [PATCH 02/10] reenable sub-array test --- test/test_codegen.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/test_codegen.py b/test/test_codegen.py index 5380f0676..e1654ffc0 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -326,9 +326,6 @@ def test_scalar_array_binary_arith(ctx_factory, which, reverse): "logical_and")) @pytest.mark.parametrize("reverse", (False, True)) def test_array_array_binary_arith(ctx_factory, which, reverse): - if which == "sub": - pytest.skip("https://github.com/inducer/loopy/issues/131") - cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) not_valid_in_complex = which in ["equal", "not_equal", "less", "less_equal", From fdd8ebccb79c6da52c74af6179ca497039faf243 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 12 Sep 2024 15:10:45 -0500 Subject: [PATCH 03/10] cast python scalars to numpy types (cf. #247) --- pytato/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytato/utils.py b/pytato/utils.py index 63db0742f..194478da4 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -220,15 +220,17 @@ def cast_to_result_type( array: ArrayOrScalar, expr: ScalarExpression ) -> ScalarExpression: + from pytato.scalar_expr import PYTHON_SCALAR_CLASSES if ((isinstance(array, Array) or isinstance(array, np.generic)) and array.dtype != result_dtype): # Loopy's type casts don't like casting to bool assert result_dtype != np.bool_ expr = TypeCast(result_dtype, expr) - elif isinstance(expr, SCALAR_CLASSES): - # Disabled due to https://github.com/inducer/pytato/issues/542 - # expr = result_dtype.type(expr) + elif isinstance(expr, PYTHON_SCALAR_CLASSES): + # See https://github.com/inducer/pytato/pull/247 and + # https://github.com/inducer/pytato/issues/542 + expr = np.dtype(type(expr)).type(expr) return expr return expr From c3c41cfc982208c0ae9accafa234454d6d577420 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 12 Sep 2024 18:04:36 -0500 Subject: [PATCH 04/10] restrict scalar special handling to complex --- pytato/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/utils.py b/pytato/utils.py index 194478da4..de611f509 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -220,14 +220,13 @@ def cast_to_result_type( array: ArrayOrScalar, expr: ScalarExpression ) -> ScalarExpression: - from pytato.scalar_expr import PYTHON_SCALAR_CLASSES if ((isinstance(array, Array) or isinstance(array, np.generic)) and array.dtype != result_dtype): # Loopy's type casts don't like casting to bool assert result_dtype != np.bool_ expr = TypeCast(result_dtype, expr) - elif isinstance(expr, PYTHON_SCALAR_CLASSES): + elif isinstance(expr, (complex,)) and not isinstance(expr, np.generic): # See https://github.com/inducer/pytato/pull/247 and # https://github.com/inducer/pytato/issues/542 expr = np.dtype(type(expr)).type(expr) From 71eec4e2bf6d51743e684da13065e256df670dd9 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 13 Sep 2024 15:30:05 -0500 Subject: [PATCH 05/10] only handle pow+int specially --- pytato/utils.py | 11 ++++++----- test/test_codegen.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/pytato/utils.py b/pytato/utils.py index de611f509..9bd80c3de 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -226,11 +226,12 @@ def cast_to_result_type( assert result_dtype != np.bool_ expr = TypeCast(result_dtype, expr) - elif isinstance(expr, (complex,)) and not isinstance(expr, np.generic): - # See https://github.com/inducer/pytato/pull/247 and - # https://github.com/inducer/pytato/issues/542 - expr = np.dtype(type(expr)).type(expr) - return expr + elif isinstance(expr, SCALAR_CLASSES): + # See https://github.com/inducer/pytato/issues/542 + # on why pow() + integers is not typecast to float or complex. + if not (op == prim.Power and np.issubdtype(type(expr), np.integer) + and not np.issubdtype(result_dtype, np.integer)): + expr = result_dtype.type(expr) return expr diff --git a/test/test_codegen.py b/test/test_codegen.py index e1654ffc0..4e3cbf904 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -2005,6 +2005,27 @@ def call_bar(tracer, x, y): np.testing.assert_allclose(result_out[k], expect_out[k]) +def test_pow_arg_casting(ctx_factory): + # Check that pow() arguments are not typecast from int + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + x_np = np.random.rand(10, 4) + x = pt.make_data_wrapper(x_np) + + out = x ** -2 + _, (pt_out,) = pt.generate_loopy(out)(cq) + assert isinstance(out.expr.exponent, int) + assert pt_out.dtype == np.float64 + np.testing.assert_allclose(np.power(x_np, -2), pt_out) + + out = x ** 2.0 + _, (pt_out,) = pt.generate_loopy(out)(cq) + assert isinstance(out.expr.exponent, np.float64) + assert pt_out.dtype == np.float64 + np.testing.assert_allclose(np.power(x_np, 2.0), pt_out) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) From f34cb4ef0bfef16ff699b007f977a26980a4ff5d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 13 Sep 2024 16:20:17 -0500 Subject: [PATCH 06/10] make compatible with non-production pytato --- pytato/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytato/utils.py b/pytato/utils.py index 9bd80c3de..aa191f230 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -227,9 +227,11 @@ def cast_to_result_type( expr = TypeCast(result_dtype, expr) elif isinstance(expr, SCALAR_CLASSES): + import operator # See https://github.com/inducer/pytato/issues/542 # on why pow() + integers is not typecast to float or complex. - if not (op == prim.Power and np.issubdtype(type(expr), np.integer) + if not ((op == prim.Power or op == operator.pow) + and np.issubdtype(type(expr), np.integer) and not np.issubdtype(result_dtype, np.integer)): expr = result_dtype.type(expr) From ed572c010e1275b86335071f90ebe0ad77150f94 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 17 Sep 2024 16:03:25 -0500 Subject: [PATCH 07/10] use is_pow --- pytato/array.py | 11 +++++++---- pytato/utils.py | 18 +++++++++--------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 6f9221ae8..3b120fb3e 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -582,6 +582,7 @@ def _binary_op( np.dtype[Any]] = _np_result_dtype, reverse: bool = False, cast_to_result_dtype: bool = True, + is_pow: bool = False, ) -> Array: # {{{ sanity checks @@ -601,14 +602,16 @@ def _binary_op( get_result_type, tags=tags, non_equality_tags=non_equality_tags, - cast_to_result_dtype=cast_to_result_dtype) + cast_to_result_dtype=cast_to_result_dtype, + is_pow=is_pow) else: result = utils.broadcast_binary_op( self, other, op, get_result_type, tags=tags, non_equality_tags=non_equality_tags, - cast_to_result_dtype=cast_to_result_dtype) + cast_to_result_dtype=cast_to_result_dtype, + is_pow=is_pow) assert isinstance(result, Array) return result @@ -648,8 +651,8 @@ def _unary_op(self, op: Any) -> Array: __rtruediv__ = partialmethod(_binary_op, operator.truediv, get_result_type=_truediv_result_type, reverse=True) - __pow__ = partialmethod(_binary_op, operator.pow) - __rpow__ = partialmethod(_binary_op, operator.pow, reverse=True) + __pow__ = partialmethod(_binary_op, operator.pow, is_pow=True) + __rpow__ = partialmethod(_binary_op, operator.pow, reverse=True, is_pow=True) __neg__ = partialmethod(_unary_op, operator.neg) diff --git a/pytato/utils.py b/pytato/utils.py index aa191f230..0c65bc0d3 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -195,6 +195,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, tags: frozenset[Tag], non_equality_tags: frozenset[Tag], cast_to_result_dtype: bool, + is_pow: bool, ) -> ArrayOrScalar: from pytato.array import _get_default_axes @@ -220,20 +221,19 @@ def cast_to_result_type( array: ArrayOrScalar, expr: ScalarExpression ) -> ScalarExpression: - if ((isinstance(array, Array) or isinstance(array, np.generic)) - and array.dtype != result_dtype): - # Loopy's type casts don't like casting to bool - assert result_dtype != np.bool_ - - expr = TypeCast(result_dtype, expr) - elif isinstance(expr, SCALAR_CLASSES): - import operator + if isinstance(expr, SCALAR_CLASSES): # See https://github.com/inducer/pytato/issues/542 # on why pow() + integers is not typecast to float or complex. - if not ((op == prim.Power or op == operator.pow) + if not (is_pow and np.issubdtype(type(expr), np.integer) and not np.issubdtype(result_dtype, np.integer)): expr = result_dtype.type(expr) + elif ((isinstance(array, Array) or isinstance(array, np.generic)) + and array.dtype != result_dtype): + # Loopy's type casts don't like casting to bool + assert result_dtype != np.bool_ + + expr = TypeCast(result_dtype, expr) return expr From e4daf10e8d9f4598c2312eed496bdb81ab0793dd Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 17 Sep 2024 16:46:27 -0500 Subject: [PATCH 08/10] improve test, support exponent array --- pytato/utils.py | 19 ++++++++----- test/test_codegen.py | 64 +++++++++++++++++++++++++++++++++++--------- 2 files changed, 63 insertions(+), 20 deletions(-) diff --git a/pytato/utils.py b/pytato/utils.py index 0c65bc0d3..2e0a7d3d2 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -221,19 +221,24 @@ def cast_to_result_type( array: ArrayOrScalar, expr: ScalarExpression ) -> ScalarExpression: - if isinstance(expr, SCALAR_CLASSES): + if ((isinstance(array, Array) or isinstance(array, np.generic)) + and array.dtype != result_dtype): + # Loopy's type casts don't like casting to bool + assert result_dtype != np.bool_ + + # See https://github.com/inducer/pytato/issues/542 + # on why pow() + integers is not typecast to float or complex. + if not (is_pow + and np.issubdtype(array.dtype, np.integer) + and not np.issubdtype(result_dtype, np.integer)): + expr = TypeCast(result_dtype, expr) + elif isinstance(expr, SCALAR_CLASSES): # See https://github.com/inducer/pytato/issues/542 # on why pow() + integers is not typecast to float or complex. if not (is_pow and np.issubdtype(type(expr), np.integer) and not np.issubdtype(result_dtype, np.integer)): expr = result_dtype.type(expr) - elif ((isinstance(array, Array) or isinstance(array, np.generic)) - and array.dtype != result_dtype): - # Loopy's type casts don't like casting to bool - assert result_dtype != np.bool_ - - expr = TypeCast(result_dtype, expr) return expr diff --git a/test/test_codegen.py b/test/test_codegen.py index 4e3cbf904..79609a673 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -2010,20 +2010,58 @@ def test_pow_arg_casting(ctx_factory): ctx = ctx_factory() cq = cl.CommandQueue(ctx) - x_np = np.random.rand(10, 4) - x = pt.make_data_wrapper(x_np) + types = (int, np.int32, np.int64, float, np.float32, np.float64) + + for base_scalar in (True, False): + for exponent_scalar in (True, False): + if base_scalar and exponent_scalar: + # Not supported in pytato + continue - out = x ** -2 - _, (pt_out,) = pt.generate_loopy(out)(cq) - assert isinstance(out.expr.exponent, int) - assert pt_out.dtype == np.float64 - np.testing.assert_allclose(np.power(x_np, -2), pt_out) - - out = x ** 2.0 - _, (pt_out,) = pt.generate_loopy(out)(cq) - assert isinstance(out.expr.exponent, np.float64) - assert pt_out.dtype == np.float64 - np.testing.assert_allclose(np.power(x_np, 2.0), pt_out) + for base_tp in types: + if base_scalar: + base_np = base_tp(2) + first = base_np + else: + base_np = np.array([1, 2, 3], base_tp) + first = pt.make_data_wrapper(base_np) + + for exponent_tp in types: + if exponent_scalar: + exponent_np = exponent_tp(2) + second = exponent_np + else: + exponent_np = np.array([1, 2, 3], exponent_tp) + second = pt.make_data_wrapper(exponent_np) + + out = first ** second + + _, (pt_out,) = pt.generate_loopy(out)(cq) + + np_out = np.power(base_np, exponent_np) + + assert pt_out.dtype == np_out.dtype + np.testing.assert_allclose(np_out, pt_out) + + if np.issubdtype(exponent_tp, np.integer): + assert exponent_tp in (int, np.int32, np.int64) + + if exponent_scalar: + # We do cast between different int types + assert (type(out.expr.exponent) in + (int, np.int32, np.int64) + or out.expr.exponent.dtype == np_out.dtype) + else: + assert out.bindings["_in1"].dtype in \ + (int, np.int32, np.int64) + else: + assert exponent_tp in (float, np.float32, np.float64) + if exponent_scalar: + assert type(out.expr.exponent) == np_out.dtype \ + or out.expr.exponent.dtype == np_out.dtype + else: + assert out.bindings["_in1"].dtype in \ + (float, np.float32, np.float64) if __name__ == "__main__": From 2fcced72ced2702f7f91eeee8e05c28da642aa59 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 17 Sep 2024 16:48:52 -0500 Subject: [PATCH 09/10] add is_pow arg --- pytato/array.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index 3b120fb3e..619cca77d 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -2406,7 +2406,8 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Array | bool: lambda x, y: np.dtype(np.bool_), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(stacklevel=2), - cast_to_result_dtype=False + cast_to_result_dtype=False, + is_pow=False, ) # type: ignore[return-value] @@ -2470,6 +2471,7 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), cast_to_result_dtype=False, + is_pow=False, ) # type: ignore[return-value] @@ -2487,6 +2489,7 @@ def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), cast_to_result_dtype=False, + is_pow=False, ) # type: ignore[return-value] From b7f182971466a58a637ae87ad4ab1fbf298fbc5e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 Sep 2024 13:52:37 -0500 Subject: [PATCH 10/10] minor variable renaming in test --- test/test_codegen.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_codegen.py b/test/test_codegen.py index 79609a673..9aeb60dd8 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -2021,20 +2021,20 @@ def test_pow_arg_casting(ctx_factory): for base_tp in types: if base_scalar: base_np = base_tp(2) - first = base_np + base = base_np else: base_np = np.array([1, 2, 3], base_tp) - first = pt.make_data_wrapper(base_np) + base = pt.make_data_wrapper(base_np) for exponent_tp in types: if exponent_scalar: exponent_np = exponent_tp(2) - second = exponent_np + exponent = exponent_np else: exponent_np = np.array([1, 2, 3], exponent_tp) - second = pt.make_data_wrapper(exponent_np) + exponent = pt.make_data_wrapper(exponent_np) - out = first ** second + out = base ** exponent _, (pt_out,) = pt.generate_loopy(out)(cq)