Skip to content

Commit

Permalink
[Bugfix] fix rope error when load models with different dtypes (vllm-…
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhen-lin authored May 17, 2024
1 parent 2614812 commit 33e0823
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 13 deletions.
44 changes: 43 additions & 1 deletion tests/kernels/test_pos_encoding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from itertools import accumulate
from itertools import accumulate, product
from typing import List, Optional

import pytest
Expand Down Expand Up @@ -207,3 +207,45 @@ def test_batched_rotary_embedding_multi_lora(
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))


@torch.inference_mode()
def test_rope_module_cache():
MAX_POSITIONS = [123, 1234]
BASES = [10000, 1000000]
ROPE_SCALINGS = [
None, {
"type": "linear",
"factor": (1, )
}, {
"type": "dynamic",
"factor": 1
}
]
settings = [
HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
ROPE_SCALINGS, DTYPES
]
rope_setting_id_map = {}
for setting in product(*settings):
head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_stype, rope_scaling, dtype)
# different settings cannot share the same rope module
assert id(rope) not in rope_setting_id_map.values()
assert all(x.dtype == dtype for x in rope.buffers())
assert all(x.dtype == dtype for x in rope.parameters())
rope_setting_id_map[str(setting)] = id(rope)

for setting in product(*settings):
head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_stype, rope_scaling, dtype)
# check if cache take effect
assert id(rope) == rope_setting_id_map[str(setting)]
33 changes: 21 additions & 12 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
Expand All @@ -62,7 +63,7 @@ def __init__(
self.is_neox_style = is_neox_style

cache = self._compute_cos_sin_cache()
cache = cache.to(torch.get_default_dtype())
cache = cache.to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
Expand Down Expand Up @@ -178,12 +179,13 @@ def __init__(
base: int,
is_neox_style: bool,
scaling_factors: Union[List[float], float],
dtype: torch.dtype,
) -> None:
if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors = scaling_factors
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
is_neox_style, dtype)

def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
Expand Down Expand Up @@ -219,10 +221,11 @@ def __init__(
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
) -> None:
self.scaling_factor = scaling_factor
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
is_neox_style, dtype)

def _compute_cos_sin_cache(self) -> torch.Tensor:
# NOTE(woosuk): self.max_position_embeddings is the original
Expand Down Expand Up @@ -299,6 +302,7 @@ def __init__(
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
Expand All @@ -314,7 +318,7 @@ def __init__(
self.mscale = float(
_yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
is_neox_style, dtype)

def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(
Expand Down Expand Up @@ -359,6 +363,7 @@ def __init__(
original_max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
short_factor: List[float],
long_factor: List[float],
short_mscale: float = 1.1,
Expand All @@ -385,14 +390,14 @@ def __init__(

short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(torch.get_default_dtype())
short_cache = short_cache.to(dtype)
self.register_buffer("short_cos_sin_cache",
short_cache,
persistent=False)

long_cache = self._compute_cos_sin_cache(max_position_embeddings,
long_factor, long_mscale)
long_cache = long_cache.to(torch.get_default_dtype())
long_cache = long_cache.to(dtype)
self.register_buffer("long_cos_sin_cache",
long_cache,
persistent=False)
Expand Down Expand Up @@ -463,7 +468,10 @@ def get_rope(
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
Expand All @@ -474,12 +482,12 @@ def get_rope(
else:
rope_scaling_args = None
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args)
rope_scaling_args, dtype)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style)
is_neox_style, dtype)
else:
scaling_type = rope_scaling["type"]
if scaling_type != "su":
Expand All @@ -488,11 +496,11 @@ def get_rope(
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor)
scaling_factor, dtype)
elif scaling_type == "dynamic":
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
scaling_factor, dtype)
elif scaling_type == "yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
Expand All @@ -505,7 +513,7 @@ def get_rope(
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
base, is_neox_style,
scaling_factor,
scaling_factor, dtype,
**extra_kwargs)
elif scaling_type == "su":
short_factor = rope_scaling["short_factor"]
Expand All @@ -519,7 +527,8 @@ def get_rope(
}
rotary_emb = Phi3SuScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, short_factor, long_factor, **extra_kwargs)
base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
Expand Down

0 comments on commit 33e0823

Please sign in to comment.