Skip to content

Commit

Permalink
[sax/llama3.1] Add frequency scaling for rotary embeddings and enable…
Browse files Browse the repository at this point in the history
… for llama3 405b

PiperOrigin-RevId: 671495855
Change-Id: I34be789738f6a79d800bdc72ad0be3748450e77a
  • Loading branch information
rdzhabarov authored and copybara-github committed Sep 5, 2024
1 parent 2c79952 commit 138be5f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
41 changes: 41 additions & 0 deletions saxml/server/pax/lm/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,41 @@ def reduce_last_dim_for_quantization(t: JTensor) -> tuple[JTensor, JTensor]:
class LLaMARotaryEmbedding(embedding_softmax.RotaryPositionalEmbedding):
"""LLaMA variant of ROPE where inputs are split in a different way."""

# LLaMA3.1 ROPE scaling, see the original pytorch implementation
# https://github.com/meta-llama/llama-models/blob/301ca3a2b3b10e94ddcd1fdd2c57e52f812e1cac/models/llama3/reference_impl/model.py#L45C5-L45C18
use_scale: bool = False

def _apply_scaling_factor(self, freq):
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * jnp.pi / freq

def lower_wavelen(freq):
return freq

def bigger_or_equal_wavelen(freq):
def bigger_wavelen(freq):
return freq / scale_factor

def equal_wavelen(freq):
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
return (1 - smooth) * freq / scale_factor + smooth * freq

bigger_wavelen_cond = wavelen > low_freq_wavelen
return jax.lax.cond(
bigger_wavelen_cond, bigger_wavelen, equal_wavelen, freq)

lower_wavelen_cond = wavelen < high_freq_wavelen
return jax.lax.cond(
lower_wavelen_cond, lower_wavelen, bigger_or_equal_wavelen, freq)

def __call__(
self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
inputs: JTensor,
Expand Down Expand Up @@ -94,16 +129,22 @@ def __call__(
half_embedding_dim = self.embedding_dims // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims
fraction = jnp.repeat(fraction, 2)

timescale = (
self.min_timescale
* (self.max_timescale / self.min_timescale) ** fraction
)

if self.use_scale:
timescale = jax.vmap(self._apply_scaling_factor)(timescale)

if position is None:
seq_length = inputs.shape[1]
position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :]
position = position[:, :, jnp.newaxis, jnp.newaxis]
timescale = timescale[jnp.newaxis, jnp.newaxis, jnp.newaxis, :]
sinusoid_inp = position / timescale

sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
sign = jnp.sign(
Expand Down
15 changes: 13 additions & 2 deletions saxml/server/pax/lm/params/lm_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,17 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]:
return task_p


class BaseLLaMA31(BaseLLaMA3):
"""Base class for LLaMA3.1 models."""

def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]:
task_p = super().task()
task_p.model.lm_tpl.stacked_transformer_tpl.transformer_layer_params_tpl.tr_atten_tpl.rotary_position_emb_tpl.use_scale = (
True
)
return task_p


@quantization.for_transformer(quantize_on_the_fly=False)
class BaseLLaMATest(BaseLLaMA):
"""Small BaseLLaMA model for unit tests.
Expand Down Expand Up @@ -758,7 +769,7 @@ class LLaMA3_70BFP16x16(BaseLLaMA3):


@servable_model_registry.register
class LLaMA3_405BFP16x64(BaseLLaMA3):
class LLaMA31_405BFP16x64(BaseLLaMA31):
"""LLama3 405B FP16 partitioned for 64 chips."""

VOCABULARY_CLASS = 'LLama3Vocabulary'
Expand All @@ -777,7 +788,7 @@ class LLaMA3_405BFP16x64(BaseLLaMA3):

@servable_model_registry.register
@quantization.for_transformer(quantize_on_the_fly=False, linear_only=True)
class LLaMA3_405BLinearOnlyInt8x32(LLaMA3_405BFP16x64):
class LLaMA31_405BLinearOnlyInt8x32(LLaMA31_405BFP16x64):
"""LLama3 405B int8 linear only layer quantized partitioned for 32 chips."""

ICI_MESH_SHAPE = [1, 1, 32]
Expand Down

0 comments on commit 138be5f

Please sign in to comment.