From 05d13528b96084e53f64d601e56a03cf17adb45c Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:56:05 +0200 Subject: [PATCH] Add RoPE scaling for Llama3.1 --- exllamav2/config.py | 29 ++++++++++++++++++++++------- exllamav2/model.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/exllamav2/config.py b/exllamav2/config.py index 6296052e..4dbcd5d3 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -10,7 +10,9 @@ T = TypeVar('T') no_default = object() -def read(input_dict: dict[str, Any], expected_type: type, keys: str | list[str], default = no_default) -> T: +def read(input_dict: dict[str, Any], expected_type: type | list[type], keys: str | list[str], default = no_default) -> T: + + expected_types = expected_type if isinstance(expected_type, list) else [expected_type] if isinstance(keys, str): keys = [keys] @@ -34,10 +36,10 @@ def read(input_dict: dict[str, Any], expected_type: type, keys: str | list[str], if expected_type == int and isinstance(x, float) and x == int(x): x = int(x) - if isinstance(x, expected_type): - return cast(T, x) - else: - raise TypeError(f"Value for {key} is not of expected type {expected_type}") + for t in expected_types: + if isinstance(x, t): + return cast(T, x) + raise TypeError(f"Value for {key} is not of expected type {expected_type}") if default != no_default: return default raise ValueError(f"Missing any of the following keys: {keys}") @@ -105,7 +107,10 @@ class ExLlamaV2Config: attn_logit_softcapping: float | None sliding_window: int norm_head: int | None - + l3_rope_factor: float | None + l3_rope_low_freq_factor: float | None + l3_rope_high_freq_factor: float | None + l3_rope_original_max_position_embeddings: int | None checkpoint_fused_mlp: bool checkpoint_offset_qzeros: bool @@ -191,10 +196,13 @@ def prepare(self, no_tensors: bool = False): # Vocab params self.bos_token_id = read(read_config, int, "bos_token_id", None) # 1 - self.eos_token_id = read(read_config, int, "eos_token_id", None) # 2 + self.eos_token_id = read(read_config, [int, list], "eos_token_id", None) # 2 self.pad_token_id = read(read_config, int, "pad_token_id", None) # 0 self.vocab_size = read(read_config, int, "vocab_size") + if isinstance(self.eos_token_id, list): + self.eos_token_id = self.eos_token_id[0] # TODO: Figure out a way to maybe use all the EOS tokens somehow + # Standard params self.initializer_range = read(read_config, float, ["initializer_range"]) @@ -287,6 +295,13 @@ def prepare(self, no_tensors: bool = False): self.alt_rope_method = "su" # if scaling_type == "yarn": # self.scale_alpha_value = factor + rope_type = rs.get("rope_type", None) + if rope_type == "llama3": + self.alt_rope_method = "llama3" + self.l3_rope_factor = rs["factor"] + self.l3_rope_low_freq_factor = rs["low_freq_factor"] + self.l3_rope_high_freq_factor = rs["high_freq_factor"] + self.l3_rope_original_max_position_embeddings = rs["original_max_position_embeddings"] # Checkpoint format (for GPTQ models) diff --git a/exllamav2/model.py b/exllamav2/model.py index 4d875f59..065f79f2 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -129,6 +129,31 @@ def get_scratch_slice(self, size_bytes): return scratch_slice + @staticmethod + def _apply_scaling( + freqs: torch.Tensor, + scale_factor: float = 8, + low_freq_factor: float = 1, + high_freq_factor: float = 4, + old_context_len: int = 8192, # original llama3 length + ): + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype = freqs.dtype, device = freqs.device) + + def prepare_sincos(self): device = _torch_device(self.device_idx) @@ -163,6 +188,19 @@ def prepare_sincos(self): inv_freq = 1.0 / (ext_factors * base ** (torch.arange(0, head_dim, 2, device = device).float() / head_dim)) + # Llama 3.1 + + elif cfg.alt_rope_method == "llama3": + + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device = device).float() / head_dim)) + inv_freq = self._apply_scaling( + inv_freq, + cfg.l3_rope_factor, + cfg.l3_rope_low_freq_factor, + cfg.l3_rope_high_freq_factor, + cfg.l3_rope_original_max_position_embeddings, + ) + # Regular else: