From 240079b01d71607cdd3fcbaa1aafa015760b1cc2 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 23 Jul 2024 10:05:18 -0700 Subject: [PATCH] Add support for Llama 3 rotary embeddings (#551) --- server/lorax_server/utils/layers.py | 49 ++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 2c1a5f043..9ac96e871 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -428,7 +428,7 @@ def static(cls, config, dim, base, device, dtype): rope_scaling = _get_rope_config(config) if rope_scaling is not None: rope_scaling = rope_scaling.copy() - rope_type = rope_scaling.pop("type") + rope_type = rope_scaling.pop("rope_type", rope_scaling.pop("type", None)) if rope_type == "linear": pass elif rope_type == "dynamic": @@ -441,6 +441,23 @@ def static(cls, config, dim, base, device, dtype): dtype=dtype, scaling_factor=scaling_factor, ) + elif rope_type == "llama3": + inv_freq = apply_llama3_scaling( + inv_freq, + scaling_factor=rope_scaling["factor"], + low_freq_factor=rope_scaling["low_freq_factor"], + high_freq_factor=rope_scaling["high_freq_factor"], + original_max_position_embeddings=rope_scaling[ + "original_max_position_embeddings" + ], + ) + return cls( + inv_freq, + scaling_factor, + max_position_embeddings=config.max_position_embeddings, + device=inv_freq.device, + dtype=dtype, + ) elif rope_type == "yarn": scaling_factor = rope_scaling["factor"] return YarnPositionRotaryEmbedding( @@ -710,3 +727,33 @@ def get_mscale(scale=1): except ImportError: pass + + +def apply_llama3_scaling( + freqs: torch.Tensor, + *, + scaling_factor: int, + low_freq_factor: int, + high_freq_factor: int, + original_max_position_embeddings: int, +): + low_freq_wavelen = original_max_position_embeddings / low_freq_factor + high_freq_wavelen = original_max_position_embeddings / high_freq_factor + new_freqs = [] + + for freq in freqs: + wavelen = 2 * math.pi / freq + + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scaling_factor) + else: + + assert low_freq_wavelen != high_freq_wavelen + smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq) + + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)