Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable type casting for pow() int arguments #543

Merged
merged 11 commits into from
Sep 18, 2024
8 changes: 7 additions & 1 deletion pytato/utils.py
inducer marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,13 @@ def cast_to_result_type(

expr = TypeCast(result_dtype, expr)
elif isinstance(expr, SCALAR_CLASSES):
expr = result_dtype.type(expr)
import operator
# See https://github.com/inducer/pytato/issues/542
# on why pow() + integers is not typecast to float or complex.
if not ((op == prim.Power or op == operator.pow)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not loving that we're identifying the operator by its callable; that's brittle. (Somebody could just wrap things in a lambda and this would forget what it was doing.)

Copy link
Collaborator Author

@matthiasdiener matthiasdiener Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand, but I'm not sure how to improve this. Do you have a suggestion?

Edit: add an argument to broadcast_binary_op that identifies it as pow.

and np.issubdtype(type(expr), np.integer)
kaushikcfd marked this conversation as resolved.
Show resolved Hide resolved
and not np.issubdtype(result_dtype, np.integer)):
expr = result_dtype.type(expr)

return expr

Expand Down
24 changes: 21 additions & 3 deletions test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,6 @@ def test_scalar_array_binary_arith(ctx_factory, which, reverse):
"logical_and"))
@pytest.mark.parametrize("reverse", (False, True))
def test_array_array_binary_arith(ctx_factory, which, reverse):
if which == "sub":
pytest.skip("https://github.com/inducer/loopy/issues/131")

cl_ctx = ctx_factory()
queue = cl.CommandQueue(cl_ctx)
not_valid_in_complex = which in ["equal", "not_equal", "less", "less_equal",
Expand Down Expand Up @@ -2008,6 +2005,27 @@ def call_bar(tracer, x, y):
np.testing.assert_allclose(result_out[k], expect_out[k])


def test_pow_arg_casting(ctx_factory):
# Check that pow() arguments are not typecast from int
ctx = ctx_factory()
cq = cl.CommandQueue(ctx)

x_np = np.random.rand(10, 4)
x = pt.make_data_wrapper(x_np)

out = x ** -2
_, (pt_out,) = pt.generate_loopy(out)(cq)
assert isinstance(out.expr.exponent, int)
inducer marked this conversation as resolved.
Show resolved Hide resolved
assert pt_out.dtype == np.float64
np.testing.assert_allclose(np.power(x_np, -2), pt_out)

out = x ** 2.0
_, (pt_out,) = pt.generate_loopy(out)(cq)
assert isinstance(out.expr.exponent, np.float64)
assert pt_out.dtype == np.float64
np.testing.assert_allclose(np.power(x_np, 2.0), pt_out)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down
Loading