Skip to content

Commit e37a76d

Browse files
authored
Add deep gemm with tma pre allocated (InternLM#3287)
* add deep gemm with tma pre allocated * add comment * add comment * dispatch * no use_deep_gemm arg * remove DeepGemmBlockedF8 * missed op type * latest get_best_config * add a line of debug
1 parent f6e7ec7 commit e37a76d

File tree

2 files changed

+258
-2
lines changed

2 files changed

+258
-2
lines changed

lmdeploy/pytorch/backends/cuda/blockedf8_modules.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import torch
55

66
import lmdeploy.pytorch.distributed as dist
7-
from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import blocked_gemm_fp8, quant_fp8
7+
from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import blocked_gemm_fp8, deep_gemm_fp8, quant_fp8, quant_fp8_tma
8+
from lmdeploy.utils import get_logger
89

910
from ..blockedf8_modules import LinearBlockedF8Builder, LinearBlockedF8Impl
1011

12+
logger = get_logger('lmdeploy')
13+
1114

1215
def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int]):
1316
"""reduce scatter."""
@@ -60,4 +63,40 @@ class TritonLinearBlockedF8Builder(LinearBlockedF8Builder):
6063
@staticmethod
6164
def build(in_features: int, out_features: int, block_size: int = 128, bias: bool = True, dtype: torch.dtype = None):
6265
"""build."""
63-
return TritonLinearBlockedF8Impl(in_features, out_features, block_size, dtype)
66+
try:
67+
import deep_gemm # noqa
68+
logger.debug('build with DeepGemmLinearBlockedF8Impl')
69+
return DeepGemmLinearBlockedF8Impl(in_features, out_features, block_size, dtype)
70+
except: # noqa
71+
return TritonLinearBlockedF8Impl(in_features, out_features, block_size, dtype)
72+
73+
74+
class DeepGemmLinearBlockedF8Impl(LinearBlockedF8Impl):
75+
"""Deep gemm blocked f8 implementation."""
76+
77+
def __init__(self, in_features: int, out_features: int, block_size: int, out_dtype: torch.dtype = torch.float16):
78+
self.in_features = in_features
79+
self.out_features = out_features
80+
self.out_dtype = out_dtype
81+
self.block_size = block_size
82+
83+
def forward(self,
84+
x,
85+
weight: torch.Tensor,
86+
scale: torch.Tensor,
87+
bias: Optional[torch.Tensor] = None,
88+
all_reduce: bool = False):
89+
"""forward."""
90+
x_shape = x.shape
91+
x = x.flatten(0, -2)
92+
input_quant, input_scale = quant_fp8_tma(x, self.block_size, dtype=weight.dtype)
93+
94+
out = deep_gemm_fp8(input_quant, input_scale, weight, scale, out_dtype=x.dtype)
95+
if bias is not None:
96+
out += bias
97+
98+
if all_reduce:
99+
dist.all_reduce(out)
100+
101+
out = out.unflatten(0, x_shape[:-1])
102+
return out

lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Tuple
3+
24
import torch
35
import triton
46
import triton.language as tl
@@ -97,6 +99,100 @@ def quant_fp8(A: Tensor, group_size: int, dtype: torch.dtype = torch.float8_e4m3
9799
return out, scales
98100

99101

102+
# adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/utils.py#L46
103+
def get_tma_aligned_size(x: int, element_size: int) -> int:
104+
"""Global memory address of TMA must be 16-byte aligned. Since we use
105+
column-major layout for the LHS scaling tensor, the M-axis of the LHS
106+
scaling tensor needs to be padded to a multiple of 16 bytes.
107+
108+
Arguments:
109+
x: original M-axis shape of the LHS scaling tensor.
110+
element_size: element size of the LHS scaling tensor.
111+
112+
Returns:
113+
M-axis shape of the LHS scaling tensor after padding.
114+
"""
115+
tma_alignment_bytes = 16
116+
assert tma_alignment_bytes % element_size == 0
117+
alignment = tma_alignment_bytes // element_size
118+
return triton.cdiv(x, alignment) * alignment
119+
120+
121+
@triton.jit
122+
def _quant_fp8_tma_kernel(
123+
a_ptr,
124+
out_ptr,
125+
scale_ptr,
126+
fp8_min: tl.constexpr,
127+
fp8_max: tl.constexpr,
128+
stride_am,
129+
stride_ak: tl.constexpr,
130+
stride_om,
131+
stride_ok: tl.constexpr,
132+
stride_sg,
133+
stride_sm: tl.constexpr,
134+
GROUP_SIZE: tl.constexpr,
135+
):
136+
"""quant fp8 kernel."""
137+
group_id = tl.program_id(0)
138+
m_id = tl.program_id(1)
139+
140+
g_offs = group_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
141+
142+
a_ptrs = a_ptr + m_id * stride_am + g_offs * stride_ak
143+
o_ptrs = out_ptr + m_id * stride_om + g_offs * stride_ok
144+
s_ptr = scale_ptr + m_id * stride_sm + group_id * stride_sg
145+
146+
rfp8_max = 1 / fp8_max
147+
148+
a = tl.load(a_ptrs).to(tl.float32)
149+
scale = tl.max(tl.abs(a)) * rfp8_max
150+
out = a / scale
151+
152+
out = tl.clamp(out, fp8_min, fp8_max)
153+
out = out.to(out_ptr.dtype.element_ty)
154+
155+
tl.store(o_ptrs, out)
156+
tl.store(s_ptr, scale)
157+
158+
159+
def quant_fp8_tma(A: Tensor, group_size: int, dtype: torch.dtype = torch.float8_e4m3fn):
160+
"""quant online."""
161+
assert A.dim() == 2
162+
M, K = A.shape
163+
assert K % group_size == 0
164+
num_groups = K // group_size
165+
166+
finfo = torch.finfo(dtype)
167+
fmin = finfo.min
168+
fmax = finfo.max
169+
170+
out = torch.empty_like(A, dtype=dtype)
171+
aligned_M = get_tma_aligned_size(M, torch.float32.itemsize)
172+
scales = A.new_empty(num_groups, aligned_M, dtype=torch.float32)
173+
grid = (num_groups, M)
174+
num_warps = 4
175+
num_stages = 1
176+
_quant_fp8_tma_kernel[grid](
177+
A,
178+
out,
179+
scales,
180+
fp8_min=fmin,
181+
fp8_max=fmax,
182+
stride_am=A.stride(0),
183+
stride_ak=A.stride(1),
184+
stride_om=out.stride(0),
185+
stride_ok=out.stride(1),
186+
stride_sg=scales.stride(0),
187+
stride_sm=scales.stride(1),
188+
GROUP_SIZE=group_size,
189+
num_warps=num_warps,
190+
num_stages=num_stages,
191+
)
192+
193+
return out, scales.transpose(0, 1)
194+
195+
100196
@triton.autotune(configs=[
101197
triton.Config({
102198
'BLOCK_M': 64,
@@ -246,3 +342,124 @@ def grid(META):
246342
)
247343

248344
return C
345+
346+
347+
# adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/utils.py#L77
348+
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
349+
"""Returns TMA-aligned transposed format of the input tensor.
350+
`torch.transpose` will be called if necessary. If the input tensor is
351+
already column-major layout and 16-byte aligned along the M axis (thus
352+
meets the requirement of LHS scaling tensor in DeepGEMM), this function
353+
will do nothing.
354+
355+
Arguments:
356+
x: usually the LHS scaling tensor in GEMM.
357+
358+
Returns:
359+
The LHS scaling tensor of TMA-aligned transposed format.
360+
"""
361+
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
362+
assert x.dim() in (2, 3)
363+
remove_dim = False
364+
if x.dim() == 2:
365+
x, remove_dim = x.unsqueeze(0), True
366+
367+
b, m, n = x.shape
368+
aligned_m = get_tma_aligned_size(m, x.element_size())
369+
370+
# The last kernel gives a column-major TMA aligned layout
371+
# NOTE we modified the stride(0) == aligned_m from stride(0) == aligned_m * n
372+
if x.stride(0) == aligned_m and x.stride(1) == 1 and x.stride(2) == aligned_m:
373+
return x.squeeze(0) if remove_dim else x
374+
375+
# Normal layout requires transposing
376+
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
377+
aligned_x[:, :m, :] = x
378+
aligned_x = aligned_x[:, :m, :]
379+
return aligned_x.squeeze(0) if remove_dim else aligned_x
380+
381+
382+
# adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/gemm.py#L114
383+
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor],
384+
out: torch.Tensor) -> None:
385+
"""Do a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling
386+
and 128x128 RHS scaling. LHS, RHS, RHS scaling factors, and output tensors
387+
must be in contiguous format. RHS and RHS scaling factors are required to
388+
be transposed. The LHS scaling tensor requires TMA-aligned transposed
389+
format, if your input does not match the requirement, this function will do
390+
a transposing with a set of slow PyTorch operations.
391+
392+
Arguments:
393+
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
394+
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
395+
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`.
396+
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
397+
out: the BF16 output tensor of shape `[m, n]`, representing the result.
398+
"""
399+
lhs, lhs_scales = lhs
400+
rhs, rhs_scales = rhs
401+
m, k = lhs.shape
402+
n, k_ = rhs.shape
403+
m_, n_ = out.shape
404+
405+
assert n % 64 == 0 and k % 128 == 0
406+
407+
# Type and shape checks
408+
assert m == m_ and n == n_ and k == k_
409+
assert n > 0 and k > 0
410+
# NOTE This is modified to skip shape[0] check
411+
assert lhs_scales.shape[-1] == (k + 127) // 128
412+
assert rhs_scales.shape == ((n + 127) // 128, (k + 127) // 128)
413+
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
414+
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
415+
assert out.dtype == torch.bfloat16
416+
assert lhs.is_contiguous() and rhs.is_contiguous() and out.is_contiguous()
417+
418+
# LHS scales must be transposed for TMA load, but not for RHS scales
419+
# NOTES: `get_tma_aligned_lhs_scales` may launch a kernel if not processed by previous kernels
420+
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
421+
assert rhs_scales.is_contiguous()
422+
423+
# Do nothing if `m` is zero
424+
if m == 0:
425+
return
426+
427+
# Auto-tuning with compilation
428+
from deep_gemm.jit_kernels.gemm import get_best_configs, get_num_sms, includes, jit_tuner, template
429+
num_sms = get_num_sms()
430+
num_sms, block_m, block_n, num_stages, num_tma_multicast, smem_size = get_best_configs(m, n, k, 1, num_sms)
431+
args = (lhs, lhs_scales, rhs, rhs_scales, out, m, torch.cuda.current_stream(), num_sms, smem_size)
432+
runtime = jit_tuner.compile_and_tune(name='gemm_fp8_fp8_bf16_nt',
433+
keys={
434+
'N': n,
435+
'K': k,
436+
'BLOCK_M': block_m,
437+
'BLOCK_N': block_n,
438+
'NUM_STAGES': num_stages,
439+
'NUM_TMA_MULTICAST': num_tma_multicast
440+
},
441+
space=(),
442+
includes=includes,
443+
arg_defs=(('lhs', torch.float8_e4m3fn), ('lhs_scales', torch.float),
444+
('rhs', torch.float8_e4m3fn), ('rhs_scales', torch.float),
445+
('out', torch.bfloat16), ('m', int), ('stream', torch.cuda.Stream),
446+
('num_sms', int), ('smem_size', int)),
447+
template=template,
448+
args=args)
449+
450+
# Run the kernel
451+
runtime(*args)
452+
453+
454+
def deep_gemm_fp8(A: Tensor,
455+
A_scale: Tensor,
456+
B: Tensor,
457+
B_scale: torch.Tensor,
458+
out_dtype: torch.dtype = torch.bfloat16):
459+
"""deepgemm fp8."""
460+
M, K = A.shape
461+
N, _ = B.shape
462+
assert out_dtype == torch.bfloat16, 'DeepGemm requires bf16 output.'
463+
C = A.new_empty(M, N, dtype=out_dtype)
464+
gemm_fp8_fp8_bf16_nt((A, A_scale), (B, B_scale), C)
465+
return C

0 commit comments

Comments
 (0)