Skip to content

Commit

Permalink
remove duplicated dtype in DEFINE_GLOBAL args (tinygrad#2768)
Browse files Browse the repository at this point in the history
now DEFINE_GLOBAL uop.arg[1] is always the same as uop.dtype, we can remove the one in arg and just use uop.dtype
  • Loading branch information
chenyuxyz authored Dec 14, 2023
1 parent 5235cde commit 57017c8
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 14 deletions.
2 changes: 1 addition & 1 deletion extra/assembly/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
lang.ins.clear()
lang.tor.clear()
lang.cnts.clear()
buf_to_dtype = {args[0]:args[1] for uop,_,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
global_size, local_size = [], []
skipload_branch = 0
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
Expand Down
6 changes: 3 additions & 3 deletions extra/triton/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, f
kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
elif uop == UOps.DEFINE_GLOBAL:
bufs.append(args)
signatures.append("*" if isinstance(args[1], PtrDType) else "" + signature_dtypes[args[1]])
r[u] = args[0]
signatures.append("*" if isinstance(dtype, PtrDType) else "" + signature_dtypes[dtype])
r[u] = args
elif uop == UOps.SPECIAL:
dims.append(args[1])
valid.append(f"{args[1]}<{get_max(args[2])}")
Expand All @@ -111,7 +111,7 @@ def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, f
elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1])
else: raise NotImplementedError(f"unimplemented: {uop}")

prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(f"{buf[0]}" for buf in bufs)+"):\n"
prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(bufs)+"):\n"
for i, line in enumerate(list(filter(lambda line: "tl.arange" in line, kernel))): kernel[kernel.index(line)] += f"[{', '.join([':' if i == j else 'None' for j in range(len(local_size))])}]"
prg += "\n".join(kernel)

Expand Down
6 changes: 3 additions & 3 deletions test/test_uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], ar

def _test_single_value(vals, op, dtype):
uops = []
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), ('data0', dtype))
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (f'data{i+1}', dtype)) for i in range(len(vals))]
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 'data0')
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), f'data{i+1}') for i in range(len(vals))]
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i in range(len(vals)))
alu = uop(uops, UOps.ALU, dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
Expand All @@ -33,7 +33,7 @@ def _test_single_value(vals, op, dtype):

def _test_single_value_const(vals, op, dtype):
uops = []
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), ('data0', dtype))
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 'data0')
loads = (uop(uops, UOps.CONST, dtype, [], a) for a in vals)
alu = uop(uops, UOps.ALU, dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/codegen/linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,11 @@ def linearize(self):
# add global buffers
for i,buf in enumerate(self.bufs):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, dtype:=PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", dtype)) # noqa: E501
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), f"data{buf.idx}")
# add var vals
for var in vars_from_ast(self.ast):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes.int32))
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), var.expr)
# define local buffers
for lb in self.local_alias.values():
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size()))
Expand Down
5 changes: 3 additions & 2 deletions tinygrad/renderer/cstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,9 @@ def ssa(u, prefix="t"):
kk(lang.render_local(args[0], args[1]))
r[u] = args[0]
elif uop == UOps.DEFINE_GLOBAL:
bufs.append(args)
r[u] = args[0]
assert dtype is not None
bufs.append((args, dtype))
r[u] = args
elif uop == UOps.GEP:
if cast(DType, vin[0].dtype).sz > 4:
r[u] = f"({r[vin[0]]})[{args}]" # this is correct for HIP
Expand Down
6 changes: 3 additions & 3 deletions tinygrad/renderer/llvmir.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
module = ir.Module(name=__file__)

# extract global buffers
buf_to_dtype = {u.arg[0]:u.arg[1] for u in uops if u.uop == UOps.DEFINE_GLOBAL}
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop == UOps.DEFINE_GLOBAL}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}

# create llvm function
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values()]
func_dtypes = [(dtype_to_llvm_dtype[dtype],dtype) for dtype in buf_to_dtype.values() if dtype is not None]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=function_name) # noqa: E501
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")
Expand Down Expand Up @@ -125,7 +125,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), block._block, bb[-1]._block)
if uop == UOps.DEFINE_GLOBAL:
lvars[u] = func.args[buf_index[args[0]]]
lvars[u] = func.args[buf_index[args]]
if uop == UOps.DEFINE_ACC:
lvars[u] = const(args, dtype)
reduce_phis.append(u)
Expand Down

0 comments on commit 57017c8

Please sign in to comment.