From 56c1ef21e773c4919d8f8bc3275d291e501ab291 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 2 Dec 2024 19:20:54 -0500 Subject: [PATCH] merge where alu when one branch is identity element const this is another case that does not increase alu count --- test/unit/test_uop_symbolic.py | 5 +++-- tinygrad/ops.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 0f182a56653f0..d65cbaf23881c 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -508,8 +508,9 @@ def test_where_combine(self): ba = cond.where(b, a) self.helper_test_variable(ab+ba, 0, 6, "((a if (x<2) else b)+(b if (x<2) else a))") - # not combining # TODO: can combine if one is identity element const - self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))") + # combining if one is identity element const + self.helper_test_variable(aa+ab, 0, 6, "((a*2) if (x<2) else b)") + self.helper_test_variable(bb*ab, 0, 9, "((a*b) if (x<2) else b)") def test_symbolic_div(self): # from symbolic arange diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 247ad6ea29018..edce6caf56bed 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1250,7 +1250,8 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), # alu of two where with same conds can combine, only do if true branch or false branch is const (UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \ - lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None), + lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST or \ + (alu.op in (Ops.MAX, Ops.ADD, Ops.MUL) and any(b.op is Ops.CONST and b.arg==identity_element(alu.op, b.dtype) for b in (t,tt,f,ff))) else None), # ALU min==max -> CONST (slow!) (UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # max folding