From 7dc0a5c7772c35bbb166fa7d6d89caf65a0b86ff Mon Sep 17 00:00:00 2001 From: q yao Date: Fri, 18 Oct 2024 12:31:34 +0800 Subject: [PATCH] optimize paged attention on triton3 (#2553) * optimize paged attention on triton3 * fix w8a8 kernel * optimize prefill * optimize short decoding * optimize sm<8 * optimize short context * fix triton2.2.0 * recovery test * add ut for custom layout * update stride * update ut --- .github/workflows/unit-test.yml | 4 +- lmdeploy/pytorch/check_env/__init__.py | 18 +- lmdeploy/pytorch/engine/engine.py | 2 +- .../pytorch/kernels/cuda/fill_kv_cache.py | 61 ++-- .../pytorch/kernels/cuda/pagedattention.py | 273 +++++++++++------- .../kernels/cuda/w8a8_triton_kernels.py | 11 +- requirements/runtime.txt | 4 +- tests/pytorch/kernel/test_paged_attention.py | 37 ++- 8 files changed, 259 insertions(+), 151 deletions(-) diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 64b3acd52a..ec6db0682d 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -53,7 +53,7 @@ jobs: - name: Install pytorch run: | python3 -m pip cache dir - python3 -m pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118 + python3 -m pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118 - name: Build lmdeploy run: | python3 -m pip install cmake @@ -77,7 +77,7 @@ jobs: run: | python3 -m pip install pynvml packaging protobuf transformers_stream_generator # manually install flash attn - python3 -m pip install /root/packages/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl + python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp38-cp38-linux_x86_64.whl python3 -m pip install -r requirements.txt -r requirements/test.txt python3 -m pip install . - name: Check env diff --git a/lmdeploy/pytorch/check_env/__init__.py b/lmdeploy/pytorch/check_env/__init__.py index ea2dda8e8d..5ace70b53c 100644 --- a/lmdeploy/pytorch/check_env/__init__.py +++ b/lmdeploy/pytorch/check_env/__init__.py @@ -93,15 +93,13 @@ def check_env_triton(device: str): if device == 'cuda': device_cap = torch.cuda.get_device_capability() - TRITON_VER_220 = version.parse('2.2.0') TRITON_VER_231 = version.parse('2.3.1') if device_cap[0] <= 7: - if (triton_version >= TRITON_VER_220 - and triton_version <= TRITON_VER_231): + if triton_version <= TRITON_VER_231: err = RuntimeError( 'Attention triton kernel does not fully support ' - 'triton[2.2.0~2.3.1] on device with capability<8. ' + 'triton<3.0.0 on device with capability<8. ' 'Please upgrade your triton version.') _handle_exception(err, 'Triton', logger) @@ -142,7 +140,8 @@ def check_awq(hf_config): def check_transformers_version(model_path: str, - trust_remote_code: bool = True): + trust_remote_code: bool = True, + dtype: str = 'auto'): """check transformers version.""" from packaging import version logger = get_logger('lmdeploy') @@ -206,7 +205,8 @@ def __check_model_dtype_support(config): try: model_config = ModelConfig.from_hf_config(config, - model_path=model_path) + model_path=model_path, + dtype=dtype) if model_config.dtype == torch.bfloat16: assert torch.cuda.is_bf16_supported(), ( 'bf16 is not supported on your device') @@ -229,11 +229,13 @@ def __check_model_dtype_support(config): check_awq(config) -def check_model(model_path: str, trust_remote_code: bool = True): +def check_model(model_path: str, + trust_remote_code: bool = True, + dtype: str = 'auto'): """check model requirements.""" logger = get_logger('lmdeploy') logger.info('Checking model.') - check_transformers_version(model_path, trust_remote_code) + check_transformers_version(model_path, trust_remote_code, dtype) def check_adapter(path: str): diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 58d319ef8c..f6ce4c29a1 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -96,7 +96,7 @@ def __init__(self, else: engine_config = copy.deepcopy(engine_config) check_env(engine_config.device_type) - check_model(model_path, trust_remote_code) + check_model(model_path, trust_remote_code, engine_config.dtype) if engine_config.max_batch_size is None: engine_config.max_batch_size = get_max_batch_size( engine_config.device_type) diff --git a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py index a9a6cab010..9ef614fadd 100644 --- a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py @@ -378,12 +378,21 @@ def fill_kv_cache(k_states: Tensor, block_offsets: Tensor, k_scales_zeros: Tensor = None, v_scales_zeros: Tensor = None, - quant_policy: Literal[0, 4, 8] = 0): + quant_policy: Literal[0, 4, 8] = 0, + kv_layout: str = 'bshd'): """fill key/value state to cache for paged attention.""" + if kv_layout == 'bshd': + b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3) + elif kv_layout == 'bhsd': + b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3) + else: + raise RuntimeError('Unsupported layout.') block_offsets = block_offsets.contiguous() batch_size = block_offsets.size(0) - block_size, num_heads, head_dim = k_caches.size()[1:] + block_size = k_caches.size(s_dim) + num_heads = k_caches.size(h_dim) + head_dim = k_caches.size(d_dim) head_dim_v = v_states.size(-1) max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1 @@ -412,14 +421,14 @@ def fill_kv_cache(k_states: Tensor, stride_vss=v_states.stride(-3), stride_vsh=v_states.stride(-2), stride_vsd=v_states.stride(-1), - stride_kcn=k_caches.stride(0), - stride_kcb=k_caches.stride(1), - stride_kch=k_caches.stride(2), - stride_kcd=k_caches.stride(3), - stride_vcn=v_caches.stride(0), - stride_vcb=v_caches.stride(1), - stride_vch=v_caches.stride(2), - stride_vcd=v_caches.stride(3), + stride_kcn=k_caches.stride(b_dim), + stride_kcb=k_caches.stride(s_dim), + stride_kch=k_caches.stride(h_dim), + stride_kcd=k_caches.stride(d_dim), + stride_vcn=v_caches.stride(b_dim), + stride_vcb=v_caches.stride(s_dim), + stride_vch=v_caches.stride(h_dim), + stride_vcd=v_caches.stride(d_dim), stride_boff=block_offsets.stride(0), BLOCK=BLOCK, BLOCK_D=BLOCK_D, @@ -450,22 +459,22 @@ def fill_kv_cache(k_states: Tensor, stride_vss=v_states.stride(-3), stride_vsh=v_states.stride(-2), stride_vsd=v_states.stride(-1), - stride_kcn=k_caches.stride(0), - stride_kcb=k_caches.stride(1), - stride_kch=k_caches.stride(2), - stride_kcd=k_caches.stride(3), - stride_vcn=v_caches.stride(0), - stride_vcb=v_caches.stride(1), - stride_vch=v_caches.stride(2), - stride_vcd=v_caches.stride(3), - stride_kszn=k_scales_zeros.stride(0), - stride_kszb=k_scales_zeros.stride(1), - stride_kszh=k_scales_zeros.stride(2), - stride_kszd=k_scales_zeros.stride(3), - stride_vszn=v_scales_zeros.stride(0), - stride_vszb=v_scales_zeros.stride(1), - stride_vszh=v_scales_zeros.stride(2), - stride_vszd=v_scales_zeros.stride(3), + stride_kcn=k_caches.stride(b_dim), + stride_kcb=k_caches.stride(s_dim), + stride_kch=k_caches.stride(h_dim), + stride_kcd=k_caches.stride(d_dim), + stride_vcn=v_caches.stride(b_dim), + stride_vcb=v_caches.stride(s_dim), + stride_vch=v_caches.stride(h_dim), + stride_vcd=v_caches.stride(d_dim), + stride_kszn=k_scales_zeros.stride(b_dim), + stride_kszb=k_scales_zeros.stride(s_dim), + stride_kszh=k_scales_zeros.stride(h_dim), + stride_kszd=k_scales_zeros.stride(d_dim), + stride_vszn=v_scales_zeros.stride(b_dim), + stride_vszb=v_scales_zeros.stride(s_dim), + stride_vszh=v_scales_zeros.stride(h_dim), + stride_vszd=v_scales_zeros.stride(d_dim), quant_policy=quant_policy, stride_boff=block_offsets.stride(0), BLOCK=BLOCK, diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index aa363d2bd4..d8e6ec5013 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -15,18 +15,15 @@ logger = get_logger('lmdeploy') TRITON_VERSION = version.parse(triton.__version__) +VERSION_300 = version.parse('3.0.0') -assert TRITON_VERSION >= version.parse('2.1.0') +assert TRITON_VERSION >= version.parse('2.2.0') -if TRITON_VERSION >= version.parse('3.0.0'): - - @triton.jit - def tanh(x): - """tanh.""" - return 2 * tl.sigmoid(2 * x) - 1 - - fast_expf = tl.math.exp - fast_dividef = tl.math.fdiv +# TODO: fast op might not work on non-nv device +if TRITON_VERSION >= VERSION_300: + tanh = tl.extra.cuda.libdevice.tanh + fast_expf = tl.extra.cuda.libdevice.fast_expf + fast_dividef = tl.extra.cuda.libdevice.fast_dividef else: tanh = tl.math.tanh fast_expf = tl.math.fast_expf @@ -38,7 +35,9 @@ def tanh(x): triton.Config({}, num_stages=2, num_warps=8), triton.Config({}, num_stages=2, num_warps=4), ], - key=['BLOCK_H', 'BLOCK_N', 'BLOCK_DMODEL', 'BLOCK_DV']) + key=['BLOCK_H', 'BLOCK_N', 'BLOCK_DMODEL', 'BLOCK_DV'], + warmup=10, + rep=25) @wrap_jit_func(type_hint=dict( Q=torch.Tensor, K=torch.Tensor, @@ -235,9 +234,13 @@ def _fwd_grouped_split_kernel( m_i = m_i_new # initialize pointers to output - off_acc = (cur_batch * stride_obs + split_k_id * stride_ok + - cur_head[:, None] * stride_oh + offs_dv[None, :] * stride_od) - tl.store(Acc_out + off_acc, acc, mask=mask_h[:, None] & mask_dv[None, :]) + if loop_end > loop_start: + off_acc = (cur_batch * stride_obs + split_k_id * stride_ok + + cur_head[:, None] * stride_oh + + offs_dv[None, :] * stride_od) + tl.store(Acc_out + off_acc, + acc, + mask=mask_h[:, None] & mask_dv[None, :]) off_meta = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v) @@ -515,9 +518,13 @@ def _fwd_grouped_split_quant_kernel( m_i = m_i_new # initialize pointers to output - off_acc = (cur_batch * stride_obs + split_k_id * stride_ok + - cur_head[:, None] * stride_oh + offs_dv[None, :] * stride_od) - tl.store(Acc_out + off_acc, acc, mask=mask_h[:, None] & mask_dv[None, :]) + if loop_end > loop_start: + off_acc = (cur_batch * stride_obs + split_k_id * stride_ok + + cur_head[:, None] * stride_oh + + offs_dv[None, :] * stride_od) + tl.store(Acc_out + off_acc, + acc, + mask=mask_h[:, None] & mask_dv[None, :]) if quant_policy == 4: off_meta = (cur_batch * stride_obs + split_k_id * stride_ok + @@ -572,9 +579,11 @@ def _reduce_split_kernel( offs_mi = (cur_batch * stride_abs + cur_head * stride_ah + stride_ak * offs_k + head_size_v) - acc_k = tl.load(Acc + offs_acc, mask=mask_dv[None, :], other=0.0) m_k = tl.load(Acc + offs_mi) l_k = tl.load(Acc + offs_mi + 1) + acc_k = tl.load(Acc + offs_acc, + mask=mask_dv[None, :] & (m_k[:, None] > -float('inf')), + other=0.0) m_max = tl.max(m_k, 0) alpha = fast_expf(m_k - m_max) @@ -592,7 +601,8 @@ def _reduce_split_kernel( def _get_convert_pv(nv_capability): """lazy load convert_pv.""" - if nv_capability[0] >= 8: + global TRITON_VERSION, VERSION_300 + if TRITON_VERSION >= VERSION_300 or nv_capability[0] >= 8: @triton.jit def convert_pv(p, v): @@ -620,7 +630,6 @@ def convert_pv(p, v): # triton.Config({}, num_stages=1, num_warps=4), # ], # key=['BLOCK_M', 'BLOCK_N', 'BLOCK_DMODEL', 'BLOCK_DV']) -@wrap_jit_func @triton.jit def _fwd_kernel( Q, @@ -647,7 +656,7 @@ def _fwd_kernel( stride_oh: tl.constexpr, stride_od: tl.constexpr, stride_boffb, - kv_group_num: tl.constexpr, + kv_group_num, window_size: tl.constexpr, head_size: tl.constexpr, head_size_v: tl.constexpr, @@ -660,18 +669,16 @@ def _fwd_kernel( ): """paged attention kernel.""" cur_batch = tl.program_id(2) - cur_head = tl.program_id(1) + cur_kv_head = tl.program_id(1) start_m = tl.program_id(0) - cur_kv_head = cur_head // kv_group_num - q_seqlen = tl.load(Q_seqlens + cur_batch) kv_seqlen = tl.load(KV_seqlens + cur_batch) q_start_loc = tl.load(Q_start_loc + cur_batch) history_len = kv_seqlen - q_seqlen block_start_loc = BLOCK_M * start_m - if block_start_loc >= q_seqlen: + if block_start_loc >= q_seqlen * kv_group_num: return # initialize offsets @@ -682,17 +689,22 @@ def _fwd_kernel( offs_d = offs_d % head_size mask_dv = offs_dv < head_size_v offs_dv = offs_dv % head_size_v - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_mh = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = offs_mh // kv_group_num + cur_head = offs_mh % kv_group_num + cur_kv_head * kv_group_num off_q = ((q_start_loc + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd) off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + offs_n[None, :] * stride_kbs) off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd + offs_n[:, None] * stride_vbs) - q = tl.load(Q + off_q, - mask=(offs_m[:, None] < q_seqlen) & mask_d[None, :], - other=0.0) + q = tl.load( + Q + off_q, + mask=(offs_m[:, None] < q_seqlen) & mask_d[None, :], + other=0.0, + eviction_policy='evict_first', + ) k_ptrs = K + off_k v_ptrs = V + off_v @@ -702,7 +714,7 @@ def _fwd_kernel( mask_d1 = offs_d1 < head_size offs_d1 = offs_d1 % head_size off_q1 = ((q_start_loc + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d1[None, :] * stride_qd) + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd) q1 = tl.load(Q + off_q1, mask=(offs_m[:, None] < q_seqlen) & mask_d1) off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd + offs_n[None, :] * stride_kbs) @@ -722,7 +734,56 @@ def _fwd_kernel( kv_start_loc = start_block_id * BLOCK_N block_offset_ptrs += start_block_id - for start_n in range(kv_start_loc, kv_seqlen, BLOCK_N): + loop_start = kv_start_loc + loop_end = history_len // BLOCK_N * BLOCK_N + for start_n in range(loop_start, loop_end, BLOCK_N): + b_offset = tl.load(block_offset_ptrs) + block_offset_ptrs += 1 + + # -- compute qk ---- + k = tl.load(k_ptrs + b_offset * stride_kp) + if BLOCK_DMODEL1 != 0: + k1 = tl.load(k1_ptrs + b_offset * stride_kp) + + v = tl.load(v_ptrs + b_offset * stride_vp) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + if BLOCK_DMODEL1 != 0: + qk += tl.dot(q1, k1) + qk *= sm_scale + if logit_softcapping > 0.0: + qk = qk / logit_softcapping + qk = tanh(qk) + qk = qk * logit_softcapping + # NOTE: inf - inf = nan, and nan will leads to error + if window_size > 0: + qk_mask = ((start_n + offs_n[None, :]) >= kv_min_loc[:, None]) + qk = tl.where( + qk_mask, + qk, + float(-1e30), + ) + + # -- compute p, m_i and l_i + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + p = fast_expf(qk - m_i_new[:, None]) + alpha = fast_expf(m_i - m_i_new) + l_i_new = alpha * l_i + tl.sum(p, 1) + # -- update output accumulator -- + # scale acc + acc = acc * alpha[:, None] + + # update acc + p, v = _convert_pv(p, v) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + loop_start = loop_end + loop_end = kv_seqlen + for start_n in range(loop_start, loop_end, BLOCK_N): b_offset = tl.load(block_offset_ptrs) block_offset_ptrs += 1 @@ -773,7 +834,7 @@ def _fwd_kernel( acc = fast_dividef(acc, l_i[:, None]) # initialize pointers to output off_o = ((q_start_loc + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_dv[None, :] * stride_od) + cur_head[:, None] * stride_oh + offs_dv[None, :] * stride_od) out_ptrs = Out + off_o tl.store(out_ptrs, acc, @@ -825,7 +886,7 @@ def _fwd_kernel_quant( stride_oh: tl.constexpr, stride_od: tl.constexpr, stride_boffb, - kv_group_num: tl.constexpr, + kv_group_num, window_size: tl.constexpr, head_size: tl.constexpr, head_size_v: tl.constexpr, @@ -845,18 +906,16 @@ def _fwd_kernel_quant( stride_d: stride of head size dim """ cur_batch = tl.program_id(2) - cur_head = tl.program_id(1) + cur_kv_head = tl.program_id(1) start_m = tl.program_id(0) - cur_kv_head = cur_head // kv_group_num - q_seqlen = tl.load(Q_seqlens + cur_batch) kv_seqlen = tl.load(KV_seqlens + cur_batch) q_start_loc = tl.load(Q_start_loc + cur_batch) history_len = kv_seqlen - q_seqlen block_start_loc = BLOCK_M * start_m - if block_start_loc >= q_seqlen: + if block_start_loc >= q_seqlen * kv_group_num: return # initialize offsets @@ -868,9 +927,11 @@ def _fwd_kernel_quant( offs_d = offs_d % head_size mask_dv = offs_dv < head_size_v offs_dv = offs_dv % head_size_v - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_mh = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = offs_mh // kv_group_num + cur_head = offs_mh % kv_group_num + cur_kv_head * kv_group_num off_q = ((q_start_loc + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd) off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + offs_n[None, :] * stride_kbs) off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd + @@ -892,7 +953,7 @@ def _fwd_kernel_quant( mask_d1 = offs_d1 < head_size offs_d1 = offs_d1 % head_size off_q1 = ((q_start_loc + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d1[None, :] * stride_qd) + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd) q1 = tl.load(Q + off_q1, mask=(offs_m[:, None] < q_seqlen) & mask_d1) off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd + offs_n[None, :] * stride_kbs) @@ -999,7 +1060,7 @@ def _fwd_kernel_quant( acc = fast_dividef(acc, l_i[:, None]) # initialize pointers to output off_o = ((q_start_loc + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_dv[None, :] * stride_od) + cur_head[:, None] * stride_oh + offs_dv[None, :] * stride_od) out_ptrs = Out + off_o tl.store(out_ptrs, acc, @@ -1022,6 +1083,7 @@ def paged_attention_fwd( window_size: int = None, sm_scale: float = None, logit_softcapping: float = None, + kv_layout: str = 'bshd', ): """Paged Attention forward. @@ -1042,6 +1104,13 @@ def paged_attention_fwd( nv_cap = torch.cuda.get_device_capability() _convert_pv = _get_convert_pv(nv_cap) + if kv_layout == 'bshd': + b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3) + elif kv_layout == 'bhsd': + b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3) + else: + raise RuntimeError('Unsupported layout.') + if window_size is None: window_size = -1 @@ -1059,7 +1128,7 @@ def _get_block_d(Lk): return BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + Lq, Lk, Lv = q.shape[-1], k.shape[d_dim], v.shape[d_dim] if quant_policy == 4: assert Lq == Lk * 2 and Lv * 2 == o.shape[-1] else: @@ -1068,9 +1137,9 @@ def _get_block_d(Lk): if sm_scale is None: sm_scale = 1.0 / (Lq**0.5) batch, head = q_seqlens.shape[0], q.shape[-2] - kv_group_num = q.shape[-2] // k.shape[-2] + kv_group_num = q.shape[-2] // k.shape[h_dim] - BLOCK = k.size(1) + BLOCK = k.size(s_dim) assert BLOCK >= 16 if Lq > 512 and BLOCK > 32: logger.warning(f'`head_dim={Lq}` and `block_size={BLOCK}` ' @@ -1084,7 +1153,9 @@ def _get_block_d(Lk): BLOCK_M = max(16, min(BLOCK, 16384 // BLOCK_DMODEL)) num_warps = 4 num_stages = 2 - grid = (triton.cdiv(max_seqlen, BLOCK_M), head, batch) + kv_head = k.shape[h_dim] + grid = (triton.cdiv(max_seqlen * kv_group_num, + BLOCK_M), kv_head, batch) if quant_policy > 0: _fwd_kernel_quant[grid](q, k, @@ -1100,22 +1171,22 @@ def _get_block_d(Lk): stride_qbs=q.stride(-3), stride_qh=q.stride(-2), stride_qd=q.stride(-1), - stride_kp=k.stride(-4), - stride_kbs=k.stride(-3), - stride_kh=k.stride(-2), - stride_kd=k.stride(-1), - stride_vp=v.stride(-4), - stride_vbs=v.stride(-3), - stride_vh=v.stride(-2), - stride_vd=v.stride(-1), - stride_kszp=k_scales_zeros.stride(-4), - stride_kszbs=k_scales_zeros.stride(-3), - stride_kszh=k_scales_zeros.stride(-2), - stride_kszd=k_scales_zeros.stride(-1), - stride_vszp=v_scales_zeros.stride(-4), - stride_vszbs=v_scales_zeros.stride(-3), - stride_vszh=v_scales_zeros.stride(-2), - stride_vszd=v_scales_zeros.stride(-1), + stride_kp=k.stride(b_dim), + stride_kbs=k.stride(s_dim), + stride_kh=k.stride(h_dim), + stride_kd=k.stride(d_dim), + stride_vp=v.stride(b_dim), + stride_vbs=v.stride(s_dim), + stride_vh=v.stride(h_dim), + stride_vd=v.stride(d_dim), + stride_kszp=k_scales_zeros.stride(b_dim), + stride_kszbs=k_scales_zeros.stride(s_dim), + stride_kszh=k_scales_zeros.stride(h_dim), + stride_kszd=k_scales_zeros.stride(d_dim), + stride_vszp=v_scales_zeros.stride(b_dim), + stride_vszbs=v_scales_zeros.stride(s_dim), + stride_vszh=v_scales_zeros.stride(h_dim), + stride_vszd=v_scales_zeros.stride(d_dim), quant_policy=quant_policy, stride_obs=o.stride(-3), stride_oh=o.stride(-2), @@ -1147,14 +1218,14 @@ def _get_block_d(Lk): stride_qbs=q.stride(-3), stride_qh=q.stride(-2), stride_qd=q.stride(-1), - stride_kp=k.stride(-4), - stride_kbs=k.stride(-3), - stride_kh=k.stride(-2), - stride_kd=k.stride(-1), - stride_vp=v.stride(-4), - stride_vbs=v.stride(-3), - stride_vh=v.stride(-2), - stride_vd=v.stride(-1), + stride_kp=k.stride(b_dim), + stride_kbs=k.stride(s_dim), + stride_kh=k.stride(h_dim), + stride_kd=k.stride(d_dim), + stride_vp=v.stride(b_dim), + stride_vbs=v.stride(s_dim), + stride_vh=v.stride(h_dim), + stride_vd=v.stride(d_dim), stride_obs=o.stride(-3), stride_oh=o.stride(-2), stride_od=o.stride(-1), @@ -1209,22 +1280,22 @@ def _get_block_d(Lk): stride_qbs=q.stride(-3), stride_qh=q.stride(-2), stride_qd=q.stride(-1), - stride_kp=k.stride(-4), - stride_kbs=k.stride(-3), - stride_kh=k.stride(-2), - stride_kd=k.stride(-1), - stride_vp=v.stride(-4), - stride_vbs=v.stride(-3), - stride_vh=v.stride(-2), - stride_vd=v.stride(-1), - stride_kszp=k_scales_zeros.stride(-4), - stride_kszbs=k_scales_zeros.stride(-3), - stride_kszh=k_scales_zeros.stride(-2), - stride_kszd=k_scales_zeros.stride(-1), - stride_vszp=v_scales_zeros.stride(-4), - stride_vszbs=v_scales_zeros.stride(-3), - stride_vszh=v_scales_zeros.stride(-2), - stride_vszd=v_scales_zeros.stride(-1), + stride_kp=k.stride(b_dim), + stride_kbs=k.stride(s_dim), + stride_kh=k.stride(h_dim), + stride_kd=k.stride(d_dim), + stride_vp=v.stride(b_dim), + stride_vbs=v.stride(s_dim), + stride_vh=v.stride(h_dim), + stride_vd=v.stride(d_dim), + stride_kszp=k_scales_zeros.stride(b_dim), + stride_kszbs=k_scales_zeros.stride(s_dim), + stride_kszh=k_scales_zeros.stride(h_dim), + stride_kszd=k_scales_zeros.stride(d_dim), + stride_vszp=v_scales_zeros.stride(b_dim), + stride_vszbs=v_scales_zeros.stride(s_dim), + stride_vszh=v_scales_zeros.stride(h_dim), + stride_vszd=v_scales_zeros.stride(d_dim), quant_policy=quant_policy, stride_ok=acc.stride(-2), stride_obs=acc.stride(-4), @@ -1257,14 +1328,14 @@ def _get_block_d(Lk): stride_qbs=q.stride(-3), stride_qh=q.stride(-2), stride_qd=q.stride(-1), - stride_kp=k.stride(-4), - stride_kbs=k.stride(-3), - stride_kh=k.stride(-2), - stride_kd=k.stride(-1), - stride_vp=v.stride(-4), - stride_vbs=v.stride(-3), - stride_vh=v.stride(-2), - stride_vd=v.stride(-1), + stride_kp=k.stride(b_dim), + stride_kbs=k.stride(s_dim), + stride_kh=k.stride(h_dim), + stride_kd=k.stride(d_dim), + stride_vp=v.stride(b_dim), + stride_vbs=v.stride(s_dim), + stride_vh=v.stride(h_dim), + stride_vd=v.stride(d_dim), stride_ok=acc.stride(-2), stride_obs=acc.stride(-4), stride_oh=acc.stride(-3), @@ -1291,13 +1362,13 @@ def _get_block_d(Lk): BLOCK_DV *= 2 _reduce_split_kernel[grid](acc, o, - stride_ak=acc.stride(-2), - stride_abs=acc.stride(-4), - stride_ah=acc.stride(-3), - stride_ad=acc.stride(-1), - stride_obs=o.stride(-3), - stride_oh=o.stride(-2), - stride_od=o.stride(-1), + stride_ak=acc.stride(2), + stride_abs=acc.stride(0), + stride_ah=acc.stride(1), + stride_ad=acc.stride(3), + stride_obs=o.stride(0), + stride_oh=o.stride(1), + stride_od=o.stride(2), SPLIT_K=SPLIT_K, head_size_v=Lv, BLOCK_DV=BLOCK_DV, diff --git a/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py b/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py index 5fdaa5802c..0d0e10ec83 100644 --- a/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py +++ b/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py @@ -3,9 +3,16 @@ import torch.nn.functional as F import triton import triton.language as tl +from packaging import version from .triton_utils import get_kernel_meta +TRITON_VERSION = version.parse(triton.__version__) +if TRITON_VERSION >= version.parse('3.0.0'): + tl_round = tl.extra.cuda.libdevice.round +else: + tl_round = tl.math.round + def per_channel_quant(x, n_bits, dtype): """Quantize the input tensor 'x' channel-wise using the given number of @@ -305,7 +312,7 @@ def _per_token_quant_int8( # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) y_s = _absmax / 127 - y_q = tl.math.round(y / y_s).to(tl.int8) + y_q = tl_round(y / y_s).to(tl.int8) tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_s_ptr, y_s) @@ -373,7 +380,7 @@ def _rms_norm_fwd_fused_dynamic_symmetric( scale = tl.max(tl.abs(y)).to(tl.float32) / 127 tl.store(Scale + row, scale) - y = tl.math.round(y / scale) + y = tl_round(y / scale) y = tl.minimum(y, 127) y = tl.maximum(y, -128) tl.store(Y + cols, y, mask=mask) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 7e5058c17b..7fb2491014 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -15,8 +15,8 @@ safetensors sentencepiece shortuuid tiktoken -torch<=2.3.1,>=2.0.0 +torch<=2.4.0,>=2.0.0 torchvision<=0.18.1,>=0.15.0 transformers -triton>=2.1.0,<=3.0.0; sys_platform == "linux" +triton>=2.2.0,<=3.0.0; sys_platform == "linux" uvicorn diff --git a/tests/pytorch/kernel/test_paged_attention.py b/tests/pytorch/kernel/test_paged_attention.py index 5d4b024199..7f63b281c5 100644 --- a/tests/pytorch/kernel/test_paged_attention.py +++ b/tests/pytorch/kernel/test_paged_attention.py @@ -26,9 +26,16 @@ def _make_bias(seq_lens, history_lens, neg_val): return mask.float() * neg_val -def _make_blocked_cache(batched_k, batched_v, seq_lens, history_lens, - block_offsets, block_size, num_heads_k, feat_dim, - feat_dim_v): +def _make_blocked_cache(batched_k, + batched_v, + seq_lens, + history_lens, + block_offsets, + block_size, + num_heads_k, + feat_dim, + feat_dim_v, + layout: str = 'bshd'): max_blocks_nums = block_offsets.max() + 1 full_seq_lens = seq_lens + history_lens blocked_k = batched_k.new_zeros(max_blocks_nums, block_size, num_heads_k, @@ -48,6 +55,10 @@ def _make_blocked_cache(batched_k, batched_v, seq_lens, history_lens, blocked_k[block_off, :size] = tmp_k blocked_v[block_off, :size] = tmp_v + if layout == 'bhsd': + blocked_k = blocked_k.transpose(1, 2).contiguous() + blocked_v = blocked_v.transpose(1, 2).contiguous() + return blocked_k, blocked_v @@ -129,6 +140,10 @@ def num_heads_k(self, request): def block_size(self, request): yield request.param + @pytest.fixture + def layout(self, request): + yield request.param + @pytest.fixture def seq_lens(self, request): yield torch.tensor(request.param, device='cuda') @@ -208,11 +223,11 @@ def conti_kv(self, batched_kv, seq_lens, history_lens): @pytest.fixture def blocked_kv(self, batched_kv, seq_lens, history_lens, block_offsets, - block_size, num_heads_k, feat_dim, feat_dim_v): + block_size, num_heads_k, feat_dim, feat_dim_v, layout): batched_k, batched_v = batched_kv yield _make_blocked_cache(batched_k, batched_v, seq_lens, history_lens, block_offsets, block_size, num_heads_k, - feat_dim, feat_dim_v) + feat_dim, feat_dim_v, layout) @pytest.fixture def mask(self, seq_lens, history_lens): @@ -236,9 +251,10 @@ def conti_gt(self, gt, seq_lens): ([1, 1, 1, 1], [50, 40, 30, 20])], indirect=True) @pytest.mark.parametrize('block_size', [16], indirect=True) + @pytest.mark.parametrize('layout', ['bshd', 'bhsd'], indirect=True) def test_paged_attention(self, conti_q, blocked_kv, block_offsets, start_loc, seq_lens, history_lens, feat_dim_v, - conti_gt): + layout, conti_gt): from lmdeploy.pytorch.kernels import paged_attention_fwd kv_seq_lens = seq_lens + history_lens max_seq_len = seq_lens.max().item() @@ -254,7 +270,8 @@ def test_paged_attention(self, conti_q, blocked_kv, block_offsets, q_start_loc=start_loc, q_seqlens=seq_lens, kv_seqlens=kv_seq_lens, - max_seqlen=max_seq_len) + max_seqlen=max_seq_len, + kv_layout=layout) torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5) @pytest.fixture @@ -282,9 +299,10 @@ def window_gt(self, conti_q, conti_kv, seq_lens, history_lens, win_size): indirect=True) @pytest.mark.parametrize('win_size', (32, ), indirect=True) @pytest.mark.parametrize('block_size', [16], indirect=True) + @pytest.mark.parametrize('layout', ['bshd'], indirect=True) def test_window_attention(self, conti_q, blocked_kv, block_offsets, start_loc, seq_lens, history_lens, feat_dim_v, - win_size, window_gt): + win_size, layout, window_gt): from lmdeploy.pytorch.kernels import paged_attention_fwd kv_seq_lens = seq_lens + history_lens max_seq_len = seq_lens.max().item() @@ -300,7 +318,8 @@ def test_window_attention(self, conti_q, blocked_kv, block_offsets, q_seqlens=seq_lens, kv_seqlens=kv_seq_lens, max_seqlen=max_seq_len, - window_size=win_size) + window_size=win_size, + kv_layout=layout) torch.testing.assert_close(out, window_gt, atol=1e-3, rtol=1e-5)