Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for a rope extension method #6553

Merged
merged 10 commits into from
Jul 19, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,36 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
return inv_freq


class ExtendedRotaryEmbedding(RotaryEmbedding):

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base)
return self.apply_scaling(inv_freqs)

def apply_scaling(self, freqs: torch.Tensor):
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / scale_factor +
smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}


Expand Down Expand Up @@ -761,6 +791,11 @@ def get_rope(
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None:
if max_position == 131072:
# Note(simon): this is a special case for a model that doesn't
# supply rope_scaling. We should remove this once the model is
# updated.
RotaryEmbedding = ExtendedRotaryEmbedding
Copy link

@davidthomas426 davidthomas426 Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you just do

rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim, max_position, base,
                                     is_neox_style, dtype)

This could break speculative decoding, for instance, since you may want to use different RoPE impl for draft and target models.

Also, the key in _ROPE_DICT should probably indicate a model identifier of some kind to avoid a similar bug.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

Copy link

@davidthomas426 davidthomas426 Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is there something else we can key off of here to make sure to avoid false positives? Maybe base?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to do the change but not adding in model id because that's pretty intrusive and won't be needed after the proper fix. Do you think that's okay?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how common this particular max_position is, probably not that common, but this would enable it for them. Could you also key on base?

rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
else:
Expand Down
Loading