Skip to content

Commit

Permalink
use is_pow
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Sep 17, 2024
1 parent f34cb4e commit f1b81aa
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
11 changes: 7 additions & 4 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ def _binary_op(
np.dtype[Any]] = _np_result_dtype,
reverse: bool = False,
cast_to_result_dtype: bool = True,
is_pow = False,
) -> Array:

# {{{ sanity checks
Expand All @@ -601,14 +602,16 @@ def _binary_op(
get_result_type,
tags=tags,
non_equality_tags=non_equality_tags,
cast_to_result_dtype=cast_to_result_dtype)
cast_to_result_dtype=cast_to_result_dtype,
is_pow=is_pow)
else:
result = utils.broadcast_binary_op(
self, other, op,
get_result_type,
tags=tags,
non_equality_tags=non_equality_tags,
cast_to_result_dtype=cast_to_result_dtype)
cast_to_result_dtype=cast_to_result_dtype,
is_pow=is_pow)

assert isinstance(result, Array)
return result
Expand Down Expand Up @@ -648,8 +651,8 @@ def _unary_op(self, op: Any) -> Array:
__rtruediv__ = partialmethod(_binary_op, operator.truediv,
get_result_type=_truediv_result_type, reverse=True)

__pow__ = partialmethod(_binary_op, operator.pow)
__rpow__ = partialmethod(_binary_op, operator.pow, reverse=True)
__pow__ = partialmethod(_binary_op, operator.pow, is_pow=True)
__rpow__ = partialmethod(_binary_op, operator.pow, reverse=True, is_pow=True)

__neg__ = partialmethod(_unary_op, operator.neg)

Expand Down
18 changes: 9 additions & 9 deletions pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar,
tags: frozenset[Tag],
non_equality_tags: frozenset[Tag],
cast_to_result_dtype: bool,
is_pow: bool,
) -> ArrayOrScalar:
from pytato.array import _get_default_axes

Expand All @@ -220,20 +221,19 @@ def cast_to_result_type(
array: ArrayOrScalar,
expr: ScalarExpression
) -> ScalarExpression:
if ((isinstance(array, Array) or isinstance(array, np.generic))
and array.dtype != result_dtype):
# Loopy's type casts don't like casting to bool
assert result_dtype != np.bool_

expr = TypeCast(result_dtype, expr)
elif isinstance(expr, SCALAR_CLASSES):
import operator
if isinstance(expr, SCALAR_CLASSES):
# See https://github.com/inducer/pytato/issues/542
# on why pow() + integers is not typecast to float or complex.
if not ((op == prim.Power or op == operator.pow)
if not (is_pow
and np.issubdtype(type(expr), np.integer)
and not np.issubdtype(result_dtype, np.integer)):
expr = result_dtype.type(expr)
elif ((isinstance(array, Array) or isinstance(array, np.generic))
and array.dtype != result_dtype):
# Loopy's type casts don't like casting to bool
assert result_dtype != np.bool_

expr = TypeCast(result_dtype, expr)

return expr

Expand Down

0 comments on commit f1b81aa

Please sign in to comment.