Skip to content

Commit

Permalink
add Half-Precision Accumulation Support for Tensor Cores in NV, CUDA,…
Browse files Browse the repository at this point in the history
… and PTX (tinygrad#8680)

* ptx and nv rendering refactor to work with half acc

* ptx fix!

* use same reg for acc and out

* fix comment

* another fix

* minor change in commet

* fix

---------

Co-authored-by: qazal <[email protected]>
  • Loading branch information
ignaciosica and Qazalin authored Feb 5, 2025
1 parent 17f9b1c commit cad44f5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
17 changes: 10 additions & 7 deletions tinygrad/renderer/cstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,11 @@ class CUDARenderer(CStyleLanguage):
local_max = (1024, 1024, 64)
shared_max = 49152
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di,dtype_out=do, opts=cuda_tc_opts,
swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float)]]
tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.half, dtype_out=dtypes.float, opts=cuda_tc_opts,
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5))))]
tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float),
(dtypes.half,dtypes.half)]]
tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]

Expand Down Expand Up @@ -344,7 +345,8 @@ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]

dt_map = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
Expand All @@ -353,10 +355,11 @@ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):

# mma operands => {c}, {a}, {b}, {c}
prefix.append(f"""__device__ {wmma_dtypes[2]} __{name}({wmma_dtypes[0]} a, {wmma_dtypes[1]} b, {wmma_dtypes[2]} c){{
int *a_pk = (int *)(&a), *b_pk = (int *)(&b);\n asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32"
int *a_pk = (int *)(&a), *b_pk = (int *)(&b), *c_pk = (int *)(&c);
asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}"
"{{{", ".join(operands[:n_operands[2]])}}}, {{{", ".join(operands[n_operands[2]:n_operands[2]+n_operands[0]])}}},"
"{{{", ".join(operands[-n_operands[1]:])}}}, {{{", ".join(operands[:n_operands[2]])}}};"
: {", ".join([f'"+f"(c.{_nms[i]})' for i in range(n_operands[2])])}
: {", ".join([f'"+r"(c_pk[{i}])' for i in range(n_operands[2])])}
: {", ".join([f'"r"(a_pk[{i}])' for i in range(n_operands[0])])}, {", ".join([f'"r"(b_pk[{i}])' for i in range(n_operands[1])])});
return c;\n}}""")

Expand Down
2 changes: 1 addition & 1 deletion tinygrad/renderer/ptx.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def render_wmma(ctx: "PTXRenderer", wmma: UOp):
if (elems_per_reg := 4 // src.dtype.scalar().itemsize) == 1: yield f"mov.b32 {reg}, {ctx.r[src][i]};"
else: yield f"mov.b32 {reg}, {{{', '.join(ctx.r[src][i * elems_per_reg : (i+1) * elems_per_reg])}}};"

dt_map_in, dt_map_out = {dtypes.float: "tf32", dtypes.half: "f16"}, {dtypes.float: "f32"}
dt_map_in, dt_map_out = {dtypes.float: "tf32", dtypes.half: "f16"}, {dtypes.float: "f32", dtypes.half: "f16"}
yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}{" "*12}'+\
f'{{{", ".join(ctx.wmma_r[2])}}}, {{{", ".join(ctx.wmma_r[0])}}}, {{{", ".join(ctx.wmma_r[1])}}}, {{{", ".join(ctx.wmma_r[2])}}};'

Expand Down

0 comments on commit cad44f5

Please sign in to comment.