Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onboard DeepSeek MoE with shared experts #1242

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ sparse_matmul: True
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
load_balance_loss_weight: 0.01 # weight for the load balance loss

# deepseek moe
base_moe_dim: 7168 # intermediate dimension at MoE layer (use base_mlp_dim if not DeepSeek style)
first_num_dense_layers: 0 # number of initial dense layers in the model
shared_experts: 1
routed_scaling_factor: 1.0 # scaling factor for routing scores
routed_score_func: "" # scoring function for routing
routed_bias: False # a flag if a bias term is added for routing

# pipeline parallelism
# The number of decoder layers is equal to the product of num_stages, num_layers_per_pipeline_stage and num_pipeline_repeats.
# There is a tradeoff between the num_layers_per_pipeline_stage and num_pipeline_repeats: The more layers per stage the easier
Expand Down
37 changes: 37 additions & 0 deletions MaxText/configs/models/deepseek3-671b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for DeepSeek V3 - 671B

base_emb_dim: 7168
base_num_query_heads: 128
base_num_kv_heads: 128
base_mlp_dim: 18432
base_moe_dim: 2048
base_num_decoder_layers: 61
first_num_dense_layers: 3
head_dim: 128
mlp_activations: ["silu","linear"]
vocab_size: 32000 # TODO(b/394635939): update after adding tokenizer
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-6
num_experts: 256
num_experts_per_tok: 8
shared_experts: 1
routed_scaling_factor: 2.5
routed_score_func: "sigmoid"
routed_bias: True
rope_max_timescale: 10_000
decoder_block: "deepseek"
213 changes: 213 additions & 0 deletions MaxText/layers/deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""Transformer model definition."""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module


from typing import Optional
from layers import quantizations
from layers import linears
from layers import initializers
import jax
from jax.ad_checkpoint import checkpoint_name
from jax.sharding import Mesh
from flax import linen as nn
import jax.numpy as jnp
from layers import attentions
from layers import embeddings
from layers import normalizations
from layers import models
import common_types
import max_logging

Array = common_types.Array
Config = common_types.Config
DType = common_types.DType
Mesh = common_types.Mesh
ScanIn = common_types.ScanIn

Embed = embeddings.Embed
Attention = attentions.Attention
RMSNorm = normalizations.RMSNorm
Quant = quantizations.AqtQuantization

# -----------------------------------------
# The Decoder Layer for DeepSeek v3
# -----------------------------------------


def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, decoder_positions, deterministic, model_mode):
# Normalization
lnx_rms = models.RMSNorm(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="pre_self_attention_layer_norm",
kernel_axes=("norm",),
epsilon=cfg.normalization_layer_epsilon,
)
lnx = lnx_rms(inputs)
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

# Self-attention
attention_layer = Attention(
config=cfg,
num_query_heads=cfg.num_query_heads,
num_kv_heads=cfg.num_kv_heads,
head_dim=cfg.head_dim,
max_target_length=cfg.max_target_length,
max_prefill_predict_length=cfg.max_prefill_predict_length,
attention_kernel=cfg.attention,
mesh=mesh,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
quant=quant,
kv_quant=quantizations.configure_kv_quant(cfg),
)

attention_lnx = attention_layer(
lnx,
lnx,
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
)

attention_lnx = nn.with_logical_constraint(
attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed")
)
intermediate_inputs = inputs + attention_lnx

# Normalization
hidden_states = models.RMSNorm(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="post_self_attention_layer_norm",
kernel_axes=("norm",),
epsilon=cfg.normalization_layer_epsilon,
)(intermediate_inputs)
hidden_states = nn.with_logical_constraint(
hidden_states, ("activation_batch", "activation_norm_length", "activation_embed")
)
return hidden_states, intermediate_inputs


def post_process(cfg, layer_output):
if cfg.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)

if cfg.scan_layers:
return layer_output, None
else:
return layer_output


class DeepSeekDenseLayer(nn.Module):
"""DeepSeek-style dense layer."""

config: models.Config
mesh: Mesh
quant: Optional[Quant] = None

@nn.compact
def __call__(
self,
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
):
cfg = self.config
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")

hidden_states, intermediate_inputs = self_attention_with_norm(
inputs, cfg, self.mesh, self.quant, decoder_segment_ids, decoder_positions, deterministic, model_mode
)
mlp_lnx = linears.MlpBlock(
intermediate_dim=cfg.mlp_dim,
activations=cfg.mlp_activations,
intermediate_dropout_rate=cfg.dropout_rate,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="mlp",
config=cfg,
quant=self.quant,
)(hidden_states, deterministic=deterministic)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

layer_output = mlp_lnx + intermediate_inputs
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
layer_output = nn.with_logical_constraint(
layer_output,
("activation_batch", "activation_norm_length", "activation_embed"),
)
return post_process(self.config, layer_output)


class DeepSeekMoELayer(nn.Module):
"""DeepSeek-style MoE layer."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment what are main differences of DeepSeekMoELayer from regular MoELayer?


config: models.Config
mesh: Mesh
quant: Optional[Quant] = None

@nn.compact
def __call__(
self,
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
):
cfg = self.config
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")

hidden_states, intermediate_inputs = self_attention_with_norm(
inputs, self.config, self.mesh, self.quant, decoder_segment_ids, decoder_positions, deterministic, model_mode
)

mlp_lnx = linears.DeepSeekMoeBlock(
config=cfg,
mesh=self.mesh,
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
quant=self.quant,
)(hidden_states)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

layer_output = mlp_lnx + intermediate_inputs
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
layer_output = nn.with_logical_constraint(
layer_output,
("activation_batch", "activation_norm_length", "activation_embed"),
)
return post_process(cfg, layer_output)
Loading
Loading