|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | +import triton.runtime.driver as driver |
| 5 | + |
| 6 | + |
| 7 | +def get_npu_properties(): |
| 8 | + device = torch.npu.current_device() |
| 9 | + return driver.active.utils.get_device_properties(device) |
| 10 | + |
| 11 | + |
| 12 | +@triton.jit |
| 13 | +def _swiglu_quant_kernel( |
| 14 | + x_ptr, |
| 15 | + group_list_ptr, |
| 16 | + out_ptr, |
| 17 | + scale_ptr, |
| 18 | + TOTAL_COLS: tl.constexpr, |
| 19 | + HALF_COLS: tl.constexpr, |
| 20 | + COL_BLOCK_SIZE: tl.constexpr, |
| 21 | + NUM_EXPERTS: tl.constexpr, |
| 22 | + NUM_EXPERTS_ALGIN: tl.constexpr, |
| 23 | + GROUP_LIST_TYPE: tl.constexpr, |
| 24 | + NUM_CORES: tl.constexpr, |
| 25 | + DTYPE_MAX: tl.constexpr, |
| 26 | + SCALE: tl.constexpr, |
| 27 | +): |
| 28 | + # calc real total_rows |
| 29 | + if GROUP_LIST_TYPE == 0: # cusum |
| 30 | + total_rows = tl.load(group_list_ptr + NUM_EXPERTS).to(tl.int32) |
| 31 | + else: |
| 32 | + gl_offsets = tl.arange(0, NUM_EXPERTS_ALGIN) |
| 33 | + gl_mask = gl_offsets < NUM_EXPERTS |
| 34 | + group_list = tl.load(group_list_ptr + gl_offsets, gl_mask, other=0).to(tl.int32) |
| 35 | + total_rows = tl.sum(group_list) |
| 36 | + |
| 37 | + block_size = (total_rows - 1) // NUM_CORES + 1 |
| 38 | + pid = tl.program_id(0) |
| 39 | + row_begin = pid * block_size |
| 40 | + if row_begin >= total_rows: |
| 41 | + return |
| 42 | + row_end = tl.minimum((pid + 1) * block_size, total_rows) |
| 43 | + |
| 44 | + for row_idx in range(row_begin, row_end): |
| 45 | + # swiglu |
| 46 | + x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS) |
| 47 | + cur_x = tl.load(x_ptr + x_offsets) |
| 48 | + x1 = tl.extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,)) |
| 49 | + x2 = tl.extract_slice( |
| 50 | + cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,) |
| 51 | + ) |
| 52 | + out = x1 * tl.sigmoid(x1) * x2 |
| 53 | + |
| 54 | + # quant |
| 55 | + if SCALE: |
| 56 | + scale = tl.max(tl.abs(out)).to(tl.float32) / DTYPE_MAX |
| 57 | + # store scale |
| 58 | + tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty)) |
| 59 | + # out = tl.math.rint(out / scale.reshape(SUB_BLOCK_SIZE, 1)) # ub overflow |
| 60 | + for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE): |
| 61 | + tmp_out = tl.extract_slice( |
| 62 | + out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,) |
| 63 | + ) |
| 64 | + tmp_out = (tmp_out.to(tl.float32) / scale).to(x_ptr.dtype.element_ty) |
| 65 | + # tmp_out = tl.clamp(tmp_out, -128, 127) |
| 66 | + tmp_out = tl.math.rint(tmp_out) |
| 67 | + |
| 68 | + o_offsets = ( |
| 69 | + row_idx * HALF_COLS + col_blk_idx + tl.arange(0, COL_BLOCK_SIZE) |
| 70 | + ) |
| 71 | + tl.store(out_ptr + o_offsets, tmp_out.to(out_ptr.dtype.element_ty)) |
| 72 | + else: |
| 73 | + # store out |
| 74 | + o_offsets = row_idx * HALF_COLS + tl.arange(0, HALF_COLS) |
| 75 | + tl.store(out_ptr + o_offsets, out.to(out_ptr.dtype.element_ty)) |
| 76 | + |
| 77 | + |
| 78 | +def swiglu_quant(x, group_list, group_list_type, need_quant=True): |
| 79 | + # group_list_type must be 0 cusum or 1 count |
| 80 | + if group_list_type not in [0, 1]: |
| 81 | + raise ValueError(f"group_list_type must be 0 or 1, but got {group_list_type}") |
| 82 | + s, h = x.shape |
| 83 | + out_dtype = torch.int8 if need_quant else x.dtype |
| 84 | + out = torch.empty((s, h // 2), dtype=out_dtype, device=x.device) |
| 85 | + scale = torch.empty((s,), dtype=torch.float32, device=x.device) |
| 86 | + num_experts = group_list.shape[0] |
| 87 | + # ub must be 32-byte aligned on npu |
| 88 | + if group_list.dtype == torch.int64: |
| 89 | + num_experts_algin = (num_experts + 7) // 8 * 8 |
| 90 | + elif group_list.dtype == torch.int32: |
| 91 | + num_experts_algin = (num_experts + 15) // 16 * 16 |
| 92 | + else: |
| 93 | + raise ValueError( |
| 94 | + f"group_list dtype must be torch.int32 or torch.int64, but got {group_list.dtype}" |
| 95 | + ) |
| 96 | + |
| 97 | + num_cores = get_npu_properties()["num_vectorcore"] |
| 98 | + _swiglu_quant_kernel[(num_cores,)]( |
| 99 | + x, |
| 100 | + group_list, |
| 101 | + out, |
| 102 | + scale, |
| 103 | + TOTAL_COLS=h, |
| 104 | + HALF_COLS=h // 2, |
| 105 | + COL_BLOCK_SIZE=1024, |
| 106 | + NUM_EXPERTS=num_experts, |
| 107 | + NUM_EXPERTS_ALGIN=num_experts_algin, |
| 108 | + GROUP_LIST_TYPE=group_list_type, |
| 109 | + NUM_CORES=num_cores, |
| 110 | + DTYPE_MAX=127, |
| 111 | + SCALE=need_quant, |
| 112 | + multibuffer=True, |
| 113 | + ) |
| 114 | + return out, scale |
0 commit comments