diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 6160197dc19de..a2ce325fd7999 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -7,9 +7,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.fused_moe import fused_moe + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -22,7 +21,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -54,13 +52,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return router_logits -class DbrxExperts(nn.Module): - """A tensor-parallel MoE implementation for DBRX. - - Each expert's weights are sharded across all ranks and a fused MoE - kernel is used for the forward pass, and finally we reduce the outputs - across ranks. - """ +class DbrxExperts(FusedMoE): def __init__( self, @@ -68,49 +60,24 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None, ): - super().__init__() + super().__init__( + num_experts=config.ffn_config.moe_num_experts, + top_k=config.ffn_config.moe_top_k, + hidden_size=config.d_model, + intermediate_size=config.ffn_config.ffn_hidden_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=get_tensor_model_parallel_world_size(), + ) + self.config = config self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.ffn_config.moe_num_experts - self.top_k = config.ffn_config.moe_top_k self.d_model = config.d_model - self.intermediate_size = (config.ffn_config.ffn_hidden_size // + self.intermediate_size = (self.config.ffn_config.ffn_hidden_size // self.tp_size) - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - - self.router = DbrxRouter(config, self.params_dtype) - self.ws = nn.Parameter( - torch.empty( - self.num_total_experts, - 2 * self.intermediate_size, - self.d_model, - device="cuda", - dtype=self.params_dtype, - )) - self.w2s = nn.Parameter( - torch.empty( - self.num_total_experts, - self.d_model, - self.intermediate_size, - device="cuda", - dtype=self.params_dtype, - )) - - set_weight_attrs( - self.ws, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2s, - { - "weight_loader": self.weight_loader, - }, - ) - + # Define custom weight loader for dbrx model def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str): tp_rank = get_tensor_model_parallel_rank() @@ -119,13 +86,13 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) # DBRX uses GLU for each experts. # GLU has 3 linear layers: w1, v1 and w2. - if weight_name.endswith("w1"): + if weight_name.endswith("w1."): loaded_weight = torch.reshape( loaded_weight, [-1, self.intermediate_size * self.tp_size, self.d_model], ) param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :] - if weight_name.endswith("v1"): + if weight_name.endswith("v1."): loaded_weight = torch.reshape( loaded_weight, [-1, self.intermediate_size * self.tp_size, self.d_model], @@ -133,33 +100,47 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, param_data[:, shard_size:2 * shard_size, :] = loaded_weight[:, shard, :] - if weight_name.endswith("w2"): + if weight_name.endswith("w2."): loaded_weight = torch.reshape( loaded_weight, [-1, self.intermediate_size * self.tp_size, self.d_model], ).transpose(1, 2) param_data[:] = loaded_weight[:, :, shard] + +class DbrxMoE(nn.Module): + """A tensor-parallel MoE implementation for DBRX. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + config: DbrxConfig, + quant_config: Optional[QuantizationConfig] = None, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.d_model = config.d_model + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.router = DbrxRouter(config, self.params_dtype) + + self.experts = DbrxExperts(config=config, + quant_config=quant_config, + params_dtype=self.params_dtype) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape + orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.d_model) # router_logits: (num_tokens, n_experts) router_logits = self.router(hidden_states) - final_hidden_states = fused_moe( - hidden_states, - self.ws, - self.w2s, - router_logits, - self.top_k, - renormalize=True, - inplace=True, - ) - - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - - return final_hidden_states.view(num_tokens, hidden_size) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) class DbrxAttention(nn.Module): @@ -288,7 +269,7 @@ def __init__( super().__init__() self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config, quant_config) - self.ffn = DbrxExperts(config, quant_config) + self.ffn = DbrxMoE(config, quant_config) def forward( self, @@ -409,12 +390,15 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + expert_params_mapping = [( - "ws" if weight_name in ["w1", "v1"] else "w2s", - f"experts.mlp.{weight_name}", + "w13_" if weight_name in ["w1", "v1"] else "w2_", + f"mlp.{weight_name}.", ) for weight_name in ["w1", "v1", "w2"]] params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: + if name.endswith(("w1", "v1", "w2")): + name = name + ".weight" for param_name, weight_name in expert_params_mapping: if weight_name not in name: continue