Skip to content

Commit

Permalink
add fill cache back
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Sep 21, 2023
1 parent b478b31 commit 18768d2
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,63 @@ def get_block_offsets(self):
"""return block offsets."""
return self.block_offsets

def fill_cache(
self,
k_states: torch.Tensor,
v_states: torch.Tensor,
start_loc: torch.Tensor,
seq_length: torch.Tensor,
k_caches: torch.Tensor,
v_caches: torch.Tensor,
):
"""fill cache."""
block_size = k_caches.size(1)
block_offsets = self.block_offsets_list

history_lengths = torch.tensor(self.history_lengths)
first_free_block_offsets = history_lengths // block_size
first_token_offsets = history_lengths % block_size

for bid in range(len(history_lengths)):
loc = start_loc[bid]
seq_len = seq_length[bid]
b_offsets = block_offsets[bid]
free_offset = first_free_block_offsets[bid]
token_offset = first_token_offsets[bid]

assert 0 <= loc <= k_states.size(0)
assert 0 <= loc + seq_len <= k_states.size(0)

k_state = k_states[loc:loc + seq_len]
v_state = v_states[loc:loc + seq_len]

# fill remain(last non-full block)
block_id = b_offsets[free_offset]
fill_token_num = min(block_size - token_offset, seq_len)

assert 0 <= fill_token_num <= block_size

k_caches[block_id][token_offset:token_offset +
fill_token_num] = k_state[:fill_token_num]
v_caches[block_id][token_offset:token_offset +
fill_token_num] = v_state[:fill_token_num]

# update offset
seq_len = seq_len - fill_token_num
free_offset += 1
k_state = k_state[fill_token_num:]
v_state = v_state[fill_token_num:]

for seq_offset in range(0, seq_len, block_size):
token_num = min(seq_len - seq_offset, block_size)
block_id = b_offsets[free_offset]
k_caches[block_id][:token_num] = k_state[:token_num]
v_caches[block_id][:token_num] = v_state[:token_num]

free_offset += 1
k_state = k_state[token_num:]
v_state = v_state[token_num:]


def _update_cache_config(model_config: ModelConfig,
cache_config: CacheConfig,
Expand Down

0 comments on commit 18768d2

Please sign in to comment.