|
1 | 1 | # Copyright (c) OpenMMLab. All rights reserved. |
| 2 | +from typing import Tuple |
| 3 | + |
2 | 4 | import torch |
3 | 5 | import triton |
4 | 6 | import triton.language as tl |
@@ -97,6 +99,100 @@ def quant_fp8(A: Tensor, group_size: int, dtype: torch.dtype = torch.float8_e4m3 |
97 | 99 | return out, scales |
98 | 100 |
|
99 | 101 |
|
| 102 | +# adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/utils.py#L46 |
| 103 | +def get_tma_aligned_size(x: int, element_size: int) -> int: |
| 104 | + """Global memory address of TMA must be 16-byte aligned. Since we use |
| 105 | + column-major layout for the LHS scaling tensor, the M-axis of the LHS |
| 106 | + scaling tensor needs to be padded to a multiple of 16 bytes. |
| 107 | +
|
| 108 | + Arguments: |
| 109 | + x: original M-axis shape of the LHS scaling tensor. |
| 110 | + element_size: element size of the LHS scaling tensor. |
| 111 | +
|
| 112 | + Returns: |
| 113 | + M-axis shape of the LHS scaling tensor after padding. |
| 114 | + """ |
| 115 | + tma_alignment_bytes = 16 |
| 116 | + assert tma_alignment_bytes % element_size == 0 |
| 117 | + alignment = tma_alignment_bytes // element_size |
| 118 | + return triton.cdiv(x, alignment) * alignment |
| 119 | + |
| 120 | + |
| 121 | +@triton.jit |
| 122 | +def _quant_fp8_tma_kernel( |
| 123 | + a_ptr, |
| 124 | + out_ptr, |
| 125 | + scale_ptr, |
| 126 | + fp8_min: tl.constexpr, |
| 127 | + fp8_max: tl.constexpr, |
| 128 | + stride_am, |
| 129 | + stride_ak: tl.constexpr, |
| 130 | + stride_om, |
| 131 | + stride_ok: tl.constexpr, |
| 132 | + stride_sg, |
| 133 | + stride_sm: tl.constexpr, |
| 134 | + GROUP_SIZE: tl.constexpr, |
| 135 | +): |
| 136 | + """quant fp8 kernel.""" |
| 137 | + group_id = tl.program_id(0) |
| 138 | + m_id = tl.program_id(1) |
| 139 | + |
| 140 | + g_offs = group_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE) |
| 141 | + |
| 142 | + a_ptrs = a_ptr + m_id * stride_am + g_offs * stride_ak |
| 143 | + o_ptrs = out_ptr + m_id * stride_om + g_offs * stride_ok |
| 144 | + s_ptr = scale_ptr + m_id * stride_sm + group_id * stride_sg |
| 145 | + |
| 146 | + rfp8_max = 1 / fp8_max |
| 147 | + |
| 148 | + a = tl.load(a_ptrs).to(tl.float32) |
| 149 | + scale = tl.max(tl.abs(a)) * rfp8_max |
| 150 | + out = a / scale |
| 151 | + |
| 152 | + out = tl.clamp(out, fp8_min, fp8_max) |
| 153 | + out = out.to(out_ptr.dtype.element_ty) |
| 154 | + |
| 155 | + tl.store(o_ptrs, out) |
| 156 | + tl.store(s_ptr, scale) |
| 157 | + |
| 158 | + |
| 159 | +def quant_fp8_tma(A: Tensor, group_size: int, dtype: torch.dtype = torch.float8_e4m3fn): |
| 160 | + """quant online.""" |
| 161 | + assert A.dim() == 2 |
| 162 | + M, K = A.shape |
| 163 | + assert K % group_size == 0 |
| 164 | + num_groups = K // group_size |
| 165 | + |
| 166 | + finfo = torch.finfo(dtype) |
| 167 | + fmin = finfo.min |
| 168 | + fmax = finfo.max |
| 169 | + |
| 170 | + out = torch.empty_like(A, dtype=dtype) |
| 171 | + aligned_M = get_tma_aligned_size(M, torch.float32.itemsize) |
| 172 | + scales = A.new_empty(num_groups, aligned_M, dtype=torch.float32) |
| 173 | + grid = (num_groups, M) |
| 174 | + num_warps = 4 |
| 175 | + num_stages = 1 |
| 176 | + _quant_fp8_tma_kernel[grid]( |
| 177 | + A, |
| 178 | + out, |
| 179 | + scales, |
| 180 | + fp8_min=fmin, |
| 181 | + fp8_max=fmax, |
| 182 | + stride_am=A.stride(0), |
| 183 | + stride_ak=A.stride(1), |
| 184 | + stride_om=out.stride(0), |
| 185 | + stride_ok=out.stride(1), |
| 186 | + stride_sg=scales.stride(0), |
| 187 | + stride_sm=scales.stride(1), |
| 188 | + GROUP_SIZE=group_size, |
| 189 | + num_warps=num_warps, |
| 190 | + num_stages=num_stages, |
| 191 | + ) |
| 192 | + |
| 193 | + return out, scales.transpose(0, 1) |
| 194 | + |
| 195 | + |
100 | 196 | @triton.autotune(configs=[ |
101 | 197 | triton.Config({ |
102 | 198 | 'BLOCK_M': 64, |
@@ -246,3 +342,124 @@ def grid(META): |
246 | 342 | ) |
247 | 343 |
|
248 | 344 | return C |
| 345 | + |
| 346 | + |
| 347 | +# adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/utils.py#L77 |
| 348 | +def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: |
| 349 | + """Returns TMA-aligned transposed format of the input tensor. |
| 350 | + `torch.transpose` will be called if necessary. If the input tensor is |
| 351 | + already column-major layout and 16-byte aligned along the M axis (thus |
| 352 | + meets the requirement of LHS scaling tensor in DeepGEMM), this function |
| 353 | + will do nothing. |
| 354 | +
|
| 355 | + Arguments: |
| 356 | + x: usually the LHS scaling tensor in GEMM. |
| 357 | +
|
| 358 | + Returns: |
| 359 | + The LHS scaling tensor of TMA-aligned transposed format. |
| 360 | + """ |
| 361 | + # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA |
| 362 | + assert x.dim() in (2, 3) |
| 363 | + remove_dim = False |
| 364 | + if x.dim() == 2: |
| 365 | + x, remove_dim = x.unsqueeze(0), True |
| 366 | + |
| 367 | + b, m, n = x.shape |
| 368 | + aligned_m = get_tma_aligned_size(m, x.element_size()) |
| 369 | + |
| 370 | + # The last kernel gives a column-major TMA aligned layout |
| 371 | + # NOTE we modified the stride(0) == aligned_m from stride(0) == aligned_m * n |
| 372 | + if x.stride(0) == aligned_m and x.stride(1) == 1 and x.stride(2) == aligned_m: |
| 373 | + return x.squeeze(0) if remove_dim else x |
| 374 | + |
| 375 | + # Normal layout requires transposing |
| 376 | + aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) |
| 377 | + aligned_x[:, :m, :] = x |
| 378 | + aligned_x = aligned_x[:, :m, :] |
| 379 | + return aligned_x.squeeze(0) if remove_dim else aligned_x |
| 380 | + |
| 381 | + |
| 382 | +# adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/gemm.py#L114 |
| 383 | +def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], |
| 384 | + out: torch.Tensor) -> None: |
| 385 | + """Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling |
| 386 | + and 128x128 RHS scaling. LHS, RHS, RHS scaling factors, and output tensors |
| 387 | + must be in contiguous format. RHS and RHS scaling factors are required to |
| 388 | + be transposed. The LHS scaling tensor requires TMA-aligned transposed |
| 389 | + format, if your input does not match the requirement, this function will do |
| 390 | + a transposing with a set of slow PyTorch operations. |
| 391 | +
|
| 392 | + Arguments: |
| 393 | + lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`, |
| 394 | + the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`. |
| 395 | + rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`. |
| 396 | + the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`. |
| 397 | + out: the BF16 output tensor of shape `[m, n]`, representing the result. |
| 398 | + """ |
| 399 | + lhs, lhs_scales = lhs |
| 400 | + rhs, rhs_scales = rhs |
| 401 | + m, k = lhs.shape |
| 402 | + n, k_ = rhs.shape |
| 403 | + m_, n_ = out.shape |
| 404 | + |
| 405 | + assert n % 64 == 0 and k % 128 == 0 |
| 406 | + |
| 407 | + # Type and shape checks |
| 408 | + assert m == m_ and n == n_ and k == k_ |
| 409 | + assert n > 0 and k > 0 |
| 410 | + # NOTE This is modified to skip shape[0] check |
| 411 | + assert lhs_scales.shape[-1] == (k + 127) // 128 |
| 412 | + assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128) |
| 413 | + assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 |
| 414 | + assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 |
| 415 | + assert out.dtype == torch.bfloat16 |
| 416 | + assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous() |
| 417 | + |
| 418 | + # LHS scales must be transposed for TMA load, but not for RHS scales |
| 419 | + # NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels |
| 420 | + lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) |
| 421 | + assert rhs_scales.is_contiguous() |
| 422 | + |
| 423 | + # Do nothing if `m` is zero |
| 424 | + if m == 0: |
| 425 | + return |
| 426 | + |
| 427 | + # Auto-tuning with compilation |
| 428 | + from deep_gemm.jit_kernels.gemm import get_best_configs, get_num_sms, includes, jit_tuner, template |
| 429 | + num_sms = get_num_sms() |
| 430 | + num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms) |
| 431 | + args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size) |
| 432 | + runtime = jit_tuner.compile_and_tune(name='gemm_fp8_fp8_bf16_nt', |
| 433 | + keys={ |
| 434 | + 'N': n, |
| 435 | + 'K': k, |
| 436 | + 'BLOCK_M': block_m, |
| 437 | + 'BLOCK_N': block_n, |
| 438 | + 'NUM_STAGES': num_stages, |
| 439 | + 'NUM_TMA_MULTICAST': num_tma_multicast |
| 440 | + }, |
| 441 | + space=(), |
| 442 | + includes=includes, |
| 443 | + arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float), |
| 444 | + ('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float), |
| 445 | + ('out', torch.bfloat16), ('m', int), ('stream', torch.cuda.Stream), |
| 446 | + ('num_sms', int), ('smem_size', int)), |
| 447 | + template=template, |
| 448 | + args=args) |
| 449 | + |
| 450 | + # Run the kernel |
| 451 | + runtime(*args) |
| 452 | + |
| 453 | + |
| 454 | +def deep_gemm_fp8(A: Tensor, |
| 455 | + A_scale: Tensor, |
| 456 | + B: Tensor, |
| 457 | + B_scale: torch.Tensor, |
| 458 | + out_dtype: torch.dtype = torch.bfloat16): |
| 459 | + """deepgemm fp8.""" |
| 460 | + M, K = A.shape |
| 461 | + N, _ = B.shape |
| 462 | + assert out_dtype == torch.bfloat16, 'DeepGemm requires bf16 output.' |
| 463 | + C = A.new_empty(M, N, dtype=out_dtype) |
| 464 | + gemm_fp8_fp8_bf16_nt((A, A_scale), (B, B_scale), C) |
| 465 | + return C |
0 commit comments