From c5b0a3143a807f28e8a19321354a2fb94cdf32cc Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 5 Mar 2024 11:36:29 +0800 Subject: [PATCH 1/2] remove unused kernel in pytorch engine (#1237) --- lmdeploy/pytorch/engine/engine.py | 14 + lmdeploy/pytorch/kernels/__init__.py | 5 +- .../pytorch/kernels/biased_pagedattention.py | 240 ------------------ .../pytorch/kernels/flashattention_nopad.py | 199 --------------- tests/pytorch/kernel/test_paged_attention.py | 30 --- 5 files changed, 15 insertions(+), 473 deletions(-) delete mode 100644 lmdeploy/pytorch/kernels/biased_pagedattention.py delete mode 100644 lmdeploy/pytorch/kernels/flashattention_nopad.py diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index fe5bf7e0c..c59bfed29 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -834,6 +834,20 @@ async def _add_messages(session_ids, token_ids): output_token_len = [len(token_ids) for token_ids in output_token_ids] return (status, output_token_ids, output_token_len) + def batched_infer(self, + session_ids: List[int], + token_ids: List[List[int]] = None, + gen_config: EngineGenerationConfig = None, + adapter_names: List[str] = None, + keep_cache: bool = False): + """batched infer.""" + coro = self.async_batched_infer(session_ids, + token_ids, + gen_config=gen_config, + adapter_names=adapter_names, + keep_cache=keep_cache) + return self.req_sender.run_until_complete(coro) + def decode(self, input_ids, steps: List[int] = None, diff --git a/lmdeploy/pytorch/kernels/__init__.py b/lmdeploy/pytorch/kernels/__init__.py index a1e2ead43..31d2e4c0d 100644 --- a/lmdeploy/pytorch/kernels/__init__.py +++ b/lmdeploy/pytorch/kernels/__init__.py @@ -1,9 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .alibi_pagedattention import alibi_paged_attention_fwd from .apply_rotary_pos_emb import apply_rotary_pos_emb -from .biased_pagedattention import biased_paged_attention_fwd from .fill_kv_cache import fill_kv_cache -from .flashattention_nopad import context_attention_fwd from .fused_rotary_emb import fused_rotary_emb from .multinomial_sampling import multinomial_sampling from .pagedattention import paged_attention_fwd @@ -11,8 +9,7 @@ from .rms_norm import rms_norm __all__ = [ - 'apply_rotary_pos_emb', 'context_attention_fwd', 'fused_rotary_emb', - 'paged_attention_fwd', 'biased_paged_attention_fwd', + 'apply_rotary_pos_emb', 'fused_rotary_emb', 'paged_attention_fwd', 'alibi_paged_attention_fwd', 'fill_kv_cache', 'multinomial_sampling', 'rms_norm', 'rerope_attention_fwd' ] diff --git a/lmdeploy/pytorch/kernels/biased_pagedattention.py b/lmdeploy/pytorch/kernels/biased_pagedattention.py deleted file mode 100644 index 1270c17e7..000000000 --- a/lmdeploy/pytorch/kernels/biased_pagedattention.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# modify from: https://github.com/ModelTC/lightllm -import torch -import triton -import triton.language as tl -from torch import Tensor - -assert triton.__version__ >= '2.1.0' - -_NV_CAP = torch.cuda.get_device_capability() -if _NV_CAP[0] >= 8: - - @triton.jit - def _convert_pv(p, v): - """convert pv.""" - p = p.to(v.dtype) - return p, v -else: - - @triton.jit - def _convert_pv(p, v): - """convert pv.""" - v = v.to(p.dtype) - return p, v - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - Bias, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_kvlen, - Block_offsets, - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_biasbs, - stride_biash, - stride_biasq, - stride_biask, - stride_obs, - stride_oh, - stride_od, - stride_boffb, - kv_group_num, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - """biased paged attention kernel.""" - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_kv_len = tl.load(B_kvlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - off_bias = (cur_batch * stride_biasbs + cur_head * stride_biash + - offs_m[:, None] * stride_biasq + - offs_n[None, :] * stride_biask) - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - bias_ptrs = Bias + off_bias - - block_offset_ptrs = Block_offsets + cur_batch * stride_boffb - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * cur_batch_kv_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - - start_block_id = start_n // BLOCK_N - b_offset = tl.load(block_offset_ptrs + start_block_id) - - # -- compute qk ---- - k = tl.load( - k_ptrs + b_offset * BLOCK_N * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_kv_len, - other=0.0, - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - - bias = tl.load( - bias_ptrs + start_n, - mask=(start_n + offs_n[None, :]) < cur_batch_kv_len - and (offs_m[:, None] < cur_batch_seq_len), - other=-1e30, - ) - qk += bias - - # -- compute p, m_i and l_i - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - p = tl.exp(qk - m_i_new[:, None]) - alpha = tl.exp(m_i - m_i_new) - l_i_new = alpha * l_i + tl.sum(p, 1) - # -- update output accumulator -- - # scale acc - acc = acc * alpha[:, None] - # update acc - v = tl.load( - v_ptrs + b_offset * BLOCK_N * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_kv_len, - other=0.0, - ) - - p, v = _convert_pv(p, v) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - - acc = acc / l_i[:, None] - # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - - -@torch.no_grad() -def biased_paged_attention_fwd( - q: Tensor, - k: Tensor, - v: Tensor, - bias: Tensor, - o: Tensor, - block_offsets: Tensor, - b_start_loc: Tensor, - b_seq_len: Tensor, - b_kv_seq_len: Tensor, - max_input_len: int, - BLOCK: int = 64, -): - """Paged attention forward with custom bias. - - Args: - q (Tensor): Query state. - k (Tensor): Key state caches. - v (Tensor): Value state caches. - bias (Tensor): Bias of the QK. - o (Tensor): Output state. - block_offsets (Tensor): The block offset of key and value. - b_start_loc (Tensor): Start token location of each data in batch. - b_seq_len (Tensor): Query length for each data in batch. - b_kv_seq_len (Tensor): Key/Value length for each data in batch. - max_input_len (int): The max input length. - BLOCK (int): The kernel block size. - """ - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - assert bias.dtype == torch.float32 - - if bias.dim() == 2: - bias = bias.unsqueeze(0) - - if bias.dim() == 3: - bias = bias.unsqueeze(1) - - sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[-2] - kv_group_num = q.shape[-2] // k[0].shape[-2] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - bias_head_stride = 0 if bias.size(1) == 1 else bias.stride(-3) - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( - q, - k, - v, - bias, - sm_scale, - b_start_loc, - b_seq_len, - b_kv_seq_len, - block_offsets, - o, - q.stride(-3), - q.stride(-2), - q.stride(-1), - k.stride(-3), - k.stride(-2), - k.stride(-1), - v.stride(-3), - v.stride(-2), - v.stride(-1), - bias.stride(-4), - bias_head_stride, - bias.stride(-2), - bias.stride(-1), - o.stride(-3), - o.stride(-2), - o.stride(-1), - block_offsets.stride(0), - kv_group_num=kv_group_num, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) diff --git a/lmdeploy/pytorch/kernels/flashattention_nopad.py b/lmdeploy/pytorch/kernels/flashattention_nopad.py deleted file mode 100644 index 567a74496..000000000 --- a/lmdeploy/pytorch/kernels/flashattention_nopad.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -# modify from: https://github.com/ModelTC/lightllm -import torch -import triton -import triton.language as tl -from torch import Tensor - -assert triton.__version__ >= '2.1.0' - - -@triton.jit -def _fwd_kernel( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, - B_kv_start_loc, - B_kvlen, - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - """flash attention forward triton kernel.""" - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_kv_len = tl.load(B_kvlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_in_kv_start_index = tl.load(B_kv_start_loc + cur_batch) - history_len = cur_batch_kv_len - cur_batch_seq_len - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * cur_batch_kv_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_kv_len, - other=0.0, - ) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - # NOTE: inf - inf = nan, and nan will leads to error - qk = tl.where( - (history_len + offs_m[:, None]) >= (start_n + offs_n[None, :]), - qk, - float(-1e30), - ) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_kv_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_kv_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - - -@torch.no_grad() -def context_attention_fwd( - q: Tensor, - k: Tensor, - v: Tensor, - o: Tensor, - b_start_loc: Tensor, - b_seq_len: Tensor, - b_kv_start_loc: Tensor, - b_kv_seq_len: Tensor, - max_input_len: int, - BLOCK: int = 64, -): - """Context Attention forward. - - Args: - q (Tensor): Query state. - k (Tensor): Key state caches. - v (Tensor): Value state caches. - o (Tensor): Output state. - b_start_loc (Tensor): Start token location of each data in batch. - b_seq_len (Tensor): Query length for each data in batch. - b_kv_start_loc (Tensor): Start token location of kv in each data - in batch. - b_kv_seq_len (Tensor): Key/Value length for each data in batch. - max_input_len (int): The max input length. - BLOCK (int): The kernel block size. - """ - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - b_kv_start_loc, - b_kv_seq_len, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - kv_group_num=kv_group_num, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/tests/pytorch/kernel/test_paged_attention.py b/tests/pytorch/kernel/test_paged_attention.py index 55175aee5..3cd208f05 100644 --- a/tests/pytorch/kernel/test_paged_attention.py +++ b/tests/pytorch/kernel/test_paged_attention.py @@ -248,36 +248,6 @@ def test_paged_attention(self, conti_q, blocked_kv, block_offsets, max_seqlen=max_seq_len) torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5) - @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(4, 2)], - indirect=True) - @pytest.mark.parametrize(['seq_lens', 'history_lens'], - [([30, 50, 70, 90], [50, 40, 30, 20])], - indirect=True) - @pytest.mark.parametrize('block_size', [16], indirect=True) - def test_biased_paged_attention(self, conti_q, blocked_kv, block_offsets, - start_loc, seq_lens, history_lens, - block_size, mask, conti_gt): - from lmdeploy.pytorch.kernels import biased_paged_attention_fwd - kv_seq_lens = seq_lens + history_lens - max_seq_len = seq_lens.max().item() - - blocked_k, blocked_v = blocked_kv - out = torch.empty_like(conti_q) - - biased_paged_attention_fwd(conti_q, - blocked_k, - blocked_v, - mask, - out, - block_offsets=block_offsets, - b_start_loc=start_loc, - b_seq_len=seq_lens, - b_kv_seq_len=kv_seq_lens, - max_input_len=max_seq_len, - BLOCK=block_size) - - torch.testing.assert_close(out, conti_gt) - @pytest.fixture def win_size(self, request): yield request.param From 4bec832028bba10f2216137b251be56b905728d5 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 5 Mar 2024 11:56:57 +0800 Subject: [PATCH 2/2] reduce torchengine prefill mem usage (#1240) * reduce mem usage * remove pdb * del to pop --- lmdeploy/pytorch/engine/engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index c59bfed29..e4b242fc7 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -673,11 +673,13 @@ async def __long_context_forward(inputs): if token_count == 0 and slen > max_prefill_token_num: tmp_out = await __long_context_single_forward(inputs, idx) logits_gather.gather(tmp_out) + tmp_out.pop('logits', None) idx += 1 elif token_count + slen > max_prefill_token_num: tmp_out = await __long_context_batched_forward( inputs, indices[0], idx) logits_gather.gather(tmp_out) + tmp_out.pop('logits', None) indices = [] token_count = 0 else: