diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 6b03403c84..52c03003e9 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -133,6 +133,7 @@ def forward( is_decoding=is_decoding, block_size=block_size, attn_mask=attn_mask, + softmax_scale=self.scale, is_unpaged_prefill=is_unpaged_prefill, kv_scales=kv_scales, kv_zeros=kv_zeros, diff --git a/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py b/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py index ed807d66b0..68ca78f2e4 100644 --- a/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py @@ -4,7 +4,8 @@ import torch from torch import nn -from ..default.rotary_embedding import LlamaDynamicNTKScalingRotaryEmbedding +from ..default.rotary_embedding import (LlamaDynamicNTKScalingRotaryEmbedding, + YarnRotaryEmbeddingImpl) from ..rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters, RopeType, RotaryEmbeddingBuilder, RotaryEmbeddingImpl, YarnParameters) @@ -151,6 +152,30 @@ def __init__( self.register_buffer('inv_freq', inv_freq_llama) +class DlinferYarnRotaryEmbeddingImpl(YarnRotaryEmbeddingImpl): + """yarn rotary embedding implementation.""" + + def __init__(self, + dim: int, + base: int = 10000, + scaling_factor: float = 1.0, + original_max_position_embeddings: int = 4096, + yarn_params: YarnParameters = None): + super().__init__(dim, base, scaling_factor, + original_max_position_embeddings, yarn_params) + + def forward(self, x: torch.Tensor, position_ids: torch.Tensor): + """forward.""" + dtype = x.dtype + if self.inv_freq.device != x.device: + self.inv_freq = self.inv_freq.to(x.device) + return _rotary_embedding_fwd(position_ids, + self.inv_freq, + scaling_factor=1.0, + mscale=self.mscale, + dtype=dtype) + + class DlinferRotaryEmbeddingBuilder(RotaryEmbeddingBuilder): """rotary embedding dlinfer builder.""" @@ -175,6 +200,12 @@ def build( return DlinferLlama3RotaryEmbeddingImpl( dim, base, scaling_factor, llama3_params.low_freq_factor, llama3_params.high_freq_factor, max_position_embeddings) + elif emb_type == RopeType.Yarn: + return DlinferYarnRotaryEmbeddingImpl(dim, + base, + scaling_factor, + max_position_embeddings, + yarn_params=yarn_params) else: raise NotImplementedError( f'Unsupported embedding type: {emb_type}') diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index ded85d476d..89f5796fc4 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -18,6 +18,7 @@ def prefill_attention( max_q_seq_len: int, block_size: int, attn_mask: Sequence[Optional[Tensor]], + softmax_scale: Optional[float], is_unpaged_prefill: Optional[bool], kv_scales: Optional[Tensor], kv_zeros: Optional[Tensor], @@ -37,6 +38,7 @@ def prefill_attention( num_q_heads, num_kv_heads, attn_mask, + softmax_scale=softmax_scale, attn_output=attn_output, ) else: @@ -55,6 +57,7 @@ def prefill_attention( num_q_heads, num_kv_heads, attn_mask, + softmax_scale=softmax_scale, attn_output=attn_output, kv_scales=kv_scales, kv_zeros=kv_zeros, @@ -71,6 +74,7 @@ def paged_token_attention( max_kv_seq_len, block_offsets, block_size, + softmax_scale: Optional[float], kv_scales: Optional[Tensor], kv_zeros: Optional[Tensor], quant_bits: Optional[int], @@ -87,6 +91,7 @@ def paged_token_attention( max_kv_seq_len, num_q_heads, num_kv_heads, + softmax_scale=softmax_scale, attn_output=attn_output, kv_scales=kv_scales, kv_zeros=kv_zeros, @@ -110,6 +115,7 @@ def paged_attention_fwd( is_decoding: bool, block_size: int, attn_mask: Sequence[Optional[Tensor]] = (), + softmax_scale: Optional[float] = None, is_unpaged_prefill: Optional[bool] = None, kv_scales: Optional[Tensor] = None, kv_zeros: Optional[Tensor] = None, @@ -130,6 +136,7 @@ def paged_attention_fwd( max_q_seq_len, block_size, attn_mask, + softmax_scale, is_unpaged_prefill, kv_scales=kv_scales, kv_zeros=kv_zeros, @@ -145,6 +152,7 @@ def paged_attention_fwd( max_kv_seq_len, block_offsets, block_size, + softmax_scale=softmax_scale, kv_scales=kv_scales, kv_zeros=kv_zeros, quant_bits=quant_bits,