diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 812e1709b..ad7909d71 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -129,6 +129,14 @@ sparse_matmul: True capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default load_balance_loss_weight: 0.01 # weight for the load balance loss +# deepseek moe +base_moe_dim: 7168 # intermediate dimension at MoE layer (use base_mlp_dim if not DeepSeek style) +first_num_dense_layers: 0 # number of initial dense layers in the model +shared_experts: 1 +routed_scaling_factor: 1.0 # scaling factor for routing scores +routed_score_func: "" # scoring function for routing +routed_bias: False # a flag if a bias term is added for routing + # pipeline parallelism # The number of decoder layers is equal to the product of num_stages, num_layers_per_pipeline_stage and num_pipeline_repeats. # There is a tradeoff between the num_layers_per_pipeline_stage and num_pipeline_repeats: The more layers per stage the easier diff --git a/MaxText/configs/models/deepseek3-671b.yml b/MaxText/configs/models/deepseek3-671b.yml new file mode 100644 index 000000000..96e612a74 --- /dev/null +++ b/MaxText/configs/models/deepseek3-671b.yml @@ -0,0 +1,37 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for DeepSeek V3 - 671B + +base_emb_dim: 7168 +base_num_query_heads: 128 +base_num_kv_heads: 128 +base_mlp_dim: 18432 +base_moe_dim: 2048 +base_num_decoder_layers: 61 +first_num_dense_layers: 3 +head_dim: 128 +mlp_activations: ["silu","linear"] +vocab_size: 32000 # TODO(b/394635939): update after adding tokenizer +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-6 +num_experts: 256 +num_experts_per_tok: 8 +shared_experts: 1 +routed_scaling_factor: 2.5 +routed_score_func: "sigmoid" +routed_bias: True +rope_max_timescale: 10_000 +decoder_block: "deepseek" diff --git a/MaxText/layers/deepseek.py b/MaxText/layers/deepseek.py new file mode 100644 index 000000000..39cad4f6f --- /dev/null +++ b/MaxText/layers/deepseek.py @@ -0,0 +1,213 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Transformer model definition.""" +# pylint: disable=arguments-differ +# pylint: disable=no-name-in-module + + +from typing import Optional +from layers import quantizations +from layers import linears +from layers import initializers +import jax +from jax.ad_checkpoint import checkpoint_name +from jax.sharding import Mesh +from flax import linen as nn +import jax.numpy as jnp +from layers import attentions +from layers import embeddings +from layers import normalizations +from layers import models +import common_types +import max_logging + +Array = common_types.Array +Config = common_types.Config +DType = common_types.DType +Mesh = common_types.Mesh +ScanIn = common_types.ScanIn + +Embed = embeddings.Embed +Attention = attentions.Attention +RMSNorm = normalizations.RMSNorm +Quant = quantizations.AqtQuantization + +# ----------------------------------------- +# The Decoder Layer for DeepSeek v3 +# ----------------------------------------- + + +def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, decoder_positions, deterministic, model_mode): + # Normalization + lnx_rms = models.RMSNorm( + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="pre_self_attention_layer_norm", + kernel_axes=("norm",), + epsilon=cfg.normalization_layer_epsilon, + ) + lnx = lnx_rms(inputs) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed")) + + # Self-attention + attention_layer = Attention( + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + quant=quant, + kv_quant=quantizations.configure_kv_quant(cfg), + ) + + attention_lnx = attention_layer( + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) + + attention_lnx = nn.with_logical_constraint( + attention_lnx, ("activation_batch", "activation_norm_length", "activation_embed") + ) + intermediate_inputs = inputs + attention_lnx + + # Normalization + hidden_states = models.RMSNorm( + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="post_self_attention_layer_norm", + kernel_axes=("norm",), + epsilon=cfg.normalization_layer_epsilon, + )(intermediate_inputs) + hidden_states = nn.with_logical_constraint( + hidden_states, ("activation_batch", "activation_norm_length", "activation_embed") + ) + return hidden_states, intermediate_inputs + + +def post_process(cfg, layer_output): + if cfg.record_internal_nn_metrics: + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow( + "intermediates", + "activation_fraction_zero", + jnp.sum(layer_output == 0) / jnp.size(layer_output), + ) + + if cfg.scan_layers: + return layer_output, None + else: + return layer_output + + +class DeepSeekDenseLayer(nn.Module): + """DeepSeek-style dense layer.""" + + config: models.Config + mesh: Mesh + quant: Optional[Quant] = None + + @nn.compact + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): + cfg = self.config + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) + inputs = checkpoint_name(inputs, "decoder_layer_input") + + hidden_states, intermediate_inputs = self_attention_with_norm( + inputs, cfg, self.mesh, self.quant, decoder_segment_ids, decoder_positions, deterministic, model_mode + ) + mlp_lnx = linears.MlpBlock( + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="mlp", + config=cfg, + quant=self.quant, + )(hidden_states, deterministic=deterministic) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) + + layer_output = mlp_lnx + intermediate_inputs + layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = nn.with_logical_constraint( + layer_output, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return post_process(self.config, layer_output) + + +class DeepSeekMoELayer(nn.Module): + """DeepSeek-style MoE layer.""" + + config: models.Config + mesh: Mesh + quant: Optional[Quant] = None + + @nn.compact + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): + cfg = self.config + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) + inputs = checkpoint_name(inputs, "decoder_layer_input") + + hidden_states, intermediate_inputs = self_attention_with_norm( + inputs, self.config, self.mesh, self.quant, decoder_segment_ids, decoder_positions, deterministic, model_mode + ) + + mlp_lnx = linears.DeepSeekMoeBlock( + config=cfg, + mesh=self.mesh, + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + quant=self.quant, + )(hidden_states) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) + + layer_output = mlp_lnx + intermediate_inputs + layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = nn.with_logical_constraint( + layer_output, + ("activation_batch", "activation_norm_length", "activation_embed"), + ) + return post_process(cfg, layer_output) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 1820a26bc..49bb8416f 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -91,7 +91,8 @@ class DenseGeneral(nn.Module): weight_dtype: the dtype of the weights (default: float32). dtype: the dtype of the computation (default: float32). kernel_init: initializer function for the weight matrix. - use_bias: whether to add bias in linear transformation + use_bias: whether to add bias in linear transformation. + bias_norm: whether to add normalization before adding bias. quant: quantization config, defaults to None implying no quantization. """ @@ -103,6 +104,7 @@ class DenseGeneral(nn.Module): kernel_axes: Tuple[Optional[str], ...] = () quant: Optional[Quant] = None use_bias: bool = False + bias_norm: str = "" matmul_precision: str = "default" @nn.compact @@ -165,6 +167,9 @@ def compute_dot_general(inputs, kernel, axis, contract_ind): self.weight_dtype, ) bias = jnp.asarray(bias, self.dtype) + + if self.bias_norm: + output = _convert_to_activation_function(self.bias_norm)(output) output += bias return output @@ -198,7 +203,7 @@ class MlpBlock(nn.Module): quant: Optional[Quant] = None def get_norm_layer(self): - if self.config.decoder_block in ("default", "llama2", "mistral", "gemma"): + if self.config.decoder_block in ("default", "llama2", "mistral", "gemma", "deepseek"): return RMSNorm elif self.config.decoder_block == "gpt3": from layers import gpt3 @@ -292,6 +297,7 @@ class MoeBlock(nn.Module): mesh: Mesh, device mesh. kernel_init: Kernel function, passed to the dense layers. kernel_axes: Tuple with axes to apply kernel function. + intermediate_dim: Intermediate dimension of MoE. weight_dtype: Type for the weights. dtype: Type for the dense layer. quant: Optional quantization config, no quantization if None. @@ -303,6 +309,7 @@ class MoeBlock(nn.Module): mesh: Mesh kernel_init: NdInitializer kernel_axes: Tuple[Optional[str], ...] + intermediate_dim: int = 2048 weight_dtype: DType = jnp.float32 dtype: DType = jnp.float32 quant: Optional[Quant] = None @@ -371,7 +378,11 @@ def permute(self, inputs, gate_logits): inputs_shape = inputs.shape inputs_2d = jnp.reshape(inputs, (inputs_shape[0] * inputs_shape[1], inputs_shape[2])) weights, selected_experts = jax.lax.top_k(gate_logits, self.num_experts_per_tok) - weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1).astype(self.dtype) + if self.config.decoder_block == "deepseek": + weights /= weights.sum(-1, keepdims=True) + weights *= self.config.routed_scaling_factor + else: + weights = jax.nn.softmax(weights.astype(jnp.float32), axis=-1).astype(self.dtype) flatten_selected_experts = jnp.ravel(selected_experts) sorted_selected_experts = jnp.argsort(flatten_selected_experts) sorted_indices = sorted_selected_experts // self.num_experts_per_tok @@ -603,14 +614,21 @@ def maybe_all_gather_kernel_weight_in_expert_parallelism(self, kernel, kernel_ax def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): # gate_logits: batch, length, expert gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_length", None)) - softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype) # shape of top_k_weights & top_k_indices: (batch, sequence, num_experts_per_tok) - top_k_weights, top_k_indices = jax.lax.top_k(softmax_probs, self.num_experts_per_tok) + top_k_weights, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok) + + if self.config.decoder_block == "deepseek": + top_k_weights /= top_k_weights.sum(-1, keepdims=True) + top_k_weights *= self.config.routed_scaling_factor + else: + top_k_weights = jax.nn.softmax(top_k_weights.astype(jnp.float32), axis=-1).astype(self.dtype) + + weights = self.reshape_and_update_weights(top_k_weights, top_k_indices) matmul_precision = lax.Precision(self.config.matmul_precision) if self.config.capacity_factor > 0: # token dropping if needed - dispatch_mask, combine_mask = self.generate_masks(top_k_indices, softmax_probs) + dispatch_mask, combine_mask = self.generate_masks(top_k_indices, weights) mask_axes = ("activation_batch", "activation_length", None, None) dispatch_mask = nn.with_logical_constraint(dispatch_mask, mask_axes) combine_mask = nn.with_logical_constraint(combine_mask, mask_axes) @@ -678,8 +696,6 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): ).astype(self.dtype) return output, loss else: - top_k_weights /= top_k_weights.sum(-1, keepdims=True) - weights = self.reshape_and_update_weights(top_k_weights, top_k_indices) inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) with jax.named_scope("wi_0"): layer_w0 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)( @@ -740,9 +756,12 @@ def __call__(self, inputs): kernel_init=self.kernel_init, kernel_axes=self.kernel_axes, name="gate", + use_bias=self.config.routed_bias, + bias_norm=self.config.routed_score_func, matmul_precision=self.config.matmul_precision, )(inputs) - w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts, cfg.emb_dim, cfg.mlp_dim) + + w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts, cfg.emb_dim, self.intermediate_dim) if cfg.sparse_matmul: max_logging.log("Running MoE sparse matmul implementation.") if quantizations.in_serve_mode(self.quant): @@ -753,3 +772,58 @@ def __call__(self, inputs): else: max_logging.log("Running MoE dense matmul implementation.") return self.dense_matmul(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel) + + +class DeepSeekMoeBlock(nn.Module): + """DeepSeek MoE block. + + Attributes: + config: Model configs. + mesh: Mesh, device mesh. + kernel_init: Kernel function, passed to the dense layers. + kernel_axes: Tuple with axes to apply kernel function. + weight_dtype: Type for the weights. + dtype: Type for the dense layer. + quant: Optional quantization config, no quantization if None. + """ + + config: Config + mesh: Mesh + kernel_init: NdInitializer + kernel_axes: Tuple[Optional[str], ...] + weight_dtype: DType = jnp.float32 + dtype: DType = jnp.float32 + quant: Optional[Quant] = None + + @nn.compact + def __call__(self, inputs): + cfg = self.config + routed_experts, _ = MoeBlock( + config=cfg, + num_experts=cfg.num_experts, + num_experts_per_tok=cfg.num_experts_per_tok, + mesh=self.mesh, + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + intermediate_dim=cfg.moe_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + quant=self.quant, + )(inputs) + + shared_experts = jax.numpy.zeros_like(inputs) + for index in range(cfg.shared_experts): + current_expert = MlpBlock( + intermediate_dim=cfg.moe_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name=f"shared_exp_{index}", + config=cfg, + quant=self.quant, + )(inputs) + shared_experts = shared_experts + current_expert + # average if multiple shared experts + shared_experts = shared_experts / cfg.shared_experts + return routed_experts + shared_experts diff --git a/MaxText/layers/mistral.py b/MaxText/layers/mistral.py index 5fbb9e1e1..9cb3f59c1 100644 --- a/MaxText/layers/mistral.py +++ b/MaxText/layers/mistral.py @@ -136,6 +136,7 @@ def __call__( mesh=mesh, kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), kernel_axes=("embed", None), + intermediate_dim=cfg.mlp_dim, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, quant=self.quant, diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 993b82ba2..8024f8661 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -259,51 +259,59 @@ def get_remat_policy(self): policy = None return policy - def set_remat_policy(self, block_layer, policy): - return nn.remat( # pylint: disable=invalid-name - block_layer, - prevent_cse=not self.config.scan_layers, - policy=policy, - static_argnums=(4, 5), # Deterministic and model mode are static arguments. - ) + def set_remat_policy(self, block_layers, policy): + RemattedBlockLayers = [] + for block_layer in block_layers: + layer = nn.remat( # pylint: disable=invalid-name + block_layer, + prevent_cse=not self.config.scan_layers, + policy=policy, + static_argnums=(4, 5), # Deterministic and model mode are static arguments. + ) + RemattedBlockLayers.append(layer) + return RemattedBlockLayers def get_decoder_layer(self): if self.config.decoder_block == "default": - return DecoderLayer + return [DecoderLayer] elif self.config.decoder_block == "llama2": from layers import llama2 - return llama2.LlamaDecoderLayer + return [llama2.LlamaDecoderLayer] elif self.config.decoder_block == "mistral": # TODO(ranran): update to Mistral with sliding window attention from layers import mistral - return mistral.MistralDecoderLayer + return [mistral.MistralDecoderLayer] + elif self.config.decoder_block == "deepseek": + from layers import deepseek + + return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer] elif self.config.decoder_block == "gemma": from layers import gemma - return gemma.GemmaDecoderLayer + return [gemma.GemmaDecoderLayer] elif self.config.decoder_block == "gemma2": from layers import gemma2 - return gemma2.Gemma2DecoderLayer + return [gemma2.Gemma2DecoderLayer] elif self.config.decoder_block == "gpt3": from layers import gpt3 - return gpt3.Gpt3DecoderLayer + return [gpt3.Gpt3DecoderLayer] elif self.config.decoder_block == "simple": from layers import simple_layer - return simple_layer.SimpleDecoderLayer + return [simple_layer.SimpleDecoderLayer] elif self.config.decoder_block == "simple_mlp": from layers import simple_layer - return simple_layer.SimpleMlpDecoderLayer + return [simple_layer.SimpleMlpDecoderLayer] else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") def get_norm_layer(self): - if self.config.decoder_block in ("default", "llama2", "mistral", "gemma", "gemma2", "simple", "simple_mlp"): + if self.config.decoder_block in ("default", "llama2", "mistral", "deepseek", "gemma", "gemma2", "simple", "simple_mlp"): return RMSNorm elif self.config.decoder_block == "gpt3": from layers import gpt3 @@ -338,13 +346,13 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metdata_axis_name, mes length=length, metadata_params={nn.PARTITION_NAME: metdata_axis_name}, ) - return scan_fn(config=cfg, mesh=mesh, name="layers", quant=self.quant) + return scan_fn(config=cfg, mesh=mesh, name=metdata_axis_name, quant=self.quant) def get_pipeline_stage_module(self, base_stage): cfg = self.config if cfg.set_remat_policy_on_layers_per_stage: policy = self.get_remat_policy() - base_stage = self.set_remat_policy(base_stage, policy) + base_stage = self.set_remat_policy([base_stage], policy)[0] if cfg.num_layers_per_pipeline_stage == 1: stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant) elif cfg.scan_layers: @@ -393,33 +401,80 @@ def __call__( )(decoder_positions) policy = self.get_remat_policy() - RemattedBlockLayer = self.set_remat_policy(self.decoder_layer, policy) + RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy) if cfg.using_pipeline_parallelism: - partition_spec = self.pipeline_module.get_weight_sharding( - y, decoder_segment_ids, decoder_positions, deterministic, model_mode - ) - y = self.pipeline_module( - y, decoder_segment_ids, decoder_positions, deterministic, model_mode, partition_spec=partition_spec + RemattedBlockLayer = RemattedBlockLayers[0] + base_stage = RemattedBlockLayer if cfg.set_remat_policy_on_layers_per_stage else BlockLayer + stage_module = self.get_pipeline_stage_module(base_stage, cfg, mesh) + y = pipeline.Pipeline(config=cfg, mesh=mesh, layers=stage_module, remat_policy=policy)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, ) else: if cfg.scan_layers: - y, _ = self.scan_decoder_layers(cfg, RemattedBlockLayer, cfg.num_decoder_layers, "layers", mesh)( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ) - else: - for lyr in range(cfg.num_decoder_layers): - y = RemattedBlockLayer(config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant)( + if cfg.decoder_block == "deepseek": + assert len(RemattedBlockLayers) == 2, f"Scanned layers must have a length of 2 using deepseek." + dense_layer = RemattedBlockLayers[0] + moe_layer = RemattedBlockLayers[1] + y, _ = self.scan_decoder_layers(cfg, dense_layer, cfg.first_num_dense_layers, "dense_layers", mesh)( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, ) + num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers + y, _ = self.scan_decoder_layers(cfg, moe_layer, num_moe_layers, "moe_layers", mesh)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) + else: + RemattedBlockLayer = RemattedBlockLayers[0] + y, _ = self.scan_decoder_layers(cfg, RemattedBlockLayer, cfg.num_decoder_layers, "layers", mesh)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) + else: + if cfg.decoder_block == "deepseek": + assert len(RemattedBlockLayers) == 2, f"Unscanned layers must have a length of 2 using deepseek." + dense_layer = RemattedBlockLayers[0] + moe_layer = RemattedBlockLayers[1] + for lyr in range(cfg.first_num_dense_layers): + y = dense_layer(config=cfg, mesh=mesh, name=f"dense_layers_{lyr}", quant=self.quant)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) + num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers + for lyr in range(num_moe_layers): + y = moe_layer(config=cfg, mesh=mesh, name=f"moe_layers_{lyr}", quant=self.quant)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) + else: + for lyr in range(cfg.num_decoder_layers): + y = RemattedBlockLayer(config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) y = self.get_norm_layer()( dtype=cfg.dtype, diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 67ca89cbe..ca17da976 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -149,6 +149,7 @@ def validate_keys(keys): validate_multiple_slices(keys) if keys["num_experts"] > 1: validate_megablox_parallelism(keys) + validate_deepseek_moe(keys) def validate_data_input(keys): @@ -202,6 +203,7 @@ def validate_model_name(s: str) -> bool: "mistral-7b", "mixtral-8x7b", "mixtral-8x22b", + "deepseek3-671b", "gemma-7b", "gemma-2b", "gemma2-2b", @@ -417,6 +419,7 @@ def user_init(raw_keys): raw_keys["num_query_heads"] = 2**num_head_scale * raw_keys["base_num_query_heads"] raw_keys["num_kv_heads"] = 2**num_head_scale * raw_keys["base_num_kv_heads"] raw_keys["mlp_dim"] = 2**mlp_dim_scale * raw_keys["base_mlp_dim"] + raw_keys["moe_dim"] = 2**mlp_dim_scale * raw_keys["base_moe_dim"] raw_keys["num_decoder_layers"] = 2**layer_scale * raw_keys["base_num_decoder_layers"] # This is the first command that initializes the backend - it calls @@ -671,6 +674,11 @@ def pipeline_first_axis(raw_keys): return raw_keys +def validate_deepseek_moe(raw_keys): + if raw_keys["decoder_block"] == "deepseek" and using_pipeline_parallelism(raw_keys): + raise ValueError("Currently we do not support DeepSeek MoE with pipeline parallelism.") + + def validate_megablox_parallelism(raw_keys): if ( raw_keys["sparse_matmul"] diff --git a/MaxText/tests/train_compile_test.py b/MaxText/tests/train_compile_test.py index 2e3ef8e59..31d37bc76 100644 --- a/MaxText/tests/train_compile_test.py +++ b/MaxText/tests/train_compile_test.py @@ -434,3 +434,25 @@ def test_moe_pp_bf16(self): "num_layers_per_pipeline_stage=1", ) ) + + @pytest.mark.tpu_only + def test_moe_deepseek_bf16(self): + compiled_trainstep_file = "/tmp/test_moe_deepseek_bf16.pickle" + train_compile_main( + ( + None, + "configs/base.yml", + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-256", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "model_name=deepseek3-671b", + "sparse_matmul=True", + "megablox=False", + "per_device_batch_size=4", + "max_target_length=1024", + "attention=flash", + "dtype=bfloat16", + "weight_dtype=bfloat16", + ) + )