Skip to content

Commit

Permalink
support deepseekv2 for maca backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
Reinerzhou committed Dec 18, 2024
1 parent 1efed79 commit f695787
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/dlinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def forward(
is_decoding=is_decoding,
block_size=block_size,
attn_mask=attn_mask,
softmax_scale=self.scale,
is_unpaged_prefill=is_unpaged_prefill,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
Expand Down
33 changes: 32 additions & 1 deletion lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch
from torch import nn

from ..default.rotary_embedding import LlamaDynamicNTKScalingRotaryEmbedding
from ..default.rotary_embedding import (LlamaDynamicNTKScalingRotaryEmbedding,
YarnRotaryEmbeddingImpl)
from ..rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters,
RopeType, RotaryEmbeddingBuilder,
RotaryEmbeddingImpl, YarnParameters)
Expand Down Expand Up @@ -151,6 +152,30 @@ def __init__(
self.register_buffer('inv_freq', inv_freq_llama)


class DlinferYarnRotaryEmbeddingImpl(YarnRotaryEmbeddingImpl):
"""yarn rotary embedding implementation."""

def __init__(self,
dim: int,
base: int = 10000,
scaling_factor: float = 1.0,
original_max_position_embeddings: int = 4096,
yarn_params: YarnParameters = None):
super().__init__(dim, base, scaling_factor,
original_max_position_embeddings, yarn_params)

def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
"""forward."""
dtype = x.dtype
if self.inv_freq.device != x.device:
self.inv_freq = self.inv_freq.to(x.device)
return _rotary_embedding_fwd(position_ids,
self.inv_freq,
scaling_factor=1.0,
mscale=self.mscale,
dtype=dtype)


class DlinferRotaryEmbeddingBuilder(RotaryEmbeddingBuilder):
"""rotary embedding dlinfer builder."""

Expand All @@ -175,6 +200,12 @@ def build(
return DlinferLlama3RotaryEmbeddingImpl(
dim, base, scaling_factor, llama3_params.low_freq_factor,
llama3_params.high_freq_factor, max_position_embeddings)
elif emb_type == RopeType.Yarn:
return DlinferYarnRotaryEmbeddingImpl(dim,
base,
scaling_factor,
max_position_embeddings,
yarn_params=yarn_params)
else:
raise NotImplementedError(
f'Unsupported embedding type: {emb_type}')
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def prefill_attention(
max_q_seq_len: int,
block_size: int,
attn_mask: Sequence[Optional[Tensor]],
softmax_scale: Optional[float],
is_unpaged_prefill: Optional[bool],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
Expand All @@ -37,6 +38,7 @@ def prefill_attention(
num_q_heads,
num_kv_heads,
attn_mask,
softmax_scale=softmax_scale,
attn_output=attn_output,
)
else:
Expand All @@ -55,6 +57,7 @@ def prefill_attention(
num_q_heads,
num_kv_heads,
attn_mask,
softmax_scale=softmax_scale,
attn_output=attn_output,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
Expand All @@ -71,6 +74,7 @@ def paged_token_attention(
max_kv_seq_len,
block_offsets,
block_size,
softmax_scale: Optional[float],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
quant_bits: Optional[int],
Expand All @@ -87,6 +91,7 @@ def paged_token_attention(
max_kv_seq_len,
num_q_heads,
num_kv_heads,
softmax_scale=softmax_scale,
attn_output=attn_output,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
Expand All @@ -110,6 +115,7 @@ def paged_attention_fwd(
is_decoding: bool,
block_size: int,
attn_mask: Sequence[Optional[Tensor]] = (),
softmax_scale: Optional[float] = None,
is_unpaged_prefill: Optional[bool] = None,
kv_scales: Optional[Tensor] = None,
kv_zeros: Optional[Tensor] = None,
Expand All @@ -130,6 +136,7 @@ def paged_attention_fwd(
max_q_seq_len,
block_size,
attn_mask,
softmax_scale,
is_unpaged_prefill,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
Expand All @@ -145,6 +152,7 @@ def paged_attention_fwd(
max_kv_seq_len,
block_offsets,
block_size,
softmax_scale=softmax_scale,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
Expand Down

0 comments on commit f695787

Please sign in to comment.