Skip to content

Commit

Permalink
remove SQRT hack in llvm (tinygrad#9067)
Browse files Browse the repository at this point in the history
replaced with xpow 0.5 in transcendental. fixed sqrt(0) backward
  • Loading branch information
chenyuxyz authored Feb 13, 2025
1 parent 947c97e commit e02e3b9
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 13 deletions.
4 changes: 2 additions & 2 deletions test/imported/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ def test_getitem_scalars(self):
one = Tensor(1, dtype=dtypes.int64)

# non-scalar indexed with scalars
a = Tensor.randn(2, 3)
a = Tensor.randn(2, 3).realize()
numpy_testing_assert_equal_helper(a[0], a[zero])
numpy_testing_assert_equal_helper(a[0][1], a[zero][one])
numpy_testing_assert_equal_helper(a[0, 1], a[zero, one])
Expand All @@ -1066,7 +1066,7 @@ def test_getitem_scalars(self):
numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int16)])

# scalar indexed with scalar
r = Tensor.randn()
r = Tensor.randn().realize()
with self.assertRaises(IndexError):
r[:]
with self.assertRaises(IndexError):
Expand Down
4 changes: 2 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@ def test_load_state_dict_sharded_model(self):
layer.weight.shard_(devices, 3)
layer.bias.shard_(devices, None)
state_dict = {
'weight': Tensor.randn(5, 3, 3, 3),
'bias': Tensor.randn(5),
'weight': Tensor.randn(5, 3, 3, 3).realize(),
'bias': Tensor.randn(5).realize(),
}
load_state_dict(layer, state_dict)

Expand Down
6 changes: 2 additions & 4 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,9 +641,7 @@ def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, ex

def test_sqrt(self):
helper_test_op([(45,65)], lambda x: x.sqrt())
if Device.DEFAULT not in ("LLVM", "DSP"):
# TODO: fix backward
helper_test_op(None, lambda x: x.sqrt(), vals=[[0.0]])
helper_test_op(None, lambda x: x.sqrt(), vals=[[0.0]])
helper_test_op([()], lambda x: x.sqrt())
def test_rsqrt(self):
helper_test_op([(45,65)], lambda x: x.rsqrt())
Expand Down Expand Up @@ -1406,7 +1404,7 @@ def test_hardtanh(self):
def test_asinh(self):
helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6)
# NOTE: this one has larger atol
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, grad_atol=1e-6, low=-300, high=-297)
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, rtol=2e-2, grad_atol=1e-6, low=-300, high=-297)
helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6, low=300, high=303)
def test_acosh(self):
helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6)
Expand Down
2 changes: 2 additions & 0 deletions tinygrad/codegen/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
def get_late_rewrite_patterns(ops, force_transcendental=False):
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
# rewrite SQRT to xpow 0.5
if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
if Ops.AND in ops:
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
Expand Down
5 changes: 0 additions & 5 deletions tinygrad/renderer/llvmir.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1
f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}" if isinstance(x.dtype, PtrDType) else None),

# unary/binary/ternary ops
(UPat(Ops.SQRT, name="x"), lambda ctx,x:
f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(GroupOp.Binary, name="x"), lambda ctx,x:
Expand Down Expand Up @@ -152,9 +150,6 @@ def render(self, name: str, uops: list[UOp]) -> str:
f" {r[u]}_ptr_amx{i} = ptrtoint {ldt(dtype.ptr())} {r[u]}_amx{i} to i64"]

for u in uops:
# hack for defining sqrt function (TODO: can we get a transcendental for this?)
if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None

if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
# NOTE: MallocAllocator promises 0x20 alignment
Expand Down

0 comments on commit e02e3b9

Please sign in to comment.