Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize fill kv cache #523

Merged
merged 8 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lmdeploy/pytorch_poc/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def rowwise_parallelize_linear_fn(module: nn.Module,
dist_tensor = distribute_tensor(param, device_mesh, dist_spec)
if to_local:
dist_tensor = try_to_local(dist_tensor)
if name == 'bias':
# rowwise linear would add bias more than ones.
dist_tensor /= device_mesh.size()
dist_param = torch.nn.Parameter(dist_tensor)
module.register_parameter(name, dist_param)

Expand Down
28 changes: 21 additions & 7 deletions lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper)
from transformers.utils import WEIGHTS_INDEX_NAME, cached_file
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, cached_file

from lmdeploy.pytorch.accel import LoadNoInit
from lmdeploy.pytorch_poc.config import (CacheConfig, ModelConfig,
Expand Down Expand Up @@ -106,12 +106,16 @@ def __init__(
block_offsets: List[List[int]],
history_lengths: List[int],
position_ids: torch.Tensor,
q_start_loc: torch.Tensor,
seq_length: torch.Tensor,
world_size: int = 1,
device='cuda',
):
self.block_offsets_list = block_offsets
self.history_lengths = history_lengths
self.position_ids = position_ids
self.q_start_loc = q_start_loc
self.seq_length = seq_length
self.world_size = world_size

# padding zero
Expand Down Expand Up @@ -257,14 +261,20 @@ def _tp_model_loop(
torch_dtype=torch_dtype,
trust_remote_code=True)

torch_model_json_path = cached_file(model_path, WEIGHTS_INDEX_NAME)
with open(torch_model_json_path, mode='r') as f:
torch_model_json = json.load(f)
try:
torch_model_json_path = cached_file(model_path, WEIGHTS_INDEX_NAME)
with open(torch_model_json_path, mode='r') as f:
torch_model_json = json.load(f)

weight_map = torch_model_json['weight_map']
weight_map = torch_model_json['weight_map']

checkpoints = list(set(weight_map.values()))
checkpoints = [cached_file(model_path, ckpt) for ckpt in checkpoints]
checkpoints = list(set(weight_map.values()))
checkpoints = [
cached_file(model_path, ckpt) for ckpt in checkpoints
]
except Exception:
logger.warning(f'load failed, try load from {WEIGHTS_NAME}.')
checkpoints = [cached_file(model_path, WEIGHTS_NAME)]
patched_model = patch(
model,
extra_args=extra_args,
Expand Down Expand Up @@ -374,6 +384,8 @@ def _tp_model_loop(
block_offsets=inputs['block_offsets'],
history_lengths=inputs['history_lengths'],
position_ids=inputs['position_ids'],
q_start_loc=inputs['q_start_loc'],
seq_length=inputs['seq_length'],
world_size=world_size,
),
q_seq_info=(inputs['q_start_loc'], inputs['seq_length']),
Expand Down Expand Up @@ -717,6 +729,8 @@ def _model_forward(self, inputs: Dict, swap_in_map: Dict[int, int],
block_offsets=inputs['block_offsets'],
history_lengths=inputs['history_lengths'],
position_ids=inputs['position_ids'],
q_start_loc=inputs['q_start_loc'],
seq_length=inputs['seq_length'],
),
q_seq_info=(inputs['q_start_loc'], inputs['seq_length']),
)
Expand Down
7 changes: 3 additions & 4 deletions lmdeploy/pytorch_poc/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .alibi_pagedattention import alibi_paged_attention_fwd
from .biased_pagedattention import biased_paged_attention_fwd
from .fill_kv_cache import fill_kv_cache
from .flashattention_nopad import context_attention_fwd
from .pagedattention import paged_attention_fwd

__all__ = [
'context_attention_fwd',
'paged_attention_fwd',
'biased_paged_attention_fwd',
'alibi_paged_attention_fwd',
'context_attention_fwd', 'paged_attention_fwd',
'biased_paged_attention_fwd', 'alibi_paged_attention_fwd', 'fill_kv_cache'
]
165 changes: 165 additions & 0 deletions lmdeploy/pytorch_poc/kernels/fill_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Sequence

import torch
import triton
import triton.language as tl
from torch import Tensor


@triton.jit
def _fill_kv_cache_kernel(
k_states,
v_states,
k_caches,
v_caches,
state_start,
state_len,
cache_start,
block_offsets1d,
stride_kss, # stride of key state token
stride_vss, # stride of value state token
stride_kcs: tl.constexpr, # stride of key cache token
stride_vcs: tl.constexpr, # stride of value cache token
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
prog_id = tl.program_id(0)

stride_kb = stride_kcs * BLOCK_M
stride_vb = stride_vcs * BLOCK_M

sstart = tl.load(state_start + prog_id)
slen = tl.load(state_len + prog_id)
cstart = tl.load(cache_start + prog_id)
boffset = tl.load(block_offsets1d + prog_id)

off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)

ks_ptrs = k_states + (sstart +
off_m[:, None]) * stride_kss + off_n[None, :]
vs_ptrs = v_states + (sstart +
off_m[:, None]) * stride_vss + off_n[None, :]
kc_ptrs = k_caches + boffset * stride_kb + (
cstart + off_m[:, None]) * stride_kcs + off_n[None, :]
vc_ptrs = v_caches + boffset * stride_vb + (
cstart + off_m[:, None]) * stride_vcs + off_n[None, :]

mask = off_m[:, None] < slen

for idx in range(0, stride_kcs, BLOCK_N):
ks = tl.load(ks_ptrs + idx, mask=mask)
tl.store(kc_ptrs + idx, ks, mask=mask)

for idx in range(0, stride_vcs, BLOCK_N):
vs = tl.load(vs_ptrs + idx, mask=mask)
tl.store(vc_ptrs + idx, vs, mask=mask)


def fill_kv_cache(k_states: Tensor,
v_states: Tensor,
k_caches: Tensor,
v_caches: Tensor,
start_loc: Tensor,
seq_length: Tensor,
block_offsets: Tensor,
history_lengths: Sequence,
context: Any = None):
"""fill kv cache for paged attention."""
fill_cache_info = getattr(context, 'fill_cache_info', None)

if fill_cache_info is None:
batch_size = block_offsets.size(0)
block_size = k_caches.size(1)

if not isinstance(history_lengths, Tensor):
history_lengths = torch.tensor(history_lengths,
device=k_states.device)

batch_ids = torch.arange(batch_size, device=k_states.device)

first_block_ids = history_lengths // block_size
block_offsets1d = block_offsets[batch_ids, first_block_ids]

token_ids_start = history_lengths % block_size
first_seq_len = torch.minimum(seq_length, block_size - token_ids_start)

state_start = start_loc[:batch_size]
state_len = first_seq_len
cache_start = token_ids_start

# middle + last = remain
remain_seq_len = torch.maximum(seq_length.new_zeros(1),
seq_length - first_seq_len)
last_seq_len = remain_seq_len % block_size
middle_seq_len = remain_seq_len - last_seq_len
middle_block_nums = middle_seq_len // block_size
remain_block_nums = (remain_seq_len / block_size).ceil().long()

remain_state_start = [
ss + slen +
torch.arange(0, rlen, block_size, device=k_states.device)
for ss, slen, rlen in zip(state_start, first_seq_len,
remain_seq_len)
]
remain_seq_lens = [
torch.full((mid, ), block_size, device=k_states.device)
for mid in middle_block_nums
]
remain_seq_lens = [
(torch.cat([slen, last]) if last != 0 else slen)
for slen, last in zip(remain_seq_lens, last_seq_len.unsqueeze(-1))
]
remain_block_offsets1d = [
block_offsets[bid, ids:ids + ids_len]
for bid, ids, ids_len in zip(range(batch_size), first_block_ids +
1, remain_block_nums)
]

# state_start store the state index of the block
# state_len store the length to write in the block
# cache_start store the first index the write in block
# block_offsets1d store the index of block in caches
state_start = torch.cat([state_start] + remain_state_start)
state_len = torch.cat([state_len] + remain_seq_lens)
cache_start = torch.cat(
[cache_start] +
[state_start.new_zeros(state_start.size(0) - batch_size)])
block_offsets1d = torch.cat([block_offsets1d] + remain_block_offsets1d)

if context is not None:
fill_cache_info = dict()
fill_cache_info['state_start'] = state_start
fill_cache_info['state_len'] = state_len
fill_cache_info['cache_start'] = cache_start
fill_cache_info['block_offsets1d'] = block_offsets1d
context.fill_cache_info = fill_cache_info
else:
state_start = fill_cache_info['state_start']
state_len = fill_cache_info['state_len']
cache_start = fill_cache_info['cache_start']
block_offsets1d = fill_cache_info['block_offsets1d']

grid = (state_start.size(0), )
BLOCK_M = k_caches.size(-3)
BLOCK_N = min(128, k_caches.stride(-3), v_caches.stride(-3))

_fill_kv_cache_kernel[grid](
k_states,
v_states,
k_caches,
v_caches,
state_start=state_start,
state_len=state_len,
cache_start=cache_start,
block_offsets1d=block_offsets1d,
stride_kss=k_states.stride(-3),
stride_vss=v_states.stride(-3),
stride_kcs=k_caches.stride(-3),
stride_vcs=v_caches.stride(-3),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=4,
num_stages=1,
)
2 changes: 2 additions & 0 deletions lmdeploy/pytorch_poc/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _rotary_emb_fn(query_states, key_states, value_states):
head_dim=self.head_dim,
position_ids=position_ids,
past_key_value=past_key_value,
context=context,
qkv_proj=_qkv_proj,
o_proj=self.o_proj,
rotary_emb_fn=_rotary_emb_fn,
Expand Down Expand Up @@ -204,6 +205,7 @@ def _qkv_proj(hidden_states):
head_dim=self.head_dim,
position_ids=position_ids,
past_key_value=past_key_value,
context=context,
qkv_proj=_qkv_proj,
o_proj=self.o_proj,
rotary_emb_fn=_rotary_emb_fn,
Expand Down
Loading
Loading