From 3bb7e452f5fc79d3e739c0493a0118156c9c4e3f Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 18 Jul 2024 21:37:08 +0000 Subject: [PATCH 1/9] Add support for a rope extension method --- .../model_executor/layers/rotary_embedding.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 1285627ec3cc5..37ee6f57cccab 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -733,6 +733,36 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: return inv_freq +class ExtendedRotaryEmbedding(RotaryEmbedding): + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + return self.apply_scaling(inv_freqs) + + def apply_scaling(self, freqs: torch.Tensor): + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor) + new_freqs.append((1 - smooth) * freq / scale_factor + + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -761,6 +791,10 @@ def get_rope( if key in _ROPE_DICT: return _ROPE_DICT[key] if rope_scaling is None: + if max_position == 131072: + # Note(simon): this is a special case for a model that doesn't supply + # rope_scaling. We should remove this once the model is updated. + RotaryEmbedding = ExtendedRotaryEmbedding rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype) else: From 7793d4cf8f0d0219c0fffed064dd20cf68853e7b Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 18 Jul 2024 21:38:41 +0000 Subject: [PATCH 2/9] fix lint --- vllm/model_executor/layers/rotary_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 37ee6f57cccab..929847a027111 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -792,8 +792,8 @@ def get_rope( return _ROPE_DICT[key] if rope_scaling is None: if max_position == 131072: - # Note(simon): this is a special case for a model that doesn't supply - # rope_scaling. We should remove this once the model is updated. + # Note(simon): this is a special case for a model that doesn't + # supply rope_scaling. We should remove this once the model is updated. RotaryEmbedding = ExtendedRotaryEmbedding rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype) From 82c5b15b652ce8f3f93cd0095b203737b6f0d655 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 18 Jul 2024 21:51:49 +0000 Subject: [PATCH 3/9] fix lint --- vllm/model_executor/layers/rotary_embedding.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 929847a027111..58c4d91fe306c 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -792,8 +792,9 @@ def get_rope( return _ROPE_DICT[key] if rope_scaling is None: if max_position == 131072: - # Note(simon): this is a special case for a model that doesn't - # supply rope_scaling. We should remove this once the model is updated. + # Note(simon): this is a special case for a model that doesn't + # supply rope_scaling. We should remove this once the model is + # updated. RotaryEmbedding = ExtendedRotaryEmbedding rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype) From ec11712ce959a62a82b2cbf691ce8bb02030f2dc Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 18 Jul 2024 22:13:25 +0000 Subject: [PATCH 4/9] move hack to config.py --- vllm/config.py | 8 ++++++++ vllm/model_executor/layers/rotary_embedding.py | 11 +++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index de7bb3943a45f..cebc19e2075f8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -150,6 +150,14 @@ def __init__( self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + if getattr(self.hf_config, "max_position_embeddings", 0) == 131072: + # Note(simon): this is a special case for a model that doesn't + # supply rope_scaling. We should remove this once the model is + # updated. + self.hf_config.update({"rope_scaling": { + "type": "extended", + }}) + if (not self.disable_sliding_window and self.hf_text_config.model_type == "gemma2" and self.hf_text_config.sliding_window is not None): diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 58c4d91fe306c..896a83daaa24f 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -791,11 +791,6 @@ def get_rope( if key in _ROPE_DICT: return _ROPE_DICT[key] if rope_scaling is None: - if max_position == 131072: - # Note(simon): this is a special case for a model that doesn't - # supply rope_scaling. We should remove this once the model is - # updated. - RotaryEmbedding = ExtendedRotaryEmbedding rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype) else: @@ -804,7 +799,11 @@ def get_rope( # for backward compatible if scaling_type != "su" and scaling_type != "longrope": scaling_factor = rope_scaling["factor"] - if scaling_type == "linear": + if scaling_type == "extended": + rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype) + elif scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, From e82954709b612189e8fc76d3ce32476f41ad07aa Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 18 Jul 2024 22:24:26 +0000 Subject: [PATCH 5/9] comments --- vllm/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index cebc19e2075f8..81bab7610f5b8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -150,7 +150,8 @@ def __init__( self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) - if getattr(self.hf_config, "max_position_embeddings", 0) == 131072: + if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072 + and getattr(self.hf_config, "role_scaling", None) is None): # Note(simon): this is a special case for a model that doesn't # supply rope_scaling. We should remove this once the model is # updated. From d4822213320832ff56997fd9139f92f47644849b Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 18 Jul 2024 23:06:45 +0000 Subject: [PATCH 6/9] fix typo --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 6472a8d4435c6..979e1422f39e6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -152,7 +152,7 @@ def __init__( self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072 - and getattr(self.hf_config, "role_scaling", None) is None): + and getattr(self.hf_config, "rope_scaling", None) is None): # Note(simon): this is a special case for a model that doesn't # supply rope_scaling. We should remove this once the model is # updated. From 5533931eb089c43506a89e5b3659532b3a5a8ac9 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 18 Jul 2024 23:42:28 +0000 Subject: [PATCH 7/9] skip reading factor --- vllm/model_executor/layers/rotary_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 896a83daaa24f..79b27f28b1402 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -797,7 +797,7 @@ def get_rope( scaling_type = rope_scaling["type"] # The correct one should be "longrope" but keep "su" here # for backward compatible - if scaling_type != "su" and scaling_type != "longrope": + if not scaling_type in {"su", "longrope", "extended"}: scaling_factor = rope_scaling["factor"] if scaling_type == "extended": rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim, From f70e187a1a900c8d044f8e439d36b2f90b605b72 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 18 Jul 2024 23:44:52 +0000 Subject: [PATCH 8/9] fix lint --- vllm/model_executor/layers/rotary_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 79b27f28b1402..3f9573f550341 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -797,7 +797,7 @@ def get_rope( scaling_type = rope_scaling["type"] # The correct one should be "longrope" but keep "su" here # for backward compatible - if not scaling_type in {"su", "longrope", "extended"}: + if scaling_type not in {"su", "longrope", "extended"}: scaling_factor = rope_scaling["factor"] if scaling_type == "extended": rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim, From f772828b61436b44a770e3f7155713e3d8ddfc85 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Fri, 19 Jul 2024 00:10:09 +0000 Subject: [PATCH 9/9] fix another spot --- vllm/config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 979e1422f39e6..e1578c0c3dbe3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1451,8 +1451,9 @@ def _get_and_verify_max_len( rope_scaling = getattr(hf_config, "rope_scaling", None) # The correct one should be "longrope", kept "su" here # to be backward compatible - if rope_scaling is not None and rope_scaling["type"] != "su" \ - and rope_scaling["type"] != "longrope": + if rope_scaling is not None and rope_scaling["type"] not in { + "su", "longrope", "extended" + }: if disable_sliding_window: # TODO(robertgshaw): Find a model that supports rope_scaling # with sliding window to see if this case should be allowed.