diff --git a/vllm/config.py b/vllm/config.py index 419118375e704..e1578c0c3dbe3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -151,6 +151,15 @@ def __init__( self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072 + and getattr(self.hf_config, "rope_scaling", None) is None): + # Note(simon): this is a special case for a model that doesn't + # supply rope_scaling. We should remove this once the model is + # updated. + self.hf_config.update({"rope_scaling": { + "type": "extended", + }}) + if (not self.disable_sliding_window and self.hf_text_config.model_type == "gemma2" and self.hf_text_config.sliding_window is not None): @@ -1442,8 +1451,9 @@ def _get_and_verify_max_len( rope_scaling = getattr(hf_config, "rope_scaling", None) # The correct one should be "longrope", kept "su" here # to be backward compatible - if rope_scaling is not None and rope_scaling["type"] != "su" \ - and rope_scaling["type"] != "longrope": + if rope_scaling is not None and rope_scaling["type"] not in { + "su", "longrope", "extended" + }: if disable_sliding_window: # TODO(robertgshaw): Find a model that supports rope_scaling # with sliding window to see if this case should be allowed. diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 1285627ec3cc5..3f9573f550341 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -733,6 +733,36 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: return inv_freq +class ExtendedRotaryEmbedding(RotaryEmbedding): + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + return self.apply_scaling(inv_freqs) + + def apply_scaling(self, freqs: torch.Tensor): + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor) + new_freqs.append((1 - smooth) * freq / scale_factor + + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -767,9 +797,13 @@ def get_rope( scaling_type = rope_scaling["type"] # The correct one should be "longrope" but keep "su" here # for backward compatible - if scaling_type != "su" and scaling_type != "longrope": + if scaling_type not in {"su", "longrope", "extended"}: scaling_factor = rope_scaling["factor"] - if scaling_type == "linear": + if scaling_type == "extended": + rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype) + elif scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style,