Skip to content

Commit

Permalink
Add support for a rope extension method (vllm-project#6553)
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-mo authored and gshtras committed Aug 12, 2024
1 parent 8608888 commit be817de
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 3 deletions.
26 changes: 25 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,26 @@ def __init__(
code_revision, rope_scaling)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)

if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
and getattr(self.hf_config, "rope_scaling", None) is None):
# Note(simon): this is a special case for a model that doesn't
# supply rope_scaling. We should remove this once the model is
# updated.
self.hf_config.update({"rope_scaling": {
"type": "extended",
}})

if (not self.disable_sliding_window
and self.hf_text_config.model_type == "gemma2"
and self.hf_text_config.sliding_window is not None):
print_warning_once(
"Gemma 2 uses sliding window attention for every odd layer, "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
f"({self.hf_text_config.sliding_window}).")
self.disable_sliding_window = True

self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
max_model_len=max_model_len,
Expand Down Expand Up @@ -1225,7 +1245,11 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None and rope_scaling["type"] != "su":
# The correct one should be "longrope", kept "su" here
# to be backward compatible
if rope_scaling is not None and rope_scaling["type"] not in {
"su", "longrope", "extended"
}:
if disable_sliding_window:
# TODO(robertgshaw): Find a model that supports rope_scaling
# with sliding window to see if this case should be allowed.
Expand Down
163 changes: 161 additions & 2 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,159 @@ def forward(
return query.flatten(-2), key.flatten(-2)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0


class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)

def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq

def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
print("Cache shape", cache.shape)
return cache

def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]

self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)

rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin

if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key


class GemmaRotaryEmbedding(RotaryEmbedding):

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq = 1.0 / (base**(
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() /
self.rotary_dim))
return inv_freq


class ExtendedRotaryEmbedding(RotaryEmbedding):

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base)
return self.apply_scaling(inv_freqs)

def apply_scaling(self, freqs: torch.Tensor):
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / scale_factor +
smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}


Expand Down Expand Up @@ -535,9 +688,15 @@ def get_rope(
is_neox_style, dtype)
else:
scaling_type = rope_scaling["type"]
if scaling_type != "su":
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if scaling_type not in {"su", "longrope", "extended"}:
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
if scaling_type == "extended":
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype)
elif scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
Expand Down

0 comments on commit be817de

Please sign in to comment.