diff --git a/blogs/deepspeed-fastgen/README.md b/blogs/deepspeed-fastgen/README.md index 2585fac5d26e..e287af2540ed 100644 --- a/blogs/deepspeed-fastgen/README.md +++ b/blogs/deepspeed-fastgen/README.md @@ -233,6 +233,8 @@ We currently support the following model architectures in this alpha release of * [Phi-2](https://huggingface.co/models?other=phi-msft) * [Phi-3](https://huggingface.co/models?other=phi3) * [Qwen](https://huggingface.co/models?other=qwen) +* [Qwen2](https://huggingface.co/models?other=qwen2) +* [Qwen2-MoE](https://huggingface.co/models?other=qwen2_moe) All current models leverage [HuggingFace](https://github.com/huggingface) APIs in our backend to provide both the model weights and the model's corresponding tokenizer. diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py index c21affb9a0de..314f7f2f0485 100644 --- a/deepspeed/inference/v2/engine_factory.py +++ b/deepspeed/inference/v2/engine_factory.py @@ -23,6 +23,7 @@ Phi3Policy, QwenPolicy, Qwen2Policy, + Qwen2MoePolicy, ) from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy from .model_implementations.flat_model_helpers import make_metadata_filename, ModelMetadata @@ -126,6 +127,8 @@ def build_hf_engine(path: str, policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "qwen2": policy = Qwen2Policy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "qwen2_moe": + policy = Qwen2MoePolicy(model_config, checkpoint_engine=checkpoint_engine) else: raise ValueError(f"Unsupported model type {model_config.model_type}") diff --git a/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h index abb9e15f8f6f..2cc430ccfe34 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h +++ b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h @@ -11,5 +11,8 @@ } else if (2 == N_TOP_K) { \ constexpr int CONST_TOP_K = 2; \ __VA_ARGS__(); \ + } else if (4 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 4; \ + __VA_ARGS__(); \ } \ }() diff --git a/deepspeed/inference/v2/model_implementations/__init__.py b/deepspeed/inference/v2/model_implementations/__init__.py index e4160ab94949..3483d9348c55 100644 --- a/deepspeed/inference/v2/model_implementations/__init__.py +++ b/deepspeed/inference/v2/model_implementations/__init__.py @@ -18,3 +18,4 @@ from .phi3 import * from .qwen import * from .qwen_v2 import * +from .qwen_v2_moe import * diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/__init__.py b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/__init__.py new file mode 100644 index 000000000000..23e06a770023 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import Qwen2MoePolicy diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py new file mode 100644 index 000000000000..b4621257ff82 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/container.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from ..common_parameters import * +from ..layer_container_base import LayerContainer +''' + # HF Qwen1.5-MoE-A2.7B model looks like this: + +Qwen2MoeForCausalLM( + (model): Qwen2MoeModel( + (embed_tokens): Embedding(151936, 2048) + (layers): ModuleList( + (0-23): 24 x Qwen2MoeDecoderLayer( + (self_attn): Qwen2MoeSdpaAttention( + (q_proj): Linear(in_features=2048, out_features=2048, bias=True) + (k_proj): Linear(in_features=2048, out_features=2048, bias=True) + (v_proj): Linear(in_features=2048, out_features=2048, bias=True) + (o_proj): Linear(in_features=2048, out_features=2048, bias=False) + (rotary_emb): Qwen2MoeRotaryEmbedding() + ) + (mlp): Qwen2MoeSparseMoeBlock( + (gate): Linear(in_features=2048, out_features=60, bias=False) + (experts): ModuleList( + (0-59): 60 x Qwen2MoeMLP( + (gate_proj): Linear(in_features=2048, out_features=1408, bias=False) + (up_proj): Linear(in_features=2048, out_features=1408, bias=False) + (down_proj): Linear(in_features=1408, out_features=2048, bias=False) + (act_fn): SiLU() + ) + ) + (shared_expert): Qwen2MoeMLP( + (gate_proj): Linear(in_features=2048, out_features=5632, bias=False) + (up_proj): Linear(in_features=2048, out_features=5632, bias=False) + (down_proj): Linear(in_features=5632, out_features=2048, bias=False) + (act_fn): SiLU() + ) + (shared_expert_gate): Linear(in_features=2048, out_features=1, bias=False) + ) + (input_layernorm): Qwen2MoeRMSNorm() + (post_attention_layernorm): Qwen2MoeRMSNorm() + ) + ) + (norm): Qwen2MoeRMSNorm() + ) + (lm_head): Linear(in_features=2048, out_features=151936, bias=False) +) +''' + + +class Qwen2MoeTransformerContainer(LayerContainer): + """ + Transformer layer container for the Qwen2Moe model. + """ + qkv_w: UnfusedQKVParameter + qkv_b: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + moe_gate: MoEGatingWeightParameter + moe_mlp_1: UnfusedMoEGatedMLPParameter + moe_mlp_2: UnfusedMoEMLP2Parameter + shared_moe_mlp_1: GatedMLPParameter + shared_moe_mlp_2: MLP2Parameter + shared_moe_gate: MoEGatingWeightParameter + attn_norm_gamma: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.q_proj.bias": "qkv_b.q_params", + "self_attn.k_proj.bias": "qkv_b.k_params", + "self_attn.v_proj.bias": "qkv_b.v_params", + "self_attn.o_proj.weight": "attn_out_w.params", + "mlp.gate.weight": "moe_gate.params", + "mlp.experts.*.gate_proj.weight": "moe_mlp_1.gating_experts", + "mlp.experts.*.up_proj.weight": "moe_mlp_1.up_experts", + "mlp.experts.*.down_proj.weight": "moe_mlp_2.experts", + "mlp.shared_expert.gate_proj.weight": "shared_moe_mlp_1.gate_params", + "mlp.shared_expert.up_proj.weight": "shared_moe_mlp_1.up_params", + "mlp.shared_expert.down_proj.weight": "shared_moe_mlp_2.params", + "mlp.shared_expert_gate.weight": "shared_moe_gate.params", + "input_layernorm.weight": "attn_norm_gamma.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + } + + +class Qwen2MoeNonTransformerContainer(LayerContainer): + """ + Non-Transformer layer container for the Qwen2Moe model. + """ + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "model.norm.weight": "final_norm.params", + "lm_head.weight": "word_unembed.params", + } diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py new file mode 100644 index 000000000000..7cddbf978369 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/model.py @@ -0,0 +1,359 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...config_v2 import RaggedInferenceEngineConfig +from ...inference_utils import ActivationType, DtypeEnum +from ...model_implementations import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...modules import heuristics +from ...ragged import RaggedBatchWrapper +from ..inference_model_base import ( + DSModelImplementationConfig, + MPType, +) + +from .container import Qwen2MoeNonTransformerContainer, Qwen2MoeTransformerContainer + + +class Qwen2MoeInferenceModel(DSMoETransformerModelBase): + """ + Inference model implementation for Qwen2MoE models. + """ + + _non_transformer: Optional[Qwen2MoeNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[Qwen2MoeTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_position_embeddings + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + # TODO(ZonePG): bf16 inference results may be different from huggingface bf16, + # because in rms_norm, Qwen still use float() instead of bf16 + # if self._config.torch_dtype == torch.float16: + # return DtypeEnum.fp16 + # elif self._config.torch_dtype == torch.bfloat16: + # return DtypeEnum.bf16 + # else: + # raise NotImplementedError("Only fp16 and bf16 are supported") + return DtypeEnum.fp16 + + @property + def mlp_activation_fn(self) -> ActivationType: + return ActivationType.SiGLU + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + @property + def positional_embedding_config(self) -> Optional[RotateHalfConfig]: + return RotateHalfConfig(theta_base=self._config.rope_theta) + + """ + Inherited from `DSMoETransformerModelBase` + """ + + @property + def n_experts(self) -> int: + return self._config.num_experts + + @property + def n_top_k(self) -> int: + return self._config.num_experts_per_tok + + @property + def normalize_expert_scores(self) -> bool: + return self._config.norm_topk_prob + + def make_moe_layer(self) -> None: + """ + Instantiates the MoE layer for the model. This sets the `self.moe` attribute. + """ + sharded_dim = sharded_intermediate_dim(self.intermediate_dim // self.n_top_k, self.tp_size, self.tp_rank) + + moe_config = DSMoEConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + model_dim=self.model_dim, + intermediate_features=sharded_dim, + activation=self.mlp_activation_fn, + n_experts=self.n_experts, + top_k=self.n_top_k, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + normalize_scores=self.normalize_expert_scores, + ) + + self.moe = heuristics.instantiate_moe(moe_config, self._engine_config) + + ######### MLP 1 ######### + def make_shared_expert_mlp_1_layer(self) -> None: + """ + Instantiates the linear projection layer for the first MLP in the feedforward network. + This sets the `self.mlp_1` attribute. + """ + shard_size = sharded_intermediate_dim(self.intermediate_dim, self.tp_size, self.tp_rank) + + linear_config = DSLinearConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + in_channels=self.model_dim, + out_channels=shard_size, + activation=self.mlp_activation_fn, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.shared_expert_mlp_1 = heuristics.instantiate_linear(linear_config, self._engine_config) + + ######### MLP 2 ######### + def make_shared_expert_mlp_2_layer(self) -> None: + """ + Instantiates the linear projection layer for the second MLP in the feedforward network. + This sets the `self.mlp_2` attribute. + """ + shard_size = sharded_intermediate_dim(self.intermediate_dim, self.tp_size, self.tp_rank) + + linear_config = DSLinearConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + in_channels=shard_size, + out_channels=self.model_dim, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.shared_expert_mlp_2 = heuristics.instantiate_linear(linear_config, self._engine_config) + + ######### MLP 2 ######### + def make_shared_expert_gate_layer(self) -> None: + """ + Instantiates the linear projection layer for the second MLP in the feedforward network. + This sets the `self.mlp_2` attribute. + """ + shard_size = sharded_intermediate_dim(self.model_dim, self.tp_size, self.tp_rank) + + linear_config = DSLinearConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + in_channels=shard_size, + out_channels=8, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + ) + + self.shared_expert_gate = heuristics.instantiate_linear(linear_config, self._engine_config) + + def make_norm_layer(self) -> None: + """ + Instantiates the normalization layer for the model. This sets the `self.norm` attribute. + + TODO(cmikeh2): In the future we'll distinguish between the different norm objects, + but for now we'll just use the same one for all of them. + """ + norm_config = DSNormConfig( + max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + type=self.norm_type, + channels=self.model_dim, + residual_dtype=self.activation_dtype, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + eps=self._config.rms_norm_eps, + ) + + self.norm = heuristics.instantiate_pre_norm(norm_config, self._engine_config) + + """ + Model implementation + """ + + def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig, + base_mp_group: MPType) -> None: + """ + Base implementation for initialization. By default, this will initialize + the traditional components of a transformer model: + - Embedding + - QKV projection + - Self attention + - Attention output projection + - Feed forward network + - Normalization + - Unembedding + + Arguments: + config (DSModelImplementationConfig): Model-specific configuration. No assumptions + should be made about this config that are not closely tied to the specific + model implementation. + engine_config (RaggedInferenceEngineConfig): Engine configuration. + base_mp_group (MPType): Base communication group for Tensor-parallel inference. + """ + super().__init__(config, engine_config, base_mp_group) + + self.make_norm_layer() + self.make_qkv_layer() + self.make_attn_layer() + self.make_attn_out_layer() + self.make_moe_layer() + self.make_shared_expert_mlp_1_layer() + self.make_shared_expert_mlp_2_layer() + self.make_shared_expert_gate_layer() + self.make_embedding_layer() + self.make_unembedding_layer() + self._kv_cache_config = None + + """ + Forward implementations + """ + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma, beta=None) + + shared_expert_output = self.shared_expert_mlp_1(hidden_states, cur_params.shared_moe_mlp_1, b=None) + shared_expert_output = self.shared_expert_mlp_2(shared_expert_output, cur_params.shared_moe_mlp_2, b=None) + shared_expert_gate_output = self.shared_expert_gate(hidden_states, cur_params.shared_moe_gate, b=None)[..., :1] + # shared_expert_gate_output shape[-1] is 1 + shared_expert_output.mul_(torch.sigmoid(shared_expert_gate_output)) + hidden_states = self.moe(hidden_states, ragged_batch_info, cur_params.moe_gate, cur_params.moe_mlp_1, + cur_params.moe_mlp_2) + hidden_states.add_(shared_expert_output) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma, beta=None) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, + self._non_transformer.word_unembed, + ragged_batch_info, + gamma=self._non_transformer.final_norm) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer(layer_idx, residual, hidden_states, wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/qwen_v2_moe/policy.py b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/policy.py new file mode 100644 index 000000000000..630bafe993a8 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/qwen_v2_moe/policy.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import Qwen2MoeNonTransformerContainer, Qwen2MoeTransformerContainer +from .model import Qwen2MoeInferenceModel + + +class Qwen2MoePolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> Qwen2MoeInferenceModel: + return Qwen2MoeInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + map = ContainerMap() + + transformer_containers = [Qwen2MoeTransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(Qwen2MoeNonTransformerContainer(self.model)) + + map.set_unmapped_params([]) + + return map diff --git a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py index 38c0000d7f78..bd90cbd5d697 100644 --- a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py +++ b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py @@ -42,7 +42,7 @@ def supports_config(config: DSMoEConfig) -> bool: if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: return False - if config.top_k != 1 and config.top_k != 2: + if config.top_k != 1 and config.top_k != 2 and config.top_k != 4: return False return True