Skip to content

Commit

Permalink
optimize fill kv cache (#523)
Browse files Browse the repository at this point in the history
* optimize fill kv cache

* update internlm

* faster embedding

* fix bias tp

* fix baichuan2

* fix fill kv cache

* fix lint

---------
  • Loading branch information
q.yao authored Oct 9, 2023
1 parent 8085fbc commit 668e30d
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 103 deletions.
3 changes: 3 additions & 0 deletions lmdeploy/pytorch_poc/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def rowwise_parallelize_linear_fn(module: nn.Module,
dist_tensor = distribute_tensor(param, device_mesh, dist_spec)
if to_local:
dist_tensor = try_to_local(dist_tensor)
if name == 'bias':
# rowwise linear would add bias more than ones.
dist_tensor /= device_mesh.size()
dist_param = torch.nn.Parameter(dist_tensor)
module.register_parameter(name, dist_param)

Expand Down
28 changes: 21 additions & 7 deletions lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper)
from transformers.utils import WEIGHTS_INDEX_NAME, cached_file
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, cached_file

from lmdeploy.pytorch.accel import LoadNoInit
from lmdeploy.pytorch_poc.config import (CacheConfig, ModelConfig,
Expand Down Expand Up @@ -106,12 +106,16 @@ def __init__(
block_offsets: List[List[int]],
history_lengths: List[int],
position_ids: torch.Tensor,
q_start_loc: torch.Tensor,
seq_length: torch.Tensor,
world_size: int = 1,
device='cuda',
):
self.block_offsets_list = block_offsets
self.history_lengths = history_lengths
self.position_ids = position_ids
self.q_start_loc = q_start_loc
self.seq_length = seq_length
self.world_size = world_size

# padding zero
Expand Down Expand Up @@ -257,14 +261,20 @@ def _tp_model_loop(
torch_dtype=torch_dtype,
trust_remote_code=True)

torch_model_json_path = cached_file(model_path, WEIGHTS_INDEX_NAME)
with open(torch_model_json_path, mode='r') as f:
torch_model_json = json.load(f)
try:
torch_model_json_path = cached_file(model_path, WEIGHTS_INDEX_NAME)
with open(torch_model_json_path, mode='r') as f:
torch_model_json = json.load(f)

weight_map = torch_model_json['weight_map']
weight_map = torch_model_json['weight_map']

checkpoints = list(set(weight_map.values()))
checkpoints = [cached_file(model_path, ckpt) for ckpt in checkpoints]
checkpoints = list(set(weight_map.values()))
checkpoints = [
cached_file(model_path, ckpt) for ckpt in checkpoints
]
except Exception:
logger.warning(f'load failed, try load from {WEIGHTS_NAME}.')
checkpoints = [cached_file(model_path, WEIGHTS_NAME)]
patched_model = patch(
model,
extra_args=extra_args,
Expand Down Expand Up @@ -374,6 +384,8 @@ def _tp_model_loop(
block_offsets=inputs['block_offsets'],
history_lengths=inputs['history_lengths'],
position_ids=inputs['position_ids'],
q_start_loc=inputs['q_start_loc'],
seq_length=inputs['seq_length'],
world_size=world_size,
),
q_seq_info=(inputs['q_start_loc'], inputs['seq_length']),
Expand Down Expand Up @@ -717,6 +729,8 @@ def _model_forward(self, inputs: Dict, swap_in_map: Dict[int, int],
block_offsets=inputs['block_offsets'],
history_lengths=inputs['history_lengths'],
position_ids=inputs['position_ids'],
q_start_loc=inputs['q_start_loc'],
seq_length=inputs['seq_length'],
),
q_seq_info=(inputs['q_start_loc'], inputs['seq_length']),
)
Expand Down
7 changes: 3 additions & 4 deletions lmdeploy/pytorch_poc/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .alibi_pagedattention import alibi_paged_attention_fwd
from .biased_pagedattention import biased_paged_attention_fwd
from .fill_kv_cache import fill_kv_cache
from .flashattention_nopad import context_attention_fwd
from .pagedattention import paged_attention_fwd

__all__ = [
'context_attention_fwd',
'paged_attention_fwd',
'biased_paged_attention_fwd',
'alibi_paged_attention_fwd',
'context_attention_fwd', 'paged_attention_fwd',
'biased_paged_attention_fwd', 'alibi_paged_attention_fwd', 'fill_kv_cache'
]
165 changes: 165 additions & 0 deletions lmdeploy/pytorch_poc/kernels/fill_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Sequence

import torch
import triton
import triton.language as tl
from torch import Tensor


@triton.jit
def _fill_kv_cache_kernel(
k_states,
v_states,
k_caches,
v_caches,
state_start,
state_len,
cache_start,
block_offsets1d,
stride_kss, # stride of key state token
stride_vss, # stride of value state token
stride_kcs: tl.constexpr, # stride of key cache token
stride_vcs: tl.constexpr, # stride of value cache token
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
prog_id = tl.program_id(0)

stride_kb = stride_kcs * BLOCK_M
stride_vb = stride_vcs * BLOCK_M

sstart = tl.load(state_start + prog_id)
slen = tl.load(state_len + prog_id)
cstart = tl.load(cache_start + prog_id)
boffset = tl.load(block_offsets1d + prog_id)

off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)

ks_ptrs = k_states + (sstart +
off_m[:, None]) * stride_kss + off_n[None, :]
vs_ptrs = v_states + (sstart +
off_m[:, None]) * stride_vss + off_n[None, :]
kc_ptrs = k_caches + boffset * stride_kb + (
cstart + off_m[:, None]) * stride_kcs + off_n[None, :]
vc_ptrs = v_caches + boffset * stride_vb + (
cstart + off_m[:, None]) * stride_vcs + off_n[None, :]

mask = off_m[:, None] < slen

for idx in range(0, stride_kcs, BLOCK_N):
ks = tl.load(ks_ptrs + idx, mask=mask)
tl.store(kc_ptrs + idx, ks, mask=mask)

for idx in range(0, stride_vcs, BLOCK_N):
vs = tl.load(vs_ptrs + idx, mask=mask)
tl.store(vc_ptrs + idx, vs, mask=mask)


def fill_kv_cache(k_states: Tensor,
v_states: Tensor,
k_caches: Tensor,
v_caches: Tensor,
start_loc: Tensor,
seq_length: Tensor,
block_offsets: Tensor,
history_lengths: Sequence,
context: Any = None):
"""fill kv cache for paged attention."""
fill_cache_info = getattr(context, 'fill_cache_info', None)

if fill_cache_info is None:
batch_size = block_offsets.size(0)
block_size = k_caches.size(1)

if not isinstance(history_lengths, Tensor):
history_lengths = torch.tensor(history_lengths,
device=k_states.device)

batch_ids = torch.arange(batch_size, device=k_states.device)

first_block_ids = history_lengths // block_size
block_offsets1d = block_offsets[batch_ids, first_block_ids]

token_ids_start = history_lengths % block_size
first_seq_len = torch.minimum(seq_length, block_size - token_ids_start)

state_start = start_loc[:batch_size]
state_len = first_seq_len
cache_start = token_ids_start

# middle + last = remain
remain_seq_len = torch.maximum(seq_length.new_zeros(1),
seq_length - first_seq_len)
last_seq_len = remain_seq_len % block_size
middle_seq_len = remain_seq_len - last_seq_len
middle_block_nums = middle_seq_len // block_size
remain_block_nums = (remain_seq_len / block_size).ceil().long()

remain_state_start = [
ss + slen +
torch.arange(0, rlen, block_size, device=k_states.device)
for ss, slen, rlen in zip(state_start, first_seq_len,
remain_seq_len)
]
remain_seq_lens = [
torch.full((mid, ), block_size, device=k_states.device)
for mid in middle_block_nums
]
remain_seq_lens = [
(torch.cat([slen, last]) if last != 0 else slen)
for slen, last in zip(remain_seq_lens, last_seq_len.unsqueeze(-1))
]
remain_block_offsets1d = [
block_offsets[bid, ids:ids + ids_len]
for bid, ids, ids_len in zip(range(batch_size), first_block_ids +
1, remain_block_nums)
]

# state_start store the state index of the block
# state_len store the length to write in the block
# cache_start store the first index the write in block
# block_offsets1d store the index of block in caches
state_start = torch.cat([state_start] + remain_state_start)
state_len = torch.cat([state_len] + remain_seq_lens)
cache_start = torch.cat(
[cache_start] +
[state_start.new_zeros(state_start.size(0) - batch_size)])
block_offsets1d = torch.cat([block_offsets1d] + remain_block_offsets1d)

if context is not None:
fill_cache_info = dict()
fill_cache_info['state_start'] = state_start
fill_cache_info['state_len'] = state_len
fill_cache_info['cache_start'] = cache_start
fill_cache_info['block_offsets1d'] = block_offsets1d
context.fill_cache_info = fill_cache_info
else:
state_start = fill_cache_info['state_start']
state_len = fill_cache_info['state_len']
cache_start = fill_cache_info['cache_start']
block_offsets1d = fill_cache_info['block_offsets1d']

grid = (state_start.size(0), )
BLOCK_M = k_caches.size(-3)
BLOCK_N = min(128, k_caches.stride(-3), v_caches.stride(-3))

_fill_kv_cache_kernel[grid](
k_states,
v_states,
k_caches,
v_caches,
state_start=state_start,
state_len=state_len,
cache_start=cache_start,
block_offsets1d=block_offsets1d,
stride_kss=k_states.stride(-3),
stride_vss=v_states.stride(-3),
stride_kcs=k_caches.stride(-3),
stride_vcs=v_caches.stride(-3),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=4,
num_stages=1,
)
2 changes: 2 additions & 0 deletions lmdeploy/pytorch_poc/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _rotary_emb_fn(query_states, key_states, value_states):
head_dim=self.head_dim,
position_ids=position_ids,
past_key_value=past_key_value,
context=context,
qkv_proj=_qkv_proj,
o_proj=self.o_proj,
rotary_emb_fn=_rotary_emb_fn,
Expand Down Expand Up @@ -204,6 +205,7 @@ def _qkv_proj(hidden_states):
head_dim=self.head_dim,
position_ids=position_ids,
past_key_value=past_key_value,
context=context,
qkv_proj=_qkv_proj,
o_proj=self.o_proj,
rotary_emb_fn=_rotary_emb_fn,
Expand Down
Loading

0 comments on commit 668e30d

Please sign in to comment.