Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for a rope extension method #6553

Merged
merged 10 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ def __init__(
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)

if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
and getattr(self.hf_config, "rope_scaling", None) is None):
# Note(simon): this is a special case for a model that doesn't
# supply rope_scaling. We should remove this once the model is
# updated.
self.hf_config.update({"rope_scaling": {
"type": "extended",
}})

if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None):
Expand Down
38 changes: 36 additions & 2 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,36 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
return inv_freq


class ExtendedRotaryEmbedding(RotaryEmbedding):

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base)
return self.apply_scaling(inv_freqs)

def apply_scaling(self, freqs: torch.Tensor):
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / scale_factor +
smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}


Expand Down Expand Up @@ -767,9 +797,13 @@ def get_rope(
scaling_type = rope_scaling["type"]
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if scaling_type != "su" and scaling_type != "longrope":
if scaling_type not in {"su", "longrope", "extended"}:
scaling_factor = rope_scaling["factor"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change actually fails here.

if scaling_type == "linear":
if scaling_type == "extended":
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype)
elif scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
Expand Down
Loading