Skip to content

Commit d54224c

Browse files
authored
impl fused_swiglu_quant with group_list for deepep-low-latency (#155)
* impl fused_swiglu_quant with group_list for deepep-low-latency * add a comment * cleancode * cleancode
1 parent b489f67 commit d54224c

File tree

3 files changed

+158
-0
lines changed

3 files changed

+158
-0
lines changed

python/sgl_kernel_npu/sgl_kernel_npu/activation/__init__.py

Whitespace-only changes.
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
import triton.runtime.driver as driver
5+
6+
7+
def get_npu_properties():
8+
device = torch.npu.current_device()
9+
return driver.active.utils.get_device_properties(device)
10+
11+
12+
@triton.jit
13+
def _swiglu_quant_kernel(
14+
x_ptr,
15+
group_list_ptr,
16+
out_ptr,
17+
scale_ptr,
18+
TOTAL_COLS: tl.constexpr,
19+
HALF_COLS: tl.constexpr,
20+
COL_BLOCK_SIZE: tl.constexpr,
21+
NUM_EXPERTS: tl.constexpr,
22+
NUM_EXPERTS_ALGIN: tl.constexpr,
23+
GROUP_LIST_TYPE: tl.constexpr,
24+
NUM_CORES: tl.constexpr,
25+
DTYPE_MAX: tl.constexpr,
26+
SCALE: tl.constexpr,
27+
):
28+
# calc real total_rows
29+
if GROUP_LIST_TYPE == 0: # cusum
30+
total_rows = tl.load(group_list_ptr + NUM_EXPERTS).to(tl.int32)
31+
else:
32+
gl_offsets = tl.arange(0, NUM_EXPERTS_ALGIN)
33+
gl_mask = gl_offsets < NUM_EXPERTS
34+
group_list = tl.load(group_list_ptr + gl_offsets, gl_mask, other=0).to(tl.int32)
35+
total_rows = tl.sum(group_list)
36+
37+
block_size = (total_rows - 1) // NUM_CORES + 1
38+
pid = tl.program_id(0)
39+
row_begin = pid * block_size
40+
if row_begin >= total_rows:
41+
return
42+
row_end = tl.minimum((pid + 1) * block_size, total_rows)
43+
44+
for row_idx in range(row_begin, row_end):
45+
# swiglu
46+
x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS)
47+
cur_x = tl.load(x_ptr + x_offsets)
48+
x1 = tl.extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,))
49+
x2 = tl.extract_slice(
50+
cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,)
51+
)
52+
out = x1 * tl.sigmoid(x1) * x2
53+
54+
# quant
55+
if SCALE:
56+
scale = tl.max(tl.abs(out)).to(tl.float32) / DTYPE_MAX
57+
# store scale
58+
tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty))
59+
# out = tl.math.rint(out / scale.reshape(SUB_BLOCK_SIZE, 1)) # ub overflow
60+
for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE):
61+
tmp_out = tl.extract_slice(
62+
out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,)
63+
)
64+
tmp_out = (tmp_out.to(tl.float32) / scale).to(x_ptr.dtype.element_ty)
65+
# tmp_out = tl.clamp(tmp_out, -128, 127)
66+
tmp_out = tl.math.rint(tmp_out)
67+
68+
o_offsets = (
69+
row_idx * HALF_COLS + col_blk_idx + tl.arange(0, COL_BLOCK_SIZE)
70+
)
71+
tl.store(out_ptr + o_offsets, tmp_out.to(out_ptr.dtype.element_ty))
72+
else:
73+
# store out
74+
o_offsets = row_idx * HALF_COLS + tl.arange(0, HALF_COLS)
75+
tl.store(out_ptr + o_offsets, out.to(out_ptr.dtype.element_ty))
76+
77+
78+
def swiglu_quant(x, group_list, group_list_type, need_quant=True):
79+
# group_list_type must be 0 cusum or 1 count
80+
if group_list_type not in [0, 1]:
81+
raise ValueError(f"group_list_type must be 0 or 1, but got {group_list_type}")
82+
s, h = x.shape
83+
out_dtype = torch.int8 if need_quant else x.dtype
84+
out = torch.empty((s, h // 2), dtype=out_dtype, device=x.device)
85+
scale = torch.empty((s,), dtype=torch.float32, device=x.device)
86+
num_experts = group_list.shape[0]
87+
# ub must be 32-byte aligned on npu
88+
if group_list.dtype == torch.int64:
89+
num_experts_algin = (num_experts + 7) // 8 * 8
90+
elif group_list.dtype == torch.int32:
91+
num_experts_algin = (num_experts + 15) // 16 * 16
92+
else:
93+
raise ValueError(
94+
f"group_list dtype must be torch.int32 or torch.int64, but got {group_list.dtype}"
95+
)
96+
97+
num_cores = get_npu_properties()["num_vectorcore"]
98+
_swiglu_quant_kernel[(num_cores,)](
99+
x,
100+
group_list,
101+
out,
102+
scale,
103+
TOTAL_COLS=h,
104+
HALF_COLS=h // 2,
105+
COL_BLOCK_SIZE=1024,
106+
NUM_EXPERTS=num_experts,
107+
NUM_EXPERTS_ALGIN=num_experts_algin,
108+
GROUP_LIST_TYPE=group_list_type,
109+
NUM_CORES=num_cores,
110+
DTYPE_MAX=127,
111+
SCALE=need_quant,
112+
multibuffer=True,
113+
)
114+
return out, scale
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import numpy as np
2+
import torch
3+
import torch_npu
4+
from sgl_kernel_npu.activation.swiglu_quant import swiglu_quant
5+
6+
7+
def test_swiglu_quant():
8+
def to_numpy(x: torch.Tensor) -> np.ndarray:
9+
return x.detach().cpu().numpy()
10+
11+
# create inputs
12+
s, h = 4096, 4096
13+
x = torch.randn((s, h), dtype=torch.bfloat16).npu()
14+
group_list = (
15+
torch.Tensor([0, 32, 0, 0, 10, 0, 0, 0, 100, 0, 0, 5, 5, 5, 0, 0])
16+
.npu()
17+
.to(torch.int64)
18+
)
19+
# torch native
20+
swglu_out = torch_npu.npu_swiglu(x)
21+
ans1, ans2 = torch_npu.npu_dynamic_quant(swglu_out)
22+
# fused_triton_kernel
23+
res1, res2 = swiglu_quant(x, group_list, group_list_type=1)
24+
25+
real_tokens = torch.sum(group_list)
26+
diff = res1[:real_tokens, :] - ans1[:real_tokens, :]
27+
max_diff = torch.max(torch.abs(diff))
28+
assert max_diff <= 1
29+
30+
diff_rate = torch.sum(torch.abs(diff)) / (real_tokens * h // 2)
31+
assert diff_rate < 2e-2
32+
33+
assert (
34+
np.testing.assert_allclose(
35+
to_numpy(res2[:real_tokens]),
36+
to_numpy(ans2[:real_tokens]),
37+
rtol=5e-3,
38+
)
39+
is None
40+
)
41+
42+
43+
if __name__ == "__main__":
44+
test_swiglu_quant()

0 commit comments

Comments
 (0)