From 689524b26c5d05f348e4811092b3e61e1fbb68fb Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 25 Jan 2024 17:25:55 +0800 Subject: [PATCH 1/6] update alibi attention --- lmdeploy/pytorch/accel.py | 3 + .../pytorch/kernels/alibi_pagedattention.py | 76 ++++++++++++------- lmdeploy/pytorch/models/falcon.py | 4 +- lmdeploy/pytorch/models/functional.py | 27 +++---- 4 files changed, 62 insertions(+), 48 deletions(-) diff --git a/lmdeploy/pytorch/accel.py b/lmdeploy/pytorch/accel.py index e51e0589c8..c304841d57 100644 --- a/lmdeploy/pytorch/accel.py +++ b/lmdeploy/pytorch/accel.py @@ -13,6 +13,7 @@ def __init__(self): self.normal_ = torch.nn.init.normal_ self.kaiming_uniform_ = torch.nn.init.kaiming_uniform_ self.kaiming_normal_ = torch.nn.init.kaiming_normal_ + self.tensor_normal_ = torch.Tensor.normal_ def __enter__(self, *args, **kwargs): """Replace initializers with no-op.""" @@ -24,6 +25,7 @@ def __enter__(self, *args, **kwargs): torch.nn.init.normal_ = lambda *args, **kwargs: None torch.nn.init.kaiming_uniform_ = lambda *args, **kwargs: None torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None + torch.Tensor.normal_ = lambda *args, **kwargs: None def __exit__(self, *args, **kwargs): """Recover.""" @@ -35,3 +37,4 @@ def __exit__(self, *args, **kwargs): torch.nn.init.normal_ = self.normal_ torch.nn.init.kaiming_uniform_ = self.kaiming_uniform_ torch.nn.init.kaiming_normal_ = self.kaiming_normal_ + torch.Tensor.normal_ = self.tensor_normal_ diff --git a/lmdeploy/pytorch/kernels/alibi_pagedattention.py b/lmdeploy/pytorch/kernels/alibi_pagedattention.py index 54fafdbb99..d4d3148efd 100644 --- a/lmdeploy/pytorch/kernels/alibi_pagedattention.py +++ b/lmdeploy/pytorch/kernels/alibi_pagedattention.py @@ -50,6 +50,20 @@ def get_slope(i, n): 2 * closest_power_of_2) +@triton.jit +def _load_block_offsets(offset_ptr, block_id, num_sub_blocks: tl.constexpr, + BLOCK: tl.constexpr): + if num_sub_blocks > 1: + offs_sub = tl.arange(0, num_sub_blocks) + offs_n = tl.arange(0, BLOCK // num_sub_blocks) + ret = tl.load(offset_ptr + block_id * num_sub_blocks + offs_sub)[ + None, :] * BLOCK // num_sub_blocks + offs_n[:, None] + return tl.ravel(ret) + else: + offs_n = tl.arange(0, BLOCK) + return tl.load(offset_ptr + block_id) * BLOCK + offs_n + + @triton.jit def _fwd_kernel( Q, @@ -78,6 +92,7 @@ def _fwd_kernel( head_offset, num_heads, kv_group_num, + num_sub_blocks: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -104,10 +119,8 @@ def _fwd_kernel( offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) + off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd) + off_v = (cur_kv_head * stride_vh + offs_d[None, :] * stride_vd) q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) @@ -123,19 +136,28 @@ def _fwd_kernel( block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + b_offset = _load_block_offsets(block_offset_ptrs, 0, num_sub_blocks, + BLOCK_N) for start_n in range(0, block_mask * cur_batch_kv_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - start_block_id = start_n // BLOCK_N - b_offset = tl.load(block_offset_ptrs + start_block_id) - # -- compute qk ---- k = tl.load( - k_ptrs + b_offset * BLOCK_N * stride_kbs, + k_ptrs + b_offset[None, :] * stride_kbs, mask=(start_n + offs_n[None, :]) < cur_batch_kv_len, other=0.0, ) + v = tl.load( + v_ptrs + b_offset[:, None] * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_kv_len, + other=0.0, + ) + if start_n + BLOCK_N < cur_batch_kv_len: + start_block_id = start_n // BLOCK_N + 1 + b_offset = _load_block_offsets(block_offset_ptrs, start_block_id, + num_sub_blocks, BLOCK_N) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale @@ -159,13 +181,8 @@ def _fwd_kernel( # -- update output accumulator -- # scale acc acc = acc * alpha[:, None] - # update acc - v = tl.load( - v_ptrs + b_offset * BLOCK_N * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_kv_len, - other=0.0, - ) + # update acc p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i @@ -181,21 +198,18 @@ def _fwd_kernel( @torch.no_grad() -def alibi_paged_attention_fwd( - q: Tensor, - k: Tensor, - v: Tensor, - o: Tensor, - block_offsets: Tensor, - b_start_loc: Tensor, - b_seq_len: Tensor, - b_kv_seq_len: Tensor, - max_input_len: int, - head_offset: int = 0, - num_heads: int = -1, - alibi_scale: float = 1.0, - BLOCK: int = 64, -): +def alibi_paged_attention_fwd(q: Tensor, + k: Tensor, + v: Tensor, + o: Tensor, + block_offsets: Tensor, + b_start_loc: Tensor, + b_seq_len: Tensor, + b_kv_seq_len: Tensor, + max_input_len: int, + head_offset: int = 0, + num_heads: int = -1, + alibi_scale: float = 1.0): """Paged attention forward with alibi bias. Args: @@ -225,6 +239,9 @@ def alibi_paged_attention_fwd( if num_heads <= 0: num_heads = head + BLOCK = 64 if k.size(1) < 16 else k.size(1) + num_sub_blocks = BLOCK // k.size(1) + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, num_warps = 4 if Lk <= 64 else 8 @@ -255,6 +272,7 @@ def alibi_paged_attention_fwd( head_offset=head_offset, num_heads=num_heads, kv_group_num=kv_group_num, + num_sub_blocks=num_sub_blocks, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, diff --git a/lmdeploy/pytorch/models/falcon.py b/lmdeploy/pytorch/models/falcon.py index 2b2a4c9c3f..9d18d9eaaa 100644 --- a/lmdeploy/pytorch/models/falcon.py +++ b/lmdeploy/pytorch/models/falcon.py @@ -276,7 +276,6 @@ def _contiguous_batching_forward( attn_output = torch.empty_like(query_layer) block_offsets = context.block_offsets - block_size = past_key.size(1) if alibi is None: paged_attention_fwd(q=query_layer, @@ -299,8 +298,7 @@ def _contiguous_batching_forward( b_seq_len=q_seq_length, b_kv_seq_len=kv_seq_length, max_input_len=max_seq_len, - alibi_scale=self.inv_norm_factor, - BLOCK=block_size) + alibi_scale=self.inv_norm_factor) attn_output = attn_output.reshape(batch_size, query_length, -1) diff --git a/lmdeploy/pytorch/models/functional.py b/lmdeploy/pytorch/models/functional.py index 23d0e3c1a5..615a1e8f09 100644 --- a/lmdeploy/pytorch/models/functional.py +++ b/lmdeploy/pytorch/models/functional.py @@ -178,8 +178,6 @@ def attention_forward_with_paged_attention( attn_output = query_states - block_size = past_key_value[0].size(1) - bias_type = bias_type.lower() if bias_type == 'default': paged_attention_fwd( @@ -202,20 +200,17 @@ def attention_forward_with_paged_attention( rank = dist.get_rank() num_heads_full = num_heads * world_size head_offset = num_heads * rank - alibi_paged_attention_fwd( - query_states, - past_key_value[0], - past_key_value[1], - attn_output, - block_offsets, - b_start_loc=q_start_loc, - b_seq_len=q_seq_length, - b_kv_seq_len=kv_seq_length, - max_input_len=max_seq_len, - head_offset=head_offset, - num_heads=num_heads_full, - BLOCK=block_size, - ) + alibi_paged_attention_fwd(query_states, + past_key_value[0], + past_key_value[1], + attn_output, + block_offsets, + b_start_loc=q_start_loc, + b_seq_len=q_seq_length, + b_kv_seq_len=kv_seq_length, + max_input_len=max_seq_len, + head_offset=head_offset, + num_heads=num_heads_full) else: raise ValueError(f'Unknown bias type: {bias_type}') hidden_size = num_heads * head_dim From 1b2fb388d63792904eb3c642ae5ed9cfc4219838 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 25 Jan 2024 18:15:31 +0800 Subject: [PATCH 2/6] fix baichuan tp --- .../pytorch/kernels/alibi_pagedattention.py | 317 ++++++++++++++++-- lmdeploy/pytorch/models/baichuan.py | 28 +- lmdeploy/pytorch/models/peft.py | 2 +- 3 files changed, 310 insertions(+), 37 deletions(-) diff --git a/lmdeploy/pytorch/kernels/alibi_pagedattention.py b/lmdeploy/pytorch/kernels/alibi_pagedattention.py index d4d3148efd..48a1efa1a0 100644 --- a/lmdeploy/pytorch/kernels/alibi_pagedattention.py +++ b/lmdeploy/pytorch/kernels/alibi_pagedattention.py @@ -6,6 +6,7 @@ import triton import triton.language as tl from torch import Tensor +from triton.runtime.jit import get_cuda_stream assert triton.__version__ >= '2.1.0' @@ -64,6 +65,191 @@ def _load_block_offsets(offset_ptr, block_id, num_sub_blocks: tl.constexpr, return tl.load(offset_ptr + block_id) * BLOCK + offs_n +@triton.jit +def _fwd_split_kernel( + Q, + K, + V, + sm_scale, + alibi_scale, + B_kvlen, + Block_offsets, + Acc_out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_ok, + stride_obs, + stride_oh, + stride_od, + stride_boffb, + head_offset, + num_heads, + kv_group_num, + block_per_cta, + num_sub_blocks: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """first step kernel of split k attention.""" + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_k_id = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = 1 + cur_batch_kv_len = tl.load(B_kvlen + cur_batch) + history_len = cur_batch_kv_len - cur_batch_seq_len + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = (cur_batch * stride_qbs + cur_head * stride_qh + + offs_d * stride_qd) + off_k = (cur_kv_head * stride_kh + offs_d[None, :] * stride_kd) + off_v = (cur_kv_head * stride_vh + offs_d[None, :] * stride_vd) + + q = tl.load(Q + off_q).to(tl.float32) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_offset_ptrs = Block_offsets + cur_batch * stride_boffb + head_slope = get_slope( + cur_head.to(tl.float32) + head_offset, num_heads.to(tl.float32)) + + # initialize pointer to m and l + m_i = -float('inf') + l_i = float(0) + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + kv_len_per_prog = block_per_cta * BLOCK_N + loop_start = kv_len_per_prog * split_k_id + loop_end = tl.minimum(loop_start + kv_len_per_prog, cur_batch_kv_len) + + # load block offset + start_block_id = loop_start // BLOCK_N + b_offset = _load_block_offsets(block_offset_ptrs, start_block_id, + num_sub_blocks, BLOCK_N) + + for start_n in range(loop_start, loop_end, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + mask = (start_n + offs_n[:, None]) < cur_batch_kv_len + + # -- compute qk ---- + k = tl.load( + k_ptrs + b_offset[:, None] * stride_kbs, + mask=mask, + other=0.0, + ) + + v = tl.load( + v_ptrs + b_offset[:, None] * stride_vbs, + mask=mask, + other=0.0, + ) + + # prefetch b_offset + if start_n + BLOCK_N < loop_end: + start_block_id += 1 + b_offset = _load_block_offsets(block_offset_ptrs, start_block_id, + num_sub_blocks, BLOCK_N) + + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + mask = start_n + offs_n + bias = mask.to(tl.float32) * (head_slope * alibi_scale) + qk += bias + + # NOTE: inf - inf = nan, and nan will leads to error + qk = tl.where( + history_len >= (start_n + offs_n), + qk, + -float('inf'), + ) + + # -- compute p, m_i and l_i + m_i_new = tl.maximum(m_i, tl.max(qk, 0)) + p = tl.exp(qk - m_i_new) + alpha = tl.exp(m_i - m_i_new) + l_i_new = alpha * l_i + tl.sum(p, 0) + + # -- update output accumulator -- + # scale acc + acc = acc * alpha + + # update acc + p_new = p.to(v.dtype) + acc += tl.sum(p_new[:, None] * v, 0) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # initialize pointers to output + off_acc = (cur_batch * stride_obs + split_k_id * stride_ok + + cur_head * stride_oh + offs_d * stride_od) + tl.store(Acc_out + off_acc, acc) + + off_meta = (cur_batch * stride_obs + split_k_id * stride_ok + + cur_head * stride_oh + BLOCK_DMODEL) + tl.store(Acc_out + off_meta + tl.arange(0, 1), m_i) + tl.store(Acc_out + off_meta + 1 + tl.arange(0, 1), l_i) + + +@triton.jit +def _reduce_split_kernel( + Acc, + Out, + stride_ak, + stride_abs, + stride_ah, + stride_ad, + stride_obs, + stride_oh, + stride_od, + SPLIT_K: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + """second step kernel of split k attention.""" + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + # initialize offsets + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_k = tl.arange(0, SPLIT_K) + + offs_acc = (cur_batch * stride_abs + cur_head * stride_ah + + offs_k[:, None] * stride_ak + offs_d[None, :] * stride_ad) + offs_mi = (cur_batch * stride_abs + cur_head * stride_ah + + stride_ak * offs_k + BLOCK_DMODEL) + + acc_k = tl.load(Acc + offs_acc) + m_k = tl.load(Acc + offs_mi) + l_k = tl.load(Acc + offs_mi + 1) + + m_max = tl.max(m_k, 0) + alpha = tl.exp(m_k - m_max) + acc_k = acc_k * alpha[:, None] + l_k = l_k * alpha + + acc = tl.sum(acc_k, 0) + l_sum = tl.sum(l_k, 0) + acc = acc / l_sum + + out_offs = (cur_batch * stride_obs + cur_head * stride_oh + + offs_d * stride_od) + tl.store(Out + out_offs, acc) + + @triton.jit def _fwd_kernel( Q, @@ -228,6 +414,14 @@ def alibi_paged_attention_fwd(q: Tensor, tensor parallel inference. BLOCK (int): The kernel block size. """ + + def _kernel_meta(): + device = q.device + device_idx = device.index + device_type = device.type + stream = get_cuda_stream(device_idx) + return dict(device=device, device_type=device_type, stream=stream) + # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv @@ -245,37 +439,92 @@ def alibi_paged_attention_fwd(q: Tensor, grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( - q, - k, - v, - sm_scale, - alibi_scale, - b_start_loc, - b_seq_len, - b_kv_seq_len, - block_offsets, - o, - q.stride(-3), - q.stride(-2), - q.stride(-1), - k.stride(-3), - k.stride(-2), - k.stride(-1), - v.stride(-3), - v.stride(-2), - v.stride(-1), - o.stride(-3), - o.stride(-2), - o.stride(-1), - block_offsets.stride(0), - head_offset=head_offset, - num_heads=num_heads, - kv_group_num=kv_group_num, - num_sub_blocks=num_sub_blocks, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) + kernel_meta = _kernel_meta() + is_decoding = q.shape[-3] == b_seq_len.size(0) + if not is_decoding: + _fwd_kernel[grid](q, + k, + v, + sm_scale, + alibi_scale, + b_start_loc, + b_seq_len, + b_kv_seq_len, + block_offsets, + o, + q.stride(-3), + q.stride(-2), + q.stride(-1), + k.stride(-3), + k.stride(-2), + k.stride(-1), + v.stride(-3), + v.stride(-2), + v.stride(-1), + o.stride(-3), + o.stride(-2), + o.stride(-1), + block_offsets.stride(0), + head_offset=head_offset, + num_heads=num_heads, + kv_group_num=kv_group_num, + num_sub_blocks=num_sub_blocks, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + **kernel_meta) + else: + SPLIT_K = 4 + grid = (batch, head, SPLIT_K) + block_per_cta = triton.cdiv(block_offsets.size(-1), SPLIT_K) + acc = q.new_empty(batch, head, SPLIT_K, Lq + 2, dtype=torch.float32) + _fwd_split_kernel[grid](q, + k, + v, + sm_scale, + alibi_scale, + b_kv_seq_len, + block_offsets, + acc, + stride_qbs=q.stride(-3), + stride_qh=q.stride(-2), + stride_qd=q.stride(-1), + stride_kbs=k.stride(-3), + stride_kh=k.stride(-2), + stride_kd=k.stride(-1), + stride_vbs=v.stride(-3), + stride_vh=v.stride(-2), + stride_vd=v.stride(-1), + stride_ok=acc.stride(-2), + stride_obs=acc.stride(-4), + stride_oh=acc.stride(-3), + stride_od=acc.stride(-1), + stride_boffb=block_offsets.stride(0), + head_offset=head_offset, + num_heads=num_heads, + kv_group_num=kv_group_num, + block_per_cta=block_per_cta, + num_sub_blocks=num_sub_blocks, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=4, + num_stages=1, + **kernel_meta) + + grid = (batch, head) + _reduce_split_kernel[grid](acc, + o, + stride_ak=acc.stride(-2), + stride_abs=acc.stride(-4), + stride_ah=acc.stride(-3), + stride_ad=acc.stride(-1), + stride_obs=o.stride(-3), + stride_oh=o.stride(-2), + stride_od=o.stride(-1), + SPLIT_K=SPLIT_K, + BLOCK_DMODEL=Lk, + num_warps=num_warps, + num_stages=1, + **kernel_meta) diff --git a/lmdeploy/pytorch/models/baichuan.py b/lmdeploy/pytorch/models/baichuan.py index fd5ad68ab1..5ae4440e0d 100644 --- a/lmdeploy/pytorch/models/baichuan.py +++ b/lmdeploy/pytorch/models/baichuan.py @@ -7,7 +7,8 @@ from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor from transformers.modeling_outputs import BaseModelOutputWithPast -from ..dist_utils import rowwise_parallelize_linear_fn, try_to_local +from ..dist_utils import (colwise_parallelize_linear_fn, + rowwise_parallelize_linear_fn, try_to_local) from .functional import attention_forward_with_paged_attention from .llama import apply_rotary_pos_emb @@ -15,7 +16,9 @@ def _attention_partition_fn(mod_name: str, mod: nn.Module, device_mesh: DeviceMesh): """A function for attention partition.""" - if mod_name in ['W_pack']: + + def __w_pack_linear_fn(mod: nn.Module): + """fn for w pack linear.""" for name, param in mod.named_parameters(): param = param.unflatten(0, (3, -1)) dist_tensor = distribute_tensor(param, device_mesh, [Shard(1)]) @@ -23,6 +26,27 @@ def _attention_partition_fn(mod_name: str, mod: nn.Module, dist_tensor = dist_tensor.flatten(0, 1) dist_param = torch.nn.Parameter(dist_tensor) mod.register_parameter(name, dist_param) + + def __w_pack_lora_linear_fn(mod: nn.Module): + """fn for w pack lora linear.""" + mod._tp_mode = 'colwise' + base_layer = mod.base_layer + __w_pack_linear_fn(base_layer) + + for lora_a_mod in mod.lora_A.values(): + colwise_parallelize_linear_fn(lora_a_mod, + device_mesh=device_mesh, + to_local=True) + + for lora_b_mod in mod.lora_B.values(): + __w_pack_linear_fn(lora_b_mod) + + if mod_name in ['W_pack']: + from peft.tuners.lora import Linear as LoraLinear + if isinstance(mod, LoraLinear): + __w_pack_lora_linear_fn(mod) + else: + __w_pack_linear_fn(mod) elif mod_name in ['o_proj']: rowwise_parallelize_linear_fn(mod, device_mesh=device_mesh, diff --git a/lmdeploy/pytorch/models/peft.py b/lmdeploy/pytorch/models/peft.py index 6831cab8dc..dbf16e13a7 100644 --- a/lmdeploy/pytorch/models/peft.py +++ b/lmdeploy/pytorch/models/peft.py @@ -253,7 +253,7 @@ def _lora_forward_tp(self, x): return self._lora_forward_tp_colwise(x) else: assert tp_mode is None, 'tp_mode == None failed.' - return self._lora_forward_tp(x) + return self._lora_forward_local(x) def _lora_forward(self, x): """lora forward.""" From 671a8e55066d5914dfd4abb50ed5c19c7a786225 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 25 Jan 2024 18:44:09 +0800 Subject: [PATCH 3/6] remove func usage --- lmdeploy/pytorch/models/baichuan.py | 130 +++++++++++++++++++++------- 1 file changed, 99 insertions(+), 31 deletions(-) diff --git a/lmdeploy/pytorch/models/baichuan.py b/lmdeploy/pytorch/models/baichuan.py index 5ae4440e0d..aa221361f9 100644 --- a/lmdeploy/pytorch/models/baichuan.py +++ b/lmdeploy/pytorch/models/baichuan.py @@ -9,7 +9,9 @@ from ..dist_utils import (colwise_parallelize_linear_fn, rowwise_parallelize_linear_fn, try_to_local) -from .functional import attention_forward_with_paged_attention +from ..kernels.alibi_pagedattention import alibi_paged_attention_fwd +from ..kernels.fill_kv_cache import fill_kv_cache +from ..kernels.pagedattention import paged_attention_fwd from .llama import apply_rotary_pos_emb @@ -107,8 +109,17 @@ def _contiguous_batching_forward( assert not output_attentions context = self.context.context history_lengths = context.history_lengths + kv_seq_length = context.kv_seq_length + q_seq_length = context.seq_length + q_start_loc = context.q_start_loc + block_offsets = context.block_offsets + + num_heads = self.num_heads // world_size + num_kv_heads = self.num_heads // world_size + head_dim = self.head_dim def _qkv_proj(hidden_states): + """qkv proj.""" proj = self.W_pack(hidden_states) return proj.chunk(3, -1) @@ -122,21 +133,45 @@ def _rotary_emb_fn(query_states, key_states, value_states): getattr(context, 'position_ids_1d', None)) return query_states, key_states, value_states - attn_output = attention_forward_with_paged_attention( - hidden_states, - history_lengths=history_lengths, - block_offsets=context.block_offsets, - num_heads=self.num_heads // world_size, - num_kv_heads=self.num_heads // world_size, - head_dim=self.head_dim, - position_ids=position_ids, - past_key_value=past_key_value, - context=context, - qkv_proj=_qkv_proj, - o_proj=self.o_proj, - rotary_emb_fn=_rotary_emb_fn, + query_states, key_states, value_states = _qkv_proj(hidden_states) + + query_states = query_states.view(-1, num_heads, head_dim) + key_states = key_states.view(-1, num_kv_heads, head_dim) + value_states = value_states.view(-1, num_kv_heads, head_dim) + + query_states, key_states, value_states = _rotary_emb_fn( + query_states, key_states, value_states) + + fill_kv_cache(key_states, + value_states, + past_key_value[0], + past_key_value[1], + q_start_loc, + q_seq_length, + block_offsets=block_offsets, + history_lengths=history_lengths, + context=context) + + attn_output = query_states + max_seq_len = position_ids.size(-1) + paged_attention_fwd( + query_states, + past_key_value[0], + past_key_value[1], + attn_output, + block_offsets, + b_start_loc=q_start_loc, + b_seq_len=q_seq_length, + b_kv_seq_len=kv_seq_length, + max_input_len=max_seq_len, ) + hidden_size = num_heads * head_dim + attn_output = attn_output.reshape(*hidden_states.shape[:-1], + hidden_size) + + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_value @@ -201,28 +236,61 @@ def _contiguous_batching_forward( context = self.context.context position_ids = context.position_ids history_lengths = context.history_lengths + kv_seq_length = context.kv_seq_length + q_seq_length = context.seq_length + q_start_loc = context.q_start_loc + block_offsets = context.block_offsets + + num_heads = self.num_heads // world_size + num_kv_heads = self.num_heads // world_size + head_dim = self.head_dim def _qkv_proj(hidden_states): proj = self.W_pack(hidden_states) return proj.chunk(3, -1) - _rotary_emb_fn = None - - attn_output = attention_forward_with_paged_attention( - hidden_states, - history_lengths=history_lengths, - block_offsets=context.block_offsets, - num_heads=self.num_heads // world_size, - num_kv_heads=self.num_heads // world_size, - head_dim=self.head_dim, - position_ids=position_ids, - past_key_value=past_key_value, - context=context, - qkv_proj=_qkv_proj, - o_proj=self.o_proj, - rotary_emb_fn=_rotary_emb_fn, - bias_type='alibi', - ) + query_states, key_states, value_states = _qkv_proj(hidden_states) + query_states = query_states.view(-1, num_heads, head_dim) + key_states = key_states.view(-1, num_kv_heads, head_dim) + value_states = value_states.view(-1, num_kv_heads, head_dim) + + fill_kv_cache(key_states, + value_states, + past_key_value[0], + past_key_value[1], + q_start_loc, + q_seq_length, + block_offsets=block_offsets, + history_lengths=history_lengths, + context=context) + + attn_output = query_states + + num_heads_full = num_heads + head_offset = 0 + max_seq_len = position_ids.size(-1) + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + num_heads_full = num_heads * world_size + head_offset = num_heads * rank + alibi_paged_attention_fwd(query_states, + past_key_value[0], + past_key_value[1], + attn_output, + block_offsets, + b_start_loc=q_start_loc, + b_seq_len=q_seq_length, + b_kv_seq_len=kv_seq_length, + max_input_len=max_seq_len, + head_offset=head_offset, + num_heads=num_heads_full) + + hidden_size = num_heads * head_dim + attn_output = attn_output.reshape(*hidden_states.shape[:-1], + hidden_size) + + attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value From 8745c26142e2ff90193b5301a1dca8758a8ef7e7 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 25 Jan 2024 21:28:48 +0800 Subject: [PATCH 4/6] fix lora --- lmdeploy/pytorch/adapter/adapter.py | 5 +++++ lmdeploy/pytorch/engine/engine.py | 6 ++++++ lmdeploy/pytorch/engine/model_agent.py | 9 +++++++++ lmdeploy/pytorch/kernels/mbgmm.py | 5 +++++ lmdeploy/pytorch/kernels/mbgmv.py | 5 +++++ lmdeploy/pytorch/models/peft.py | 8 ++++++++ tests/pytorch/kernel/test_mbgmm.py | 9 +++++++-- tests/pytorch/kernel/test_mbgmv.py | 7 ++++++- 8 files changed, 51 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/adapter/adapter.py b/lmdeploy/pytorch/adapter/adapter.py index 178f681de9..f8186123ae 100644 --- a/lmdeploy/pytorch/adapter/adapter.py +++ b/lmdeploy/pytorch/adapter/adapter.py @@ -266,6 +266,11 @@ def rank(self): """get rank.""" return self.config.r + @property + def scaling(self): + """get scaling.""" + return self.config.lora_alpha / self.rank + def is_actived(self): """check if adapter is active.""" return self._active diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index ebaec203b6..0ace19efe7 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -370,6 +370,7 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList): local_adapter_ids = None global_adapter_ids = None adapter_offsets = None + local_adapter_scalings = None max_rank = 0 if ADAPTER_MANAGER.num_adapters() > 1: local_adapter_ids = _get_adapter_ids(messages, adapters) @@ -380,6 +381,10 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList): global_adapter_ids = seq_length.new_tensor(global_adapter_ids) ranks = [ada.rank for ada in adapters] max_rank = max(ranks) + local_adapter_scalings = [ + adapters[ada_ids].scaling for ada_ids in local_adapter_ids + ] + local_adapter_scalings = torch.tensor(local_adapter_scalings) # add batch dim [bs=1, seq_len] if input_ids.ndim == 1: @@ -396,6 +401,7 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList): local_adapter_ids=local_adapter_ids, global_adapter_ids=global_adapter_ids, adapter_offsets=adapter_offsets, + local_adapter_scalings=local_adapter_scalings, max_rank=max_rank, meta=meta) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index d4ee0cf29f..e28b76f63c 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -82,6 +82,7 @@ class ModelInputs: local_adapter_ids: torch.LongTensor global_adapter_ids: torch.LongTensor adapter_offsets: torch.LongTensor + local_adapter_scalings: torch.Tensor max_rank: int meta: Any @@ -98,8 +99,10 @@ def slice(self, start: int, end: int): history_lengths = self.history_lengths[sli] local_adapter_ids = self.local_adapter_ids + local_adapter_scalings = self.local_adapter_scalings if local_adapter_ids is not None: local_adapter_ids = local_adapter_ids[sli] + local_adapter_scalings = local_adapter_scalings[sli] return ModelInputs(input_ids=input_ids, seq_length=seq_length, @@ -112,6 +115,7 @@ def slice(self, start: int, end: int): local_adapter_ids=local_adapter_ids, global_adapter_ids=self.global_adapter_ids, adapter_offsets=self.adapter_offsets, + local_adapter_scalings=local_adapter_scalings, max_rank=self.max_rank, meta=self.meta) @@ -140,8 +144,10 @@ def split(self, split_size: int, block_size: int): block_end += 1 local_adapter_ids = self.local_adapter_ids + local_adapter_scalings = self.local_adapter_scalings if local_adapter_ids is not None: local_adapter_ids = local_adapter_ids[:, start:end] + local_adapter_scalings = local_adapter_scalings[:, start:end] inp = ModelInputs( input_ids=self.input_ids[:, start:end], @@ -155,6 +161,7 @@ def split(self, split_size: int, block_size: int): local_adapter_ids=local_adapter_ids, global_adapter_ids=self.global_adapter_ids, adapter_offsets=self.adapter_offsets, + local_adapter_scalings=local_adapter_scalings, max_rank=self.max_rank, meta=self.meta, ) @@ -198,6 +205,7 @@ class StepContext: local_adapter_ids: torch.LongTensor = None global_adapter_ids: torch.LongTensor = None adapter_offsets: torch.LongTensor = None + local_adapter_scalings: torch.Tensor = None max_rank: int = 0 _outputs: Dict = field(default_factory=dict) @@ -246,6 +254,7 @@ def new( local_adapter_ids=inputs.local_adapter_ids, global_adapter_ids=inputs.global_adapter_ids, adapter_offsets=inputs.adapter_offsets, + local_adapter_scalings=inputs.local_adapter_scalings, max_rank=inputs.max_rank) return ret diff --git a/lmdeploy/pytorch/kernels/mbgmm.py b/lmdeploy/pytorch/kernels/mbgmm.py index 24be1e234e..122d08a705 100644 --- a/lmdeploy/pytorch/kernels/mbgmm.py +++ b/lmdeploy/pytorch/kernels/mbgmm.py @@ -101,6 +101,7 @@ def _acc_b_mm_kernel( B_start_loc, B_seq_lens, B_adapter_id, + B_scaling, Rank_page_table, Rank_page_start, Ranks, @@ -127,6 +128,7 @@ def _acc_b_mm_kernel( start_loc = tl.load(B_start_loc + cur_batch) adapter_id = tl.load(B_adapter_id + cur_batch) + scaling = tl.load(B_scaling + cur_batch) rank = tl.load(Ranks + adapter_id) page_start = tl.load(Rank_page_start + adapter_id) @@ -164,6 +166,7 @@ def _acc_b_mm_kernel( # compute out = tl.dot(acc, lb) out = out.to(lb.dtype) + out = out * scaling # store o oh_off = cur_dm_off * stride_oh @@ -244,6 +247,7 @@ def mbgmm_b(xa: Tensor, b_start_loc: Tensor, b_seq_lens: Tensor, b_adapter_ids: Tensor, + b_scaling: Tensor, rank_page_table: Tensor, ranks: Tensor, rank_page_start: Tensor, @@ -284,6 +288,7 @@ def _kernel_meta(): b_start_loc, b_seq_lens, b_adapter_ids, + b_scaling, Rank_page_table=rank_page_table, Rank_page_start=rank_page_start, Ranks=ranks, diff --git a/lmdeploy/pytorch/kernels/mbgmv.py b/lmdeploy/pytorch/kernels/mbgmv.py index e8c92d63d9..aac963e786 100644 --- a/lmdeploy/pytorch/kernels/mbgmv.py +++ b/lmdeploy/pytorch/kernels/mbgmv.py @@ -81,6 +81,7 @@ def _acc_b_mv_kernel( LoRA_B, Out, B_adapter_id, + B_scaling, Rank_page_table, Rank_page_start, Ranks, @@ -100,6 +101,7 @@ def _acc_b_mv_kernel( r_off = tl.arange(0, BLOCK_R) adapter_id = tl.load(B_adapter_id + cur_batch) + scaling = tl.load(B_scaling + cur_batch) rank = tl.load(Ranks + adapter_id) page_start = tl.load(Rank_page_start + adapter_id) @@ -130,6 +132,7 @@ def _acc_b_mv_kernel( # compute out = tl.sum(acc[:, None] * lb, 0) out = out.to(lb.dtype) + out = out * scaling # store o oh_off = cur_dm_off * stride_oh @@ -198,6 +201,7 @@ def _kernel_meta(): def mbgmv_b(xa: Tensor, lora_b: Tensor, b_adapter_ids: Tensor, + b_scaling: Tensor, rank_page_table: Tensor, ranks: Tensor, rank_page_start: Tensor, @@ -232,6 +236,7 @@ def _kernel_meta(): lora_b, output, b_adapter_ids, + b_scaling, Rank_page_table=rank_page_table, Rank_page_start=rank_page_start, Ranks=ranks, diff --git a/lmdeploy/pytorch/models/peft.py b/lmdeploy/pytorch/models/peft.py index dbf16e13a7..74e5279d11 100644 --- a/lmdeploy/pytorch/models/peft.py +++ b/lmdeploy/pytorch/models/peft.py @@ -17,6 +17,7 @@ class PackedLoRAInput: b_start_loc: torch.Tensor b_seq_lens: torch.Tensor b_adapter_ids: torch.Tensor + b_scaling: torch.Tensor rank_page_table: torch.Tensor rank_page_start: torch.Tensor ranks: torch.Tensor @@ -46,6 +47,7 @@ def _make_packed_lora_input(self, x): b_start_loc=context.q_start_loc, b_seq_lens=context.seq_length, b_adapter_ids=context.local_adapter_ids, + b_scaling=context.local_adapter_scalings, rank_page_table=context.adapter_offsets, rank_page_start=block_starts, ranks=ranks, @@ -75,6 +77,7 @@ def _lora_forward_local(self, x): b_start_loc=lora_input.b_start_loc, b_seq_lens=lora_input.b_seq_lens, b_adapter_ids=lora_input.b_adapter_ids, + b_scaling=lora_input.b_scaling, rank_page_table=lora_input.rank_page_table, rank_page_start=lora_input.rank_page_start, ranks=lora_input.ranks, @@ -92,6 +95,7 @@ def _lora_forward_local(self, x): lora_out = mbgmv_b(xa, lora_input.b_cache, b_adapter_ids=lora_input.b_adapter_ids, + b_scaling=lora_input.b_scaling, rank_page_table=lora_input.rank_page_table, rank_page_start=lora_input.rank_page_start, ranks=lora_input.ranks, @@ -127,6 +131,7 @@ def _lora_forward_tp_rowwise(self, x): b_start_loc=lora_input.b_start_loc, b_seq_lens=lora_input.b_seq_lens, b_adapter_ids=lora_input.b_adapter_ids, + b_scaling=lora_input.b_scaling, rank_page_table=lora_input.rank_page_table, rank_page_start=lora_input.rank_page_start, ranks=lora_input.ranks, @@ -144,6 +149,7 @@ def _lora_forward_tp_rowwise(self, x): lora_out = mbgmv_b(xa, lora_input.b_cache, b_adapter_ids=lora_input.b_adapter_ids, + b_scaling=lora_input.b_scaling, rank_page_table=lora_input.rank_page_table, rank_page_start=lora_input.rank_page_start, ranks=lora_input.ranks, @@ -203,6 +209,7 @@ def __gather_xa(xa): b_start_loc=lora_input.b_start_loc, b_seq_lens=lora_input.b_seq_lens, b_adapter_ids=lora_input.b_adapter_ids, + b_scaling=lora_input.b_scaling, rank_page_table=lora_input.rank_page_table, rank_page_start=lora_input.rank_page_start, ranks=lora_input.ranks, @@ -232,6 +239,7 @@ def __gather_xa(xa): lora_out = mbgmv_b(gathered_xa, lora_input.b_cache, b_adapter_ids=lora_input.b_adapter_ids, + b_scaling=lora_input.b_scaling, rank_page_table=lora_input.rank_page_table, rank_page_start=lora_input.rank_page_start, ranks=lora_input.ranks, diff --git a/tests/pytorch/kernel/test_mbgmm.py b/tests/pytorch/kernel/test_mbgmm.py index 8151a9bc72..8b60594b10 100644 --- a/tests/pytorch/kernel/test_mbgmm.py +++ b/tests/pytorch/kernel/test_mbgmm.py @@ -47,6 +47,10 @@ def adapter_ids(self, seq_lens, ranks): ret = torch.randint(0, num_ranks, (num_seqs, )).cuda() yield ret + @pytest.fixture + def scaling(self, adapter_ids): + yield torch.ones(adapter_ids.size(0)).cuda() + @pytest.fixture def lora_a(self, ranks, head_size, dtype): out = [] @@ -99,8 +103,8 @@ def gt(self, input, start_loc, seq_lens, adapter_ids, lora_a, lora_b): yield torch.cat(out) def test_mbgmm(self, input, paged_lora_a, paged_lora_b, out_head_size, - start_loc, seq_lens, adapter_ids, page_table, ranks, - page_start, gt): + start_loc, seq_lens, adapter_ids, scaling, page_table, + ranks, page_start, gt): max_seq_len = max(seq_lens).item() max_rank = page_table.size(-1) @@ -120,6 +124,7 @@ def test_mbgmm(self, input, paged_lora_a, paged_lora_b, out_head_size, b_start_loc=start_loc, b_seq_lens=seq_lens, b_adapter_ids=adapter_ids, + b_scaling=scaling, rank_page_table=page_table, rank_page_start=page_start, ranks=ranks, diff --git a/tests/pytorch/kernel/test_mbgmv.py b/tests/pytorch/kernel/test_mbgmv.py index 2ffadcadbc..4143b00e82 100644 --- a/tests/pytorch/kernel/test_mbgmv.py +++ b/tests/pytorch/kernel/test_mbgmv.py @@ -43,6 +43,10 @@ def adapter_ids(self, batch_size, ranks): ret = torch.randint(0, num_ranks, (batch_size, )).cuda() yield ret + @pytest.fixture + def scaling(self, adapter_ids): + yield torch.ones(adapter_ids.size(0)).cuda() + @pytest.fixture def lora_a(self, ranks, head_size, dtype): out = [] @@ -97,7 +101,7 @@ def gt(self, input, adapter_ids, lora_a, lora_b): yield torch.cat(out) def test_mbgmv(self, input, paged_lora_a, paged_lora_b, out_head_size, - adapter_ids, page_table, ranks, page_start, gt): + adapter_ids, scaling, page_table, ranks, page_start, gt): max_rank = page_table.size(-1) xa = mbgmv_a(input, @@ -111,6 +115,7 @@ def test_mbgmv(self, input, paged_lora_a, paged_lora_b, out_head_size, output = mbgmv_b(xa, paged_lora_b[..., :out_head_size], b_adapter_ids=adapter_ids, + b_scaling=scaling, rank_page_table=page_table, rank_page_start=page_start, ranks=ranks, From 77a7b1e732968b2f4f9a4c3d00541dec45a7b954 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 26 Jan 2024 10:39:45 +0800 Subject: [PATCH 5/6] hide local scaling --- lmdeploy/pytorch/adapter/adapter.py | 25 ++++++++++++++++++------- lmdeploy/pytorch/engine/engine.py | 6 ------ lmdeploy/pytorch/engine/model_agent.py | 9 --------- lmdeploy/pytorch/models/peft.py | 3 ++- 4 files changed, 20 insertions(+), 23 deletions(-) diff --git a/lmdeploy/pytorch/adapter/adapter.py b/lmdeploy/pytorch/adapter/adapter.py index f8186123ae..f93ac5541a 100644 --- a/lmdeploy/pytorch/adapter/adapter.py +++ b/lmdeploy/pytorch/adapter/adapter.py @@ -79,10 +79,11 @@ def __get_targets(): all_targets.update(targets) return all_targets - def __get_rank_and_start(target_names): + def __get_linear_meta(target_names): """get rank and start.""" rank_map = dict() start_map = dict() + scaling_map = dict() for target in target_names: ranks = [0] + [ weight_map.target_modules[target].rank @@ -92,15 +93,22 @@ def __get_rank_and_start(target_names): weight_map.target_modules[target].block_start for weight_map in weight_maps ] + scaling = [0] + [ + weight_map.target_modules[target].scaling + for weight_map in weight_maps + ] rank_map[target] = torch.tensor(ranks) start_map[target] = torch.tensor(block_starts) - return rank_map, start_map + scaling_map[target] = torch.tensor(scaling) + return rank_map, start_map, scaling_map - def __update_linear(linear, idx, rank_map, start_map, adapter_names): + def __update_linear(linear, idx, rank_map, start_map, scaling_map, + adapter_names): """update linear.""" linear.layer_idx = idx linear.ranks = rank_map[target].to(device) linear.block_starts = start_map[target].to(device) + linear.scaling = scaling_map[target].to(device) for name in adapter_names: if name in linear.lora_A: linear.lora_A.pop(name) @@ -113,7 +121,7 @@ def __update_linear(linear, idx, rank_map, start_map, adapter_names): for weight_map in weight_maps: weight_map.expand_targets(all_targets) - rank_map, start_map = __get_rank_and_start(all_targets) + rank_map, start_map, scaling_map = __get_linear_meta(all_targets) for idx, lora_linear in lora_linears.items(): for target, linear in lora_linear.items(): @@ -121,6 +129,7 @@ def __update_linear(linear, idx, rank_map, start_map, adapter_names): idx, rank_map=rank_map, start_map=start_map, + scaling_map=scaling_map, adapter_names=adapter_names) @@ -139,6 +148,7 @@ def get_max_lora_weight_size(model: torch.nn.Module): class TargetMeta: rank: int block_start: int + scaling: float @dataclass @@ -149,12 +159,12 @@ class AdapterWeightMap: @classmethod def new(cls, adapter_name: str, rank: int, target_names: List[str], - block_table: Tensor): + block_table: Tensor, scaling: float): """create new weightmap.""" block_start = 0 target_modules: Dict[str, TargetMeta] = dict() for name in target_names: - target_modules[name] = TargetMeta(rank, block_start) + target_modules[name] = TargetMeta(rank, block_start, scaling) block_start += rank return AdapterWeightMap(adapter_name, @@ -296,7 +306,8 @@ def build_weight_map(self, block_table: Tensor): return AdapterWeightMap.new(self.name, rank=self.rank, target_names=self.target_modules, - block_table=block_table) + block_table=block_table, + scaling=self.scaling) class AdapterManager: diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 0ace19efe7..ebaec203b6 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -370,7 +370,6 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList): local_adapter_ids = None global_adapter_ids = None adapter_offsets = None - local_adapter_scalings = None max_rank = 0 if ADAPTER_MANAGER.num_adapters() > 1: local_adapter_ids = _get_adapter_ids(messages, adapters) @@ -381,10 +380,6 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList): global_adapter_ids = seq_length.new_tensor(global_adapter_ids) ranks = [ada.rank for ada in adapters] max_rank = max(ranks) - local_adapter_scalings = [ - adapters[ada_ids].scaling for ada_ids in local_adapter_ids - ] - local_adapter_scalings = torch.tensor(local_adapter_scalings) # add batch dim [bs=1, seq_len] if input_ids.ndim == 1: @@ -401,7 +396,6 @@ def create_model_inputs(self, messages: SeqList, adapters: AdapterList): local_adapter_ids=local_adapter_ids, global_adapter_ids=global_adapter_ids, adapter_offsets=adapter_offsets, - local_adapter_scalings=local_adapter_scalings, max_rank=max_rank, meta=meta) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 756d51bdea..f92b2bfb5f 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -82,7 +82,6 @@ class ModelInputs: local_adapter_ids: torch.LongTensor = None global_adapter_ids: torch.LongTensor = None adapter_offsets: torch.LongTensor = None - local_adapter_scalings: torch.Tensor = None max_rank: int = 0 meta: Any = None @@ -99,10 +98,8 @@ def slice(self, start: int, end: int): history_lengths = self.history_lengths[sli] local_adapter_ids = self.local_adapter_ids - local_adapter_scalings = self.local_adapter_scalings if local_adapter_ids is not None: local_adapter_ids = local_adapter_ids[sli] - local_adapter_scalings = local_adapter_scalings[sli] return ModelInputs(input_ids=input_ids, seq_length=seq_length, @@ -115,7 +112,6 @@ def slice(self, start: int, end: int): local_adapter_ids=local_adapter_ids, global_adapter_ids=self.global_adapter_ids, adapter_offsets=self.adapter_offsets, - local_adapter_scalings=local_adapter_scalings, max_rank=self.max_rank, meta=self.meta) @@ -144,10 +140,8 @@ def split(self, split_size: int, block_size: int): block_end += 1 local_adapter_ids = self.local_adapter_ids - local_adapter_scalings = self.local_adapter_scalings if local_adapter_ids is not None: local_adapter_ids = local_adapter_ids[:, start:end] - local_adapter_scalings = local_adapter_scalings[:, start:end] inp = ModelInputs( input_ids=self.input_ids[:, start:end], @@ -161,7 +155,6 @@ def split(self, split_size: int, block_size: int): local_adapter_ids=local_adapter_ids, global_adapter_ids=self.global_adapter_ids, adapter_offsets=self.adapter_offsets, - local_adapter_scalings=local_adapter_scalings, max_rank=self.max_rank, meta=self.meta, ) @@ -205,7 +198,6 @@ class StepContext: local_adapter_ids: torch.LongTensor = None global_adapter_ids: torch.LongTensor = None adapter_offsets: torch.LongTensor = None - local_adapter_scalings: torch.Tensor = None max_rank: int = 0 _outputs: Dict = field(default_factory=dict) @@ -254,7 +246,6 @@ def new( local_adapter_ids=inputs.local_adapter_ids, global_adapter_ids=inputs.global_adapter_ids, adapter_offsets=inputs.adapter_offsets, - local_adapter_scalings=inputs.local_adapter_scalings, max_rank=inputs.max_rank) return ret diff --git a/lmdeploy/pytorch/models/peft.py b/lmdeploy/pytorch/models/peft.py index 74e5279d11..52fbc48df1 100644 --- a/lmdeploy/pytorch/models/peft.py +++ b/lmdeploy/pytorch/models/peft.py @@ -36,6 +36,7 @@ def _make_packed_lora_input(self, x): layer_idx = self.layer_idx ranks = self.ranks[global_adapter_ids] block_starts = self.block_starts[global_adapter_ids] + scaling = self.scaling[global_adapter_ids] k_cache, v_cache = context.kv_caches[layer_idx] cache_len = k_cache.size(0) a_cache = k_cache.view(cache_len, -1) @@ -47,7 +48,7 @@ def _make_packed_lora_input(self, x): b_start_loc=context.q_start_loc, b_seq_lens=context.seq_length, b_adapter_ids=context.local_adapter_ids, - b_scaling=context.local_adapter_scalings, + b_scaling=scaling, rank_page_table=context.adapter_offsets, rank_page_start=block_starts, ranks=ranks, From 394ad257c0799cde845a10649e0e76572b3c936a Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 26 Jan 2024 10:47:19 +0800 Subject: [PATCH 6/6] fix --- lmdeploy/pytorch/adapter/adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/adapter/adapter.py b/lmdeploy/pytorch/adapter/adapter.py index f93ac5541a..3746915346 100644 --- a/lmdeploy/pytorch/adapter/adapter.py +++ b/lmdeploy/pytorch/adapter/adapter.py @@ -180,7 +180,7 @@ def expand_targets(self, continue else: raise RuntimeError(f'target {name} exists.') - self.target_modules[name] = TargetMeta(0, 0) + self.target_modules[name] = TargetMeta(0, 0, 0.0) @classmethod def cache_lora_a(cls, cache: Tensor, weight: Tensor, block_table: Tensor):