Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable type casting for pow() int arguments #543

Merged
merged 11 commits into from
Sep 18, 2024
16 changes: 11 additions & 5 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ def _binary_op(
np.dtype[Any]] = _np_result_dtype,
reverse: bool = False,
cast_to_result_dtype: bool = True,
is_pow: bool = False,
) -> Array:

# {{{ sanity checks
Expand All @@ -601,14 +602,16 @@ def _binary_op(
get_result_type,
tags=tags,
non_equality_tags=non_equality_tags,
cast_to_result_dtype=cast_to_result_dtype)
cast_to_result_dtype=cast_to_result_dtype,
is_pow=is_pow)
else:
result = utils.broadcast_binary_op(
self, other, op,
get_result_type,
tags=tags,
non_equality_tags=non_equality_tags,
cast_to_result_dtype=cast_to_result_dtype)
cast_to_result_dtype=cast_to_result_dtype,
is_pow=is_pow)

assert isinstance(result, Array)
return result
Expand Down Expand Up @@ -648,8 +651,8 @@ def _unary_op(self, op: Any) -> Array:
__rtruediv__ = partialmethod(_binary_op, operator.truediv,
get_result_type=_truediv_result_type, reverse=True)

__pow__ = partialmethod(_binary_op, operator.pow)
__rpow__ = partialmethod(_binary_op, operator.pow, reverse=True)
__pow__ = partialmethod(_binary_op, operator.pow, is_pow=True)
__rpow__ = partialmethod(_binary_op, operator.pow, reverse=True, is_pow=True)

__neg__ = partialmethod(_unary_op, operator.neg)

Expand Down Expand Up @@ -2403,7 +2406,8 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Array | bool:
lambda x, y: np.dtype(np.bool_),
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(stacklevel=2),
cast_to_result_dtype=False
cast_to_result_dtype=False,
is_pow=False,
inducer marked this conversation as resolved.
Show resolved Hide resolved
) # type: ignore[return-value]


Expand Down Expand Up @@ -2467,6 +2471,7 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool:
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(),
cast_to_result_dtype=False,
is_pow=False,
) # type: ignore[return-value]


Expand All @@ -2484,6 +2489,7 @@ def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool:
tags=_get_default_tags(),
non_equality_tags=_get_created_at_tag(),
cast_to_result_dtype=False,
is_pow=False,
) # type: ignore[return-value]


Expand Down
15 changes: 13 additions & 2 deletions pytato/utils.py
inducer marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar,
tags: frozenset[Tag],
non_equality_tags: frozenset[Tag],
cast_to_result_dtype: bool,
is_pow: bool,
) -> ArrayOrScalar:
from pytato.array import _get_default_axes

Expand Down Expand Up @@ -225,9 +226,19 @@ def cast_to_result_type(
# Loopy's type casts don't like casting to bool
assert result_dtype != np.bool_

expr = TypeCast(result_dtype, expr)
# See https://github.com/inducer/pytato/issues/542
# on why pow() + integers is not typecast to float or complex.
if not (is_pow
and np.issubdtype(array.dtype, np.integer)
and not np.issubdtype(result_dtype, np.integer)):
expr = TypeCast(result_dtype, expr)
elif isinstance(expr, SCALAR_CLASSES):
expr = result_dtype.type(expr)
# See https://github.com/inducer/pytato/issues/542
# on why pow() + integers is not typecast to float or complex.
if not (is_pow
and np.issubdtype(type(expr), np.integer)
kaushikcfd marked this conversation as resolved.
Show resolved Hide resolved
and not np.issubdtype(result_dtype, np.integer)):
expr = result_dtype.type(expr)

return expr

Expand Down
62 changes: 59 additions & 3 deletions test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,6 @@ def test_scalar_array_binary_arith(ctx_factory, which, reverse):
"logical_and"))
@pytest.mark.parametrize("reverse", (False, True))
def test_array_array_binary_arith(ctx_factory, which, reverse):
if which == "sub":
pytest.skip("https://github.com/inducer/loopy/issues/131")

cl_ctx = ctx_factory()
queue = cl.CommandQueue(cl_ctx)
not_valid_in_complex = which in ["equal", "not_equal", "less", "less_equal",
Expand Down Expand Up @@ -2008,6 +2005,65 @@ def call_bar(tracer, x, y):
np.testing.assert_allclose(result_out[k], expect_out[k])


def test_pow_arg_casting(ctx_factory):
# Check that pow() arguments are not typecast from int
ctx = ctx_factory()
cq = cl.CommandQueue(ctx)

types = (int, np.int32, np.int64, float, np.float32, np.float64)

for base_scalar in (True, False):
for exponent_scalar in (True, False):
if base_scalar and exponent_scalar:
# Not supported in pytato
continue

for base_tp in types:
if base_scalar:
base_np = base_tp(2)
first = base_np
else:
base_np = np.array([1, 2, 3], base_tp)
first = pt.make_data_wrapper(base_np)

for exponent_tp in types:
if exponent_scalar:
exponent_np = exponent_tp(2)
second = exponent_np
else:
exponent_np = np.array([1, 2, 3], exponent_tp)
second = pt.make_data_wrapper(exponent_np)

out = first ** second

_, (pt_out,) = pt.generate_loopy(out)(cq)

np_out = np.power(base_np, exponent_np)

assert pt_out.dtype == np_out.dtype
np.testing.assert_allclose(np_out, pt_out)

if np.issubdtype(exponent_tp, np.integer):
assert exponent_tp in (int, np.int32, np.int64)

if exponent_scalar:
# We do cast between different int types
assert (type(out.expr.exponent) in
(int, np.int32, np.int64)
or out.expr.exponent.dtype == np_out.dtype)
else:
assert out.bindings["_in1"].dtype in \
(int, np.int32, np.int64)
else:
assert exponent_tp in (float, np.float32, np.float64)
if exponent_scalar:
assert type(out.expr.exponent) == np_out.dtype \
or out.expr.exponent.dtype == np_out.dtype
else:
assert out.bindings["_in1"].dtype in \
(float, np.float32, np.float64)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down
Loading