Skip to content

Commit

Permalink
faster embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Oct 7, 2023
1 parent 6edc78c commit 996ac70
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
8 changes: 8 additions & 0 deletions lmdeploy/pytorch_poc/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,16 @@ def __init__(
block_offsets: List[List[int]],
history_lengths: List[int],
position_ids: torch.Tensor,
q_start_loc: torch.Tensor,
seq_length: torch.Tensor,
world_size: int = 1,
device='cuda',
):
self.block_offsets_list = block_offsets
self.history_lengths = history_lengths
self.position_ids = position_ids
self.q_start_loc = q_start_loc
self.seq_length = seq_length
self.world_size = world_size

# padding zero
Expand Down Expand Up @@ -374,6 +378,8 @@ def _tp_model_loop(
block_offsets=inputs['block_offsets'],
history_lengths=inputs['history_lengths'],
position_ids=inputs['position_ids'],
q_start_loc=inputs['q_start_loc'],
seq_length=inputs['seq_length'],
world_size=world_size,
),
q_seq_info=(inputs['q_start_loc'], inputs['seq_length']),
Expand Down Expand Up @@ -717,6 +723,8 @@ def _model_forward(self, inputs: Dict, swap_in_map: Dict[int, int],
block_offsets=inputs['block_offsets'],
history_lengths=inputs['history_lengths'],
position_ids=inputs['position_ids'],
q_start_loc=inputs['q_start_loc'],
seq_length=inputs['seq_length'],
),
q_seq_info=(inputs['q_start_loc'], inputs['seq_length']),
)
Expand Down
23 changes: 14 additions & 9 deletions lmdeploy/pytorch_poc/patch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,11 @@ def apply_rotary_pos_emb(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor,
sin = sin.to(q.device)
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids] # [bs, 1, seq_len, dim]
sin = sin[position_ids] # [bs, 1, seq_len, dim]
seq_length = position_ids[..., -1] + 1
cos = [s[:l] for s, l in zip(cos, seq_length)]
sin = [s[:l] for s, l in zip(sin, seq_length)]
cos = torch.cat(cos, 0).unsqueeze(1)
sin = torch.cat(sin, 0).unsqueeze(1)
position_ids_1d = [ids[:l] for ids, l in zip(position_ids, seq_length)]
position_ids_1d = torch.cat(position_ids_1d)
cos = cos[position_ids_1d].unsqueeze(1)
sin = sin[position_ids_1d].unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)

Expand Down Expand Up @@ -164,9 +162,16 @@ def attention_forward_with_paged_attention(
query_states, key_states, value_states)

kv_seq_length = position_ids[..., -1] + 1
q_seq_length = kv_seq_length - kv_seq_length.new_tensor(history_lengths)
q_start_loc = q_seq_length.cumsum(0)
q_start_loc = torch.cat([q_start_loc.new_zeros(1), q_start_loc[:-1]])

q_seq_length = getattr(context, 'seq_length', None)
if q_seq_length is None:
q_seq_length = kv_seq_length - kv_seq_length.new_tensor(
history_lengths)

q_start_loc = getattr(context, 'q_start_loc', None)
if q_start_loc is None:
q_start_loc = q_seq_length.cumsum(0)
q_start_loc = torch.cat([q_start_loc.new_zeros(1), q_start_loc[:-1]])

fill_kv_cache(key_states,
value_states,
Expand Down

0 comments on commit 996ac70

Please sign in to comment.