Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cogvlm model and optimize. #8

Open
wants to merge 8 commits into
base: infer_ext
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions lmdeploy/pytorch/models/cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import torch.distributed as dist
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPast
from lmdeploy.pytorch.kernels.ascend.fused_rotary_emb import fused_rotary_emb as fused_rotary_emb_ascend
from lmdeploy.pytorch.kernels.ascend.paged_attention_fwd import paged_attention_fwd as paged_attention_fwd_ascend
from lmdeploy.pytorch.kernels.ascend.fill_kv_cache import fill_kv_cache as fill_kv_cache_ascend

from ..kernels import fill_kv_cache, fused_rotary_emb, paged_attention_fwd
from ..weight_loader.dist_utils import (colwise_split_parallelize_linear,
Expand Down Expand Up @@ -238,6 +241,162 @@ def forward(
)


class PatchedVisionExpertAttentionAscend(nn.Module):

def _contiguous_batching_forward_impl(
self,
hidden_states: torch.Tensor,
token_type_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite implementation of Attention.forward.

Add continuous batching support. Add paged attention support.
"""
context = self.context.context
q_start_loc = context.q_start_loc
q_seq_length = context.q_seq_length
kv_seq_length = context.kv_seq_length
block_offsets = context.block_offsets
max_q_seq_length = context.max_q_seq_length
num_heads = self.config.num_attention_heads // world_size
num_kv_heads = getattr(self.config, 'num_multi_query_heads',
self.config.num_attention_heads) // world_size

head_dim = self.config.hidden_size // self.config.num_attention_heads
hidden_size = num_heads * head_dim
only_has_language = context.is_decoding
if not context.is_decoding:
# for embedding splitting
if hasattr(context, 'vision_token_mask') and hasattr(
context, 'language_token_mask'):
vision_token_mask = context.vision_token_mask
language_token_mask = context.language_token_mask
only_has_language = vision_token_mask.numel() == 0
else:
only_has_language = True

def __qkv_proj(hidden_states):
"""qkv_proj."""
if only_has_language:
mixed_raw_layer = self.language_expert_query_key_value(
hidden_states)
else:
shape = list(hidden_states.shape)
shape[-1] = hidden_size + head_dim * num_kv_heads * 2
mixed_raw_layer = torch.empty(shape,
dtype=hidden_states.dtype,
device=hidden_states.device)

mixed_raw_layer[:,
vision_token_mask, :] = self.vision_expert_query_key_value(
hidden_states[:, vision_token_mask, :])
mixed_raw_layer[:,
language_token_mask, :] = self.language_expert_query_key_value(
hidden_states[:, language_token_mask, :])
query_states, key_states, value_states = torch.split(
mixed_raw_layer, [
hidden_size, head_dim * num_kv_heads,
head_dim * num_kv_heads
],
dim=-1)
return query_states, key_states, value_states

def __rotary_emb_fn(query_states, key_states, value_states):
"""rotary embedding func."""
scaling_factor = getattr(self.rotary_emb, 'scaling_factor', 1.0)
inv_freq = self.rotary_emb.inv_freq

query_states, key_states = fused_rotary_emb_ascend(
query_states[None],
key_states[None],
position_ids[None],
inv_freq=inv_freq,
scaling_factor=scaling_factor,
out_q=query_states[None],
out_k=key_states[None],
context=context)
return query_states[0], key_states[0], value_states

query_states, key_states, value_states = __qkv_proj(hidden_states)

query_states = query_states.view(-1, num_heads, head_dim)
key_states = key_states.view(-1, num_kv_heads, head_dim)
value_states = value_states.view(-1, num_kv_heads, head_dim)

query_states, key_states, value_states = __rotary_emb_fn(
query_states, key_states, value_states)

fill_kv_cache_ascend(
key_states,
value_states,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
block_offsets=block_offsets,
context=context
)

context_layer = query_states
paged_attention_fwd_ascend(
query_states,
key_states,
value_states,
past_key_value[0],
past_key_value[1],
context_layer,
block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length,
context=context
)
context_layer = context_layer.reshape(*hidden_states.shape[:-1], -1)

if only_has_language:
attn_output = self.language_expert_dense(context_layer)
else:
ctx_shape = list(context_layer.shape)
ctx_shape[-1] *= world_size
attn_output = torch.empty(ctx_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)

attn_output[:, vision_token_mask, :] = self.vision_expert_dense(
context_layer[:, vision_token_mask, :])
attn_output[:,
language_token_mask, :] = self.language_expert_dense(
context_layer[:, language_token_mask, :])

return attn_output, None, past_key_value

def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite of forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
return self._contiguous_batching_forward_impl(
hidden_states,
position_ids=position_ids,
past_key_value=past_key_value,
world_size=world_size,
)


class PatchedCogVLMModel(nn.Module):

def forward(
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/pytorch/models/module_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,12 @@
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2AttentionAscend',
})

# ascend cogvlm
ASCEND_MODULE_MAP.update({
'modeling_cogvlm.VisionExpertAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertAttentionAscend',
})

# ascend mixtral
ASCEND_MODULE_MAP.update({
'transformers.models.mixtral.modeling_mixtral.MixtralAttention':
Expand Down