Skip to content

Commit

Permalink
Added support for YARN scaling (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Nov 20, 2023
1 parent 097b725 commit 2b2752d
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 7 deletions.
2 changes: 1 addition & 1 deletion server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ install: gen-server install-torch

run-dev:
# SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve meta-llama/Llama-2-7b-hf --sharded
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve mistralai/Mistral-7B-Instruct-v0.1 --sharded
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve NousResearch/Yarn-Mistral-7b-128k --sharded
# SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 lorax_server/cli.py serve alexsherstinsky/Mistral-7B-v0.1-sharded --sharded

export-requirements:
Expand Down
114 changes: 108 additions & 6 deletions server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import os
import torch
import torch.distributed
Expand Down Expand Up @@ -549,20 +550,30 @@ def static(cls, config, dim, base, device):
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
rope_scaling = rope_scaling.copy()
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
rope_type = rope_scaling.pop("type")
if rope_type == "linear":
pass
elif rope_scaling["type"] == "dynamic":
elif rope_type == "dynamic":
return DynamicPositionRotaryEmbedding(
dim=dim,
max_position_embeddings=config.max_position_embeddings,
base=base,
device=inv_freq.device,
scaling_factor=scaling_factor,
)
elif rope_type == "yarn":
return YarnPositionRotaryEmbedding(
dim=dim,
max_position_embeddings=config.max_position_embeddings,
base=base,
device=inv_freq.device,
**rope_scaling,
)
else:
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
f"rope scaling type {rope_type} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor)

Expand All @@ -577,20 +588,30 @@ def load(cls, config, prefix, weights):
scaling_factor = None
rope_scaling = _get_rope_config(config)
if rope_scaling is not None:
rope_scaling = rope_scaling.copy()
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "linear":
rope_type = rope_scaling.pop("type")
if rope_type == "linear":
pass
elif rope_scaling["type"] == "dynamic":
elif rope_type == "dynamic":
return DynamicPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=config.max_position_embeddings,
base=10000.0,
device=inv_freq.device,
scaling_factor=scaling_factor,
)
elif rope_type == "yarn":
return YarnPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=config.max_position_embeddings,
base=10000.0,
device=inv_freq.device,
**rope_scaling,
)
else:
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
f"rope scaling type {rope_type} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor)

Expand Down Expand Up @@ -667,5 +688,86 @@ def _update_cos_sin_cache(self, dtype, device, seqlen):
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)

class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
"""https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py"""

def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
factor=1,
original_max_position_embeddings=2048,
extrapolation_factor=1,
attn_factor=1,
beta_fast=32,
beta_slow=1,
finetuned=True,
device=None,
):
super().__init__(_create_inv_freq(dim, base, device), factor)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.original_max_position_embeddings = original_max_position_embeddings
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.finetuned = finetuned

self.yarn(device)

def _update_cos_sin_cache(self, dtype, device, seqlen):
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen

t = torch.arange(self._seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))

self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)

def yarn(self, device):
pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (self.scaling_factor * pos_freqs)

low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask

self.inv_freq = inv_freq
self.mscale = float(get_mscale(self.scaling_factor) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation

# Inverse dim formula to find dim based on number of rotations
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))

# Find dim range bounds based on rotations
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(find_correction_dim(
low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_dim(
high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim-1) # Clamp values just in case

def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity

linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func

def get_mscale(scale=1):
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0

except ImportError:
pass

0 comments on commit 2b2752d

Please sign in to comment.