Skip to content

Commit

Permalink
move more pow const to rewrite (tinygrad#8916)
Browse files Browse the repository at this point in the history
* move more pow const to rewrite

one less use of _to_const_val

* fix
  • Loading branch information
chenyuxyz authored Feb 6, 2025
1 parent 7667138 commit 488200f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 12 deletions.
21 changes: 21 additions & 0 deletions test/test_const_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,27 @@ def test_2_pow_is_exp2(self):
alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
self.assertEqual(alu, [Ops.EXP2])

def test_pow_05_is_sqrt(self):
t = Tensor([1.0, 2.0, 3.0]) ** 0.5
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
self.assertEqual(len(s), 1)
alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
self.assertEqual(alu, [Ops.SQRT])

def test_pow_neg_05_is_rsqrt(self):
t = Tensor([1.0, 2.0, 3.0]) ** -0.5
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
self.assertEqual(len(s), 1)
alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
self.assertEqual(alu, [Ops.RECIP, Ops.SQRT])

def test_pow_8_has_3_muls(self):
t = Tensor([1.0, 2.0, 3.0]) ** 8
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
self.assertEqual(len(s), 1)
alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
self.assertEqual(alu, [Ops.MUL, Ops.MUL, Ops.MUL])

# folds advance indexing into basic indexing
class TestIndexingConstFolding(unittest.TestCase):
def test_scalar_index(self):
Expand Down
4 changes: 2 additions & 2 deletions test/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,12 +562,12 @@ def test_double_from(self):
def test_pow_const_tensor_simplified(self):
x = Tensor([1,2,3,4])
# NOTE: this does not test ** Tensor(2) is simpler in ast than ** Tensor(2.5)
out = x ** Tensor(2)
out = x ** Tensor(2.0)
check_schedule(out, 1)

def test_pow_const_tensor_to_zero(self):
x = Tensor([1,2,3,4])
out = x ** Tensor(0)
out = x ** Tensor(0.0)
# NOTE: this is ConstBuffer 0 + ConstBuffer 1
check_schedule(out, 0)

Expand Down
8 changes: 8 additions & 0 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,13 @@ def simplify_valid(valid:UOp) -> UOp|None:
if ret[-1] is not stmt: something_changed = True
return functools.reduce(operator.and_, ret) if something_changed else None

def simplify_pow(x:UOp, c:UOp) -> UOp|None:
if c.arg < 0: return x.reciprocal().pow(-c)
if c.arg == 0: return x.const_like(1)
if int(c.arg-0.5)+0.5 == c.arg: return x.pow(c.const_like(c.arg-0.5)) * x.sqrt()
if int(c.arg) == c.arg: return (y := x.pow(c.const_like(c.arg//2))) * y * (x if c.arg%2 == 1 else 1)
return None

# def max_var_const(x:UOp, c1:UOp, c2:UOp):
# if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
# if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
Expand Down Expand Up @@ -1156,6 +1163,7 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype,
(UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
(UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
# ** pow **
(UPat.var("x").alu(Ops.POW, UPat.cvar("c", vec=False)), simplify_pow),
# positive const ** x
(UPat.cvar("c", vec=False).alu(Ops.POW, UPat.var("x")), lambda c,x: c if c.arg == 1 else (x*math.log2(c.arg)).exp2() if c.arg > 0 else None),
])
Expand Down
11 changes: 1 addition & 10 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3289,7 +3289,7 @@ def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
Equivalent to `self ** x`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).pow(2).numpy())
print(Tensor([-1, 2, 3]).pow(2.0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).pow(Tensor([-1.5, 0.5, 1.5])).numpy())
Expand All @@ -3298,15 +3298,6 @@ def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
print((2.0 ** Tensor([-1, 2, 3])).numpy())
```
"""
x = self._to_const_val(x)
if not isinstance(x, Tensor) and not reverse:
# simple pow identities
if x < 0: return self.reciprocal().pow(-x).cast(self.dtype)
if x == 0: return 1 + self * 0
# rewrite pow 0.5 to sqrt
if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)

base, exponent = self._broadcasted(x, reverse=reverse)
# TODO: int pow
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
Expand Down

0 comments on commit 488200f

Please sign in to comment.