Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 667917349
Change-Id: Id06f6e2783009db5a74511d7fa0374eebf47a309
  • Loading branch information
Sax Authors authored and copybara-github committed Aug 27, 2024
1 parent 8ee3f1e commit 22c76fe
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 11 deletions.
1 change: 1 addition & 0 deletions saxml/server/pax/lm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,6 @@ pytype_strict_library(
":layers",
"//third_party/py/praxis:pax_fiddle",
"//third_party/py/praxis/layers",
"//third_party/py/praxis/layers:multi_query_attention",
],
)
61 changes: 50 additions & 11 deletions saxml/server/pax/lm/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,29 @@
# limitations under the License.
"""Customize transformer models for sax."""

from typing import Optional, Sequence, Union

from praxis import layers
from praxis import pax_fiddle
from praxis.layers import multi_query_attention
from saxml.server.pax.lm import layers as sax_layers


def gemma(
vocab_size,
model_dims,
hidden_dims,
num_layers,
num_heads,
dim_per_head,
use_mqa,
vocab_size: int,
model_dims: int,
hidden_dims: int,
num_layers: int,
num_heads: int,
dim_per_head: int,
use_mqa: bool,
num_kv_heads: int = 1,
attn_logit_softcap: float | None = None,
final_logit_softcap: float | None = None,
use_post_attn_norm: bool = False,
use_post_ffn_norm: bool = False,
scale_query_by_dim_per_head: bool = True,
sliding_window_sizes: Optional[Union[int, Sequence[Optional[int]]]] = None,
chunked_one_step_attn_num_seq_split=1,
chunked_ffn_num_seq_split=1,
) -> pax_fiddle.Config[layers.TransformerLm]:
Expand All @@ -36,22 +46,34 @@ def gemma(
model_dims: Model dimension.
hidden_dims: Hidden dimension for the ffw layer.
num_layers: Number of layers.
num_heads: Number of heads.
num_heads: Number of heads for query.
dim_per_head: Dimension per head.
use_mqa: Whether use Multi-Query Attention.
num_kv_heads: Number of heads for key and value.
attn_logit_softcap: Softcap for attention logit.
final_logit_softcap: Softcap for final logit.
use_post_attn_norm: Whether to use post attention norm.
use_post_ffn_norm: Whether to use post ffn norm.
scale_query_by_dim_per_head: Whether to scale query by dim_per_head.
Otherwise, it is scaled by hidden_dim // num_heads.
sliding_window_sizes: Sliding window sizes for local attention.
chunked_one_step_attn_num_seq_split: split attention computation in chunks.
chunked_ffn_num_seq_split: chunk ff weight computation.
Returns:
TransformerLm for Gmini.
"""
if num_kv_heads > 1:
assert use_mqa, 'num_kv_heads > 1 is only supported with MQA.'

model_p = pax_fiddle.Config(layers.TransformerLm)
model_p.vocab_size = vocab_size
model_p.model_dims = model_dims
model_p.softmax_tpl = pax_fiddle.Config(
layers.embedding_softmax.NClassMajorSharedEmbeddingSoftmax,
scale_sqrt_depth=True,
use_bias=False,
soft_cap_logits=final_logit_softcap,
)
model_p.position_emb_tpl = None
ln_tpl = pax_fiddle.Config(
Expand All @@ -67,13 +89,24 @@ def gemma(
stacked_transformer_tpl.num_layers = num_layers
stacked_transformer_tpl.num_heads = num_heads
stacked_transformer_tpl.dim_per_head = dim_per_head
if sliding_window_sizes is not None:
stacked_transformer_tpl.local_window_size = sliding_window_sizes
transformer_layer_p = pax_fiddle.Config(layers.Transformer)
transformer_layer_p.ln_tpl = ln_tpl.clone()
transformer_layer_p.norm_policy = (
'primer_hybrid' if use_post_attn_norm else 'pre'
)
# Attention Layer.
if use_mqa:
if sliding_window_sizes is not None:
transformer_layer_p.tr_atten_tpl = pax_fiddle.Config(
multi_query_attention.MultiQueryDotProductAttention,
num_kv_heads=num_kv_heads,
chunked_attn_num_seq_split=chunked_one_step_attn_num_seq_split,
)
elif use_mqa:
transformer_layer_p.tr_atten_tpl = pax_fiddle.Config(
sax_layers.ChunkedMQA,
num_kv_heads=1,
num_kv_heads=num_kv_heads,
chunked_one_step_attn_num_seq_split=chunked_one_step_attn_num_seq_split,
)
else:
Expand All @@ -85,7 +118,10 @@ def gemma(
transformer_layer_p.tr_atten_tpl.use_bias = False
transformer_layer_p.tr_atten_tpl.use_rotary_position_emb = True
transformer_layer_p.tr_atten_tpl.consolidate_rope_key_state = True
transformer_layer_p.tr_atten_tpl.scale_query_by_dim_per_head = True
transformer_layer_p.tr_atten_tpl.scale_query_by_dim_per_head = (
scale_query_by_dim_per_head
)
transformer_layer_p.tr_atten_tpl.atten_logit_cap = attn_logit_softcap
# FeedForward
transformer_layer_p.tr_fflayer_tpl = pax_fiddle.Config(
sax_layers.TransformerFeedForwardWithSeqSplit,
Expand All @@ -97,6 +133,9 @@ def gemma(
transformer_layer_p.tr_fflayer_tpl.activation_tpl = pax_fiddle.Config(
layers.activations.GELU,
)
transformer_layer_p.tr_fflayer_tpl.norm_policy = (
'primer_hybrid' if use_post_ffn_norm else 'pre'
)

stacked_transformer_tpl.transformer_layer_params_tpl = transformer_layer_p
model_p.stacked_transformer_tpl = stacked_transformer_tpl
Expand Down

0 comments on commit 22c76fe

Please sign in to comment.