Skip to content

Commit

Permalink
Add support for Llama 3 rotary embeddings (#551)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jul 23, 2024
1 parent 452ac73 commit 240079b
Showing 1 changed file with 48 additions and 1 deletion.
49 changes: 48 additions & 1 deletion server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def static(cls, config, dim, base, device, dtype):
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
rope_scaling = rope_scaling.copy()
rope_type = rope_scaling.pop("type")
rope_type = rope_scaling.pop("rope_type", rope_scaling.pop("type", None))
if rope_type == "linear":
pass
elif rope_type == "dynamic":
Expand All @@ -441,6 +441,23 @@ def static(cls, config, dim, base, device, dtype):
dtype=dtype,
scaling_factor=scaling_factor,
)
elif rope_type == "llama3":
inv_freq = apply_llama3_scaling(
inv_freq,
scaling_factor=rope_scaling["factor"],
low_freq_factor=rope_scaling["low_freq_factor"],
high_freq_factor=rope_scaling["high_freq_factor"],
original_max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
)
return cls(
inv_freq,
scaling_factor,
max_position_embeddings=config.max_position_embeddings,
device=inv_freq.device,
dtype=dtype,
)
elif rope_type == "yarn":
scaling_factor = rope_scaling["factor"]
return YarnPositionRotaryEmbedding(
Expand Down Expand Up @@ -710,3 +727,33 @@ def get_mscale(scale=1):

except ImportError:
pass


def apply_llama3_scaling(
freqs: torch.Tensor,
*,
scaling_factor: int,
low_freq_factor: int,
high_freq_factor: int,
original_max_position_embeddings: int,
):
low_freq_wavelen = original_max_position_embeddings / low_freq_factor
high_freq_wavelen = original_max_position_embeddings / high_freq_factor
new_freqs = []

for freq in freqs:
wavelen = 2 * math.pi / freq

if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scaling_factor)
else:

assert low_freq_wavelen != high_freq_wavelen
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)

return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)

0 comments on commit 240079b

Please sign in to comment.