From eaae55aa7755e8a19851257bf24eea4697480f5b Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Sat, 15 Feb 2025 16:37:06 -0500 Subject: [PATCH] minor tweak to transcendental pow also added more pow with const test cases --- test/test_ops.py | 8 ++++++++ tinygrad/codegen/transcendental.py | 8 ++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 0350b635aa252..80202a5c2e902 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -610,7 +610,14 @@ def test_pow_const(self): helper_test_op([(45,65)], lambda x: x**0.0) helper_test_op([(45,65)], lambda x: x**1.0) helper_test_op([(45,65)], lambda x: x**-1.0) + helper_test_op([(45,65)], lambda x: x**8.0) + helper_test_op([(45,65)], lambda x: x**5.5) + helper_test_op([(45,65)], lambda x: x**-5.5) + # helper_test_op([(45,65)], lambda x: x**-8.0) # TODO: fix this helper_test_op([(45,65)], lambda x: 1.0**x) + helper_test_op([(45,65)], lambda x: 5.5**x) + helper_test_op([(45,65)], lambda x: (-5.5)**x) + helper_test_op([(45,65)], lambda x: 8.0**x) helper_test_op([(45,65)], lambda x: x**2.0) helper_test_op([(45,65)], lambda x: 2.0**x) helper_test_op([()], lambda x: x**2.0) @@ -628,6 +635,7 @@ def test_pow_zero_const(self): helper_test_op(None, lambda x: x**0.3, vals=[[0.0]]) helper_test_op(None, lambda x: x**0.0, vals=[[0.0]]) helper_test_op(None, lambda x: x**-0.3, vals=[[0.0]]) + helper_test_op(None, lambda x: x**-1.0, vals=[[-1.0, 0.0, 1.0]]) @unittest.skip("not supported") def test_pow_int(self): diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index 5611cae1d1f9d..a15209ba9ac32 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -259,8 +259,8 @@ def xpow(base:UOp, exponent:UOp) -> UOp: # start with b ** e = exp2(e * log2(b)) ret = (base < 0).where(-base, base).log2().mul(exponent).exp2() # negative base adjustment: nan for non-integer exponent and -1 for odd exponent - adj = (base < 0).where((exponent != exponent.cast(dtypes.int32).cast(exponent.dtype)).where( - ret.const_like(math.nan), - (exponent.cast(dtypes.int32).cast(dtypes.uint32)%2).eq(1).where(ret.const_like(-1), ret.const_like(1))), ret.const_like(1)) + non_int = exponent != exponent.cast(dtypes.int32).cast(exponent.dtype) + adj = non_int.where(ret.const_like(math.nan), + (exponent < 0).where(-exponent, exponent).cast(dtypes.int32).mod(2).cast(dtypes.bool).where(ret.const_like(-1), ret.const_like(1))) # fix 0 ** 0 = 1 - return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * adj) + return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * (base < 0).where(adj, ret.const_like(1)))