From f1b81aab22ce80b560e52df31f6b686beb41f757 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 17 Sep 2024 16:03:25 -0500 Subject: [PATCH] use is_pow --- pytato/array.py | 11 +++++++---- pytato/utils.py | 18 +++++++++--------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 6f9221ae..1b067f34 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -582,6 +582,7 @@ def _binary_op( np.dtype[Any]] = _np_result_dtype, reverse: bool = False, cast_to_result_dtype: bool = True, + is_pow = False, ) -> Array: # {{{ sanity checks @@ -601,14 +602,16 @@ def _binary_op( get_result_type, tags=tags, non_equality_tags=non_equality_tags, - cast_to_result_dtype=cast_to_result_dtype) + cast_to_result_dtype=cast_to_result_dtype, + is_pow=is_pow) else: result = utils.broadcast_binary_op( self, other, op, get_result_type, tags=tags, non_equality_tags=non_equality_tags, - cast_to_result_dtype=cast_to_result_dtype) + cast_to_result_dtype=cast_to_result_dtype, + is_pow=is_pow) assert isinstance(result, Array) return result @@ -648,8 +651,8 @@ def _unary_op(self, op: Any) -> Array: __rtruediv__ = partialmethod(_binary_op, operator.truediv, get_result_type=_truediv_result_type, reverse=True) - __pow__ = partialmethod(_binary_op, operator.pow) - __rpow__ = partialmethod(_binary_op, operator.pow, reverse=True) + __pow__ = partialmethod(_binary_op, operator.pow, is_pow=True) + __rpow__ = partialmethod(_binary_op, operator.pow, reverse=True, is_pow=True) __neg__ = partialmethod(_unary_op, operator.neg) diff --git a/pytato/utils.py b/pytato/utils.py index aa191f23..0c65bc0d 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -195,6 +195,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, tags: frozenset[Tag], non_equality_tags: frozenset[Tag], cast_to_result_dtype: bool, + is_pow: bool, ) -> ArrayOrScalar: from pytato.array import _get_default_axes @@ -220,20 +221,19 @@ def cast_to_result_type( array: ArrayOrScalar, expr: ScalarExpression ) -> ScalarExpression: - if ((isinstance(array, Array) or isinstance(array, np.generic)) - and array.dtype != result_dtype): - # Loopy's type casts don't like casting to bool - assert result_dtype != np.bool_ - - expr = TypeCast(result_dtype, expr) - elif isinstance(expr, SCALAR_CLASSES): - import operator + if isinstance(expr, SCALAR_CLASSES): # See https://github.com/inducer/pytato/issues/542 # on why pow() + integers is not typecast to float or complex. - if not ((op == prim.Power or op == operator.pow) + if not (is_pow and np.issubdtype(type(expr), np.integer) and not np.issubdtype(result_dtype, np.integer)): expr = result_dtype.type(expr) + elif ((isinstance(array, Array) or isinstance(array, np.generic)) + and array.dtype != result_dtype): + # Loopy's type casts don't like casting to bool + assert result_dtype != np.bool_ + + expr = TypeCast(result_dtype, expr) return expr