Skip to content

Commit

Permalink
support internlm
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Sep 22, 2023
1 parent 18768d2 commit 6252b99
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 1 deletion.
98 changes: 98 additions & 0 deletions lmdeploy/pytorch_poc/patch/internlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple

import torch
import torch.distributed as dist
from torch import nn
from torch.distributed._tensor import DeviceMesh

from lmdeploy.pytorch_poc.dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from lmdeploy.pytorch_poc.patch.functional import (
apply_rotary_pos_emb, attention_forward_with_paged_attention)


class PatchedInternLMAttention(nn.Module):

@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['q_proj', 'k_proj', 'v_proj']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['o_proj']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)

@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs

def _contiguous_batching_forward_impl(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite implementation of LlamaAttention.forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
assert not output_attentions
context = self.context.context
history_lengths = context.history_lengths

def _rotary_emb_fn(query_states, key_states, value_states):
max_seq_len = position_ids.size(-1)
kv_seq_len = max_seq_len + max(history_lengths)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
return query_states, key_states, value_states

attn_output = attention_forward_with_paged_attention(
hidden_states,
history_lengths=history_lengths,
block_offsets=context.block_offsets,
num_heads=self.num_heads // world_size,
num_kv_heads=self.num_heads // world_size,
head_dim=self.head_dim,
position_ids=position_ids,
past_key_value=past_key_value,
q_proj=self.q_proj,
k_proj=self.k_proj,
v_proj=self.v_proj,
o_proj=self.o_proj,
rotary_emb_fn=_rotary_emb_fn,
)
return attn_output, None, past_key_value

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
return self._contiguous_batching_forward_impl(
hidden_states,
position_ids,
past_key_value,
output_attentions,
world_size=world_size,
)
17 changes: 16 additions & 1 deletion lmdeploy/pytorch_poc/patch/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,18 @@
from lmdeploy.pytorch_poc.dist_utils import partition_module, replicate_module
from lmdeploy.utils import get_logger

# llama
MODULE_MAP = {
'transformers.models.llama.modeling_llama.LlamaAttention':
'lmdeploy.pytorch_poc.patch.llama.LlamaAttention',
'transformers.models.llama.modeling_llama.LlamaModel':
'lmdeploy.pytorch_poc.patch.llama.LlamaModel',
'transformers.models.llama.modeling_llama.LlamaMLP':
'lmdeploy.pytorch_poc.patch.llama.LlamaMLP',
}

# baichuan
MODULE_MAP.update({
'modeling_baichuan.Model':
'lmdeploy.pytorch_poc.patch.llama.LlamaModel', # noqa
'modeling_baichuan.BaichuanModel':
Expand All @@ -30,7 +35,17 @@
'lmdeploy.pytorch_poc.patch.baichuan.BaichuanAttention', # noqa
'modeling_baichuan.MLP':
'lmdeploy.pytorch_poc.patch.llama.LlamaMLP', # noqa
}
})

# internlm
MODULE_MAP.update({
'modeling_internlm.InternLMAttention':
'lmdeploy.pytorch_poc.patch.internlm.PatchedInternLMAttention',
'modeling_internlm.InternLMModel':
'lmdeploy.pytorch_poc.patch.llama.LlamaModel',
'modeling_internlm.InternLMMLP':
'lmdeploy.pytorch_poc.patch.llama.LlamaMLP',
})


def _get_rewrite_qualname(origin_qualname: str) -> str:
Expand Down

0 comments on commit 6252b99

Please sign in to comment.