diff --git a/python/sgl_kernel_npu/sgl_kernel_npu/activation/__init__.py b/python/sgl_kernel_npu/sgl_kernel_npu/activation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/sgl_kernel_npu/sgl_kernel_npu/activation/swiglu_quant.py b/python/sgl_kernel_npu/sgl_kernel_npu/activation/swiglu_quant.py new file mode 100644 index 00000000..9759e830 --- /dev/null +++ b/python/sgl_kernel_npu/sgl_kernel_npu/activation/swiglu_quant.py @@ -0,0 +1,114 @@ +import torch +import triton +import triton.language as tl +import triton.runtime.driver as driver + + +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +@triton.jit +def _swiglu_quant_kernel( + x_ptr, + group_list_ptr, + out_ptr, + scale_ptr, + TOTAL_COLS: tl.constexpr, + HALF_COLS: tl.constexpr, + COL_BLOCK_SIZE: tl.constexpr, + NUM_EXPERTS: tl.constexpr, + NUM_EXPERTS_ALGIN: tl.constexpr, + GROUP_LIST_TYPE: tl.constexpr, + NUM_CORES: tl.constexpr, + DTYPE_MAX: tl.constexpr, + SCALE: tl.constexpr, +): + # calc real total_rows + if GROUP_LIST_TYPE == 0: # cusum + total_rows = tl.load(group_list_ptr + NUM_EXPERTS).to(tl.int32) + else: + gl_offsets = tl.arange(0, NUM_EXPERTS_ALGIN) + gl_mask = gl_offsets < NUM_EXPERTS + group_list = tl.load(group_list_ptr + gl_offsets, gl_mask, other=0).to(tl.int32) + total_rows = tl.sum(group_list) + + block_size = (total_rows - 1) // NUM_CORES + 1 + pid = tl.program_id(0) + row_begin = pid * block_size + if row_begin >= total_rows: + return + row_end = tl.minimum((pid + 1) * block_size, total_rows) + + for row_idx in range(row_begin, row_end): + # swiglu + x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS) + cur_x = tl.load(x_ptr + x_offsets) + x1 = tl.extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,)) + x2 = tl.extract_slice( + cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,) + ) + out = x1 * tl.sigmoid(x1) * x2 + + # quant + if SCALE: + scale = tl.max(tl.abs(out)).to(tl.float32) / DTYPE_MAX + # store scale + tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty)) + # out = tl.math.rint(out / scale.reshape(SUB_BLOCK_SIZE, 1)) # ub overflow + for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE): + tmp_out = tl.extract_slice( + out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,) + ) + tmp_out = (tmp_out.to(tl.float32) / scale).to(x_ptr.dtype.element_ty) + # tmp_out = tl.clamp(tmp_out, -128, 127) + tmp_out = tl.math.rint(tmp_out) + + o_offsets = ( + row_idx * HALF_COLS + col_blk_idx + tl.arange(0, COL_BLOCK_SIZE) + ) + tl.store(out_ptr + o_offsets, tmp_out.to(out_ptr.dtype.element_ty)) + else: + # store out + o_offsets = row_idx * HALF_COLS + tl.arange(0, HALF_COLS) + tl.store(out_ptr + o_offsets, out.to(out_ptr.dtype.element_ty)) + + +def swiglu_quant(x, group_list, group_list_type, need_quant=True): + # group_list_type must be 0 cusum or 1 count + if group_list_type not in [0, 1]: + raise ValueError(f"group_list_type must be 0 or 1, but got {group_list_type}") + s, h = x.shape + out_dtype = torch.int8 if need_quant else x.dtype + out = torch.empty((s, h // 2), dtype=out_dtype, device=x.device) + scale = torch.empty((s,), dtype=torch.float32, device=x.device) + num_experts = group_list.shape[0] + # ub must be 32-byte aligned on npu + if group_list.dtype == torch.int64: + num_experts_algin = (num_experts + 7) // 8 * 8 + elif group_list.dtype == torch.int32: + num_experts_algin = (num_experts + 15) // 16 * 16 + else: + raise ValueError( + f"group_list dtype must be torch.int32 or torch.int64, but got {group_list.dtype}" + ) + + num_cores = get_npu_properties()["num_vectorcore"] + _swiglu_quant_kernel[(num_cores,)]( + x, + group_list, + out, + scale, + TOTAL_COLS=h, + HALF_COLS=h // 2, + COL_BLOCK_SIZE=1024, + NUM_EXPERTS=num_experts, + NUM_EXPERTS_ALGIN=num_experts_algin, + GROUP_LIST_TYPE=group_list_type, + NUM_CORES=num_cores, + DTYPE_MAX=127, + SCALE=need_quant, + multibuffer=True, + ) + return out, scale diff --git a/tests/python/sgl_kernel_npu/test_swiglu_quant.py b/tests/python/sgl_kernel_npu/test_swiglu_quant.py new file mode 100644 index 00000000..3195dd52 --- /dev/null +++ b/tests/python/sgl_kernel_npu/test_swiglu_quant.py @@ -0,0 +1,44 @@ +import numpy as np +import torch +import torch_npu +from sgl_kernel_npu.activation.swiglu_quant import swiglu_quant + + +def test_swiglu_quant(): + def to_numpy(x: torch.Tensor) -> np.ndarray: + return x.detach().cpu().numpy() + + # create inputs + s, h = 4096, 4096 + x = torch.randn((s, h), dtype=torch.bfloat16).npu() + group_list = ( + torch.Tensor([0, 32, 0, 0, 10, 0, 0, 0, 100, 0, 0, 5, 5, 5, 0, 0]) + .npu() + .to(torch.int64) + ) + # torch native + swglu_out = torch_npu.npu_swiglu(x) + ans1, ans2 = torch_npu.npu_dynamic_quant(swglu_out) + # fused_triton_kernel + res1, res2 = swiglu_quant(x, group_list, group_list_type=1) + + real_tokens = torch.sum(group_list) + diff = res1[:real_tokens, :] - ans1[:real_tokens, :] + max_diff = torch.max(torch.abs(diff)) + assert max_diff <= 1 + + diff_rate = torch.sum(torch.abs(diff)) / (real_tokens * h // 2) + assert diff_rate < 2e-2 + + assert ( + np.testing.assert_allclose( + to_numpy(res2[:real_tokens]), + to_numpy(ans2[:real_tokens]), + rtol=5e-3, + ) + is None + ) + + +if __name__ == "__main__": + test_swiglu_quant()