From 21d93c140d0a97af5f0c59e660cf04bd417fd424 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Dec 2023 23:55:07 -0800 Subject: [PATCH] Optimize Mixtral with expert parallelism (#2090) --- Dockerfile | 14 +- README.md | 4 - docs/source/models/supported_models.rst | 3 +- vllm/config.py | 16 +- vllm/model_executor/models/__init__.py | 4 +- vllm/model_executor/models/mixtral.py | 514 ++++++++++-------------- 6 files changed, 221 insertions(+), 334 deletions(-) diff --git a/Dockerfile b/Dockerfile index f41753aeb52a6..6ef03b843f457 100644 --- a/Dockerfile +++ b/Dockerfile @@ -41,14 +41,6 @@ ENV NVCC_THREADS=$nvcc_threads RUN python3 setup.py build_ext --inplace -# Build the megablocks library as wheel because it doesn't publish pre-built wheels. -# https://github.com/stanford-futuredata/megablocks/commit/5897cd6f254b7b3edf7a708a3a3314ecb54b6f78 -RUN apt-get install -y git && \ - git clone https://github.com/stanford-futuredata/megablocks.git && \ - cd megablocks && \ - git checkout 5897cd6f254b7b3edf7a708a3a3314ecb54b6f78 && \ - MAX_JOBS=8 NVCC_THREADS=8 python3 setup.py bdist_wheel - # image to run unit testing suite FROM dev AS test @@ -85,12 +77,8 @@ FROM vllm-base AS vllm-openai RUN --mount=type=cache,target=/root/.cache/pip \ pip install accelerate -COPY vllm vllm COPY --from=build /workspace/vllm/*.so /workspace/vllm/ -COPY --from=build /workspace/megablocks/dist/*.whl /tmp/ -RUN --mount=type=cache,target=/root/.cache/pip \ - pip install /tmp/megablocks-0.5.0-cp310-cp310-linux_x86_64.whl && \ - rm /tmp/megablocks-0.5.0-cp310-cp310-linux_x86_64.whl +COPY vllm vllm ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/README.md b/README.md index 84cadee4839fc..e4b3b50260182 100644 --- a/README.md +++ b/README.md @@ -72,10 +72,6 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get ```bash pip install vllm ``` -**NOTE:** The Mixtral model additionally requires `megablocks` which can be installed with pip or [from source](https://github.com/stanford-futuredata/megablocks): -```bash -pip install megablocks -``` ## Getting Started diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index e21cdd65d1e4f..44e4fe5ead988 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -74,8 +74,7 @@ Otherwise, please refer to :ref:`Adding a New Model ` for in Alternatively, you can raise an issue on our `GitHub `_ project. .. note:: - Currently, the ROCm version of vLLM does not support Mixtral. - Additionally, it only supports Mistral for context lengths up to 4096. + Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. .. tip:: The easiest way to check if your model is supported is to run the program below: diff --git a/vllm/config.py b/vllm/config.py index 6bafa73c7a981..eb1fee0f258b3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -120,14 +120,16 @@ def _verify_load_format(self) -> None: if load_format == "auto": load_format = "pt" - # FIXME(woosuk): This is a temporary hack. Support safetensor weights. + # TODO: Remove this check once HF updates the pt weights of Mixtral. architectures = getattr(self.hf_config, "architectures", []) - if "MixtralForCausalLM" in architectures and load_format != "pt": - logger.info( - "Currently, only 'pt' format is supported for Mixtral. " - "Changing the format to 'pt'. This may re-download the " - "weights if you have downloaded the safetensor weights.") - load_format = "pt" + if "MixtralForCausalLM" in architectures: + if load_format == "pt": + raise ValueError( + "Currently, the 'pt' format is not supported for Mixtral. " + "Please use the 'safetensors' format instead. ") + elif load_format == "auto": + # Do not fall back to pt weights. + load_format = "safetensors" self.load_format = load_format diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 5596884f3af89..ab9a1636ad13f 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -39,13 +39,15 @@ } # Models not supported by ROCm. -_ROCM_UNSUPPORTED_MODELS = ["MixtralForCausalLM"] +_ROCM_UNSUPPORTED_MODELS = [] # Models partially supported by ROCm. # Architecture -> Reason. _ROCM_PARTIALLY_SUPPORTED_MODELS = { "MistralForCausalLM": "Sliding window attention is not yet supported in ROCm's flash attention", + "MixtralForCausalLM": + "Sliding window attention is not yet supported in ROCm's flash attention", } diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 8e0a094c78353..b11e3713fd4da 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -31,22 +31,11 @@ from torch import nn from transformers import MixtralConfig -try: - import megablocks.ops as ops -except ImportError as e: - raise ImportError("MegaBlocks not found. " - "Please install it by `pip install megablocks`.") from e -try: - import stk -except ImportError as e: - raise ImportError( - "STK not found. " - "Please install it by `pip install stanford-stk`.") from e - from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, + ReplicatedLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.rotary_embedding import get_rope @@ -66,8 +55,134 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -def promote_scalar(x: torch.Tensor) -> torch.Tensor: - return x.view(1) if len(x.size()) == 0 else x +class MixtralMLP(nn.Module): + + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.num_experts = num_experts + self.ffn_dim = intermediate_size + self.hidden_dim = hidden_size + + self.w1 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + linear_method=linear_method) + self.w2 = ReplicatedLinear(self.ffn_dim, + self.hidden_dim, + bias=False, + linear_method=linear_method) + self.w3 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + linear_method=linear_method) + + # TODO: Use vllm's SiluAndMul + self.act_fn = nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + w1_out, _ = self.w1(hidden_states) + w1_out = self.act_fn(w1_out) + w3_out, _ = self.w3(hidden_states) + current_hidden_states = w1_out * w3_out + current_hidden_states, _ = self.w2(current_hidden_states) + return current_hidden_states + + +class DummyModule(nn.Module): + + def __init__(self) -> None: + super().__init__() + + self.w1 = nn.Linear(0, 0, bias=False) + self.w2 = nn.Linear(0, 0, bias=False) + self.w3 = nn.Linear(0, 0, bias=False) + + set_weight_attrs(self.w1.weight, + {"weight_loader": self.dummy_weight_loader}) + set_weight_attrs(self.w2.weight, + {"weight_loader": self.dummy_weight_loader}) + set_weight_attrs(self.w3.weight, + {"weight_loader": self.dummy_weight_loader}) + + def forward(self, *args, **kwargs) -> None: + raise NotImplementedError() + + def dummy_weight_loader(self, *args, **kwargs) -> None: # pylint: disable=unused-argument + # Noop + return + + +class MixtralMoE(nn.Module): + + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.num_total_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.num_total_experts}.") + # Split experts equally between ranks + self.expert_indicies = np.array_split(range( + self.num_total_experts), self.tp_size)[self.rank].tolist() + if not self.expert_indicies: + raise ValueError( + f"Rank {self.rank} has no experts assigned to it.") + + self.experts = nn.ModuleList([ + MixtralMLP(self.num_total_experts, + config.hidden_size, + config.intermediate_size, + linear_method=linear_method) + if idx in self.expert_indicies else DummyModule() + for idx in range(self.num_total_experts) + ]) + self.gate = ReplicatedLinear(config.hidden_size, + self.num_total_experts, + bias=False, + linear_method=linear_method) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits, _ = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum(dim=-1, + keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_( + expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states).view( + batch_size, sequence_length, hidden_dim) class MixtralAttention(nn.Module): @@ -78,6 +193,7 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, + linear_method: Optional[LinearMethodBase] = None, sliding_window: Optional[int] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -102,24 +218,26 @@ def __init__(self, self.rope_theta = rope_theta self.sliding_window = sliding_window - self.wqkv = QKVParallelLinear( + self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=False, + linear_method=linear_method, ) - self.wo = RowParallelLinear( + self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, + linear_method=linear_method, ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position, base=int(self.rope_theta), - is_neox_style=False, # weights not in HF format + is_neox_style=True, ) self.attn = PagedAttention( self.num_heads, @@ -137,310 +255,74 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - qkv, _ = self.wqkv(hidden_states) + qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, cache_event) - output, _ = self.wo(attn_output) + output, _ = self.o_proj(attn_output) return output -class BlockSparseMoE(nn.Module): - """ - Built on the paper and library Megablocks as described in - https://arxiv.org/abs/2211.15841. This implementation is - strictly equivalent to standard MoE with full capacity (no - dropped tokens). It's faster since it formulates MoE operations - in terms of block-sparse operations to accomodate imbalanced - assignments of tokens to experts, whereas standard MoE either - (1) drop tokens at the cost of reduced performance or (2) set - capacity factor to number of experts and thus waste computation - and memory on padding. - """ - - def __init__(self, hidden_dim: int, ffn_dim: int, num_experts: int, - top_k: int): - super().__init__() - self.hidden_dim = hidden_dim - self.ffn_dim = ffn_dim - self.num_experts = num_experts - self.top_k = top_k - - # gating - self.gate = nn.Linear(self.hidden_dim, - self.num_experts, - bias=False, - device=torch.cuda.current_device()) - - tp_size = get_tensor_model_parallel_world_size() - assert self.ffn_dim % tp_size == 0 - self.ffn_dim_per_partition = self.ffn_dim // tp_size - # merged expert weights, all of size (ffn_dim * n_experts, model_dim) - self.w1 = nn.Parameter( - torch.empty(self.ffn_dim_per_partition * self.num_experts, - self.hidden_dim, - device=torch.cuda.current_device())) - set_weight_attrs(self.w1, {"weight_loader": self.moe_weight_loader}) - self.w2 = nn.Parameter( - torch.empty(self.ffn_dim_per_partition * self.num_experts, - self.hidden_dim, - device=torch.cuda.current_device())) - set_weight_attrs(self.w2, {"weight_loader": self.moe_weight_loader}) - self.w3 = nn.Parameter( - torch.empty(self.ffn_dim_per_partition * self.num_experts, - self.hidden_dim, - device=torch.cuda.current_device())) - set_weight_attrs(self.w3, {"weight_loader": self.moe_weight_loader}) - - # Calculate the number of bits needed to represent the expert indices - # so that we can pass it to radix sort. - self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) - self.blocking = 128 - self.quantize_scatter_num_bits = -1 - - # Calculate the number of bits needed to represent the column indices - # in the intermediate sparse matrix. - max_column_index = (self.ffn_dim * self.num_experts) // self.blocking - self.transpose_sort_end_bit = max( - int(np.ceil(np.log2(max_column_index))), 1) - - def moe_weight_loader(self, param: nn.Parameter, - loaded_weight: torch.Tensor) -> None: - """ - Load the weights for the MoE linear layer. - """ - tp_rank = get_tensor_model_parallel_rank() - shard_size = self.ffn_dim_per_partition - loaded_weight = loaded_weight.view(self.num_experts, self.ffn_dim, -1) - loaded_weight = loaded_weight[:, shard_size * tp_rank:shard_size * - (tp_rank + 1)] - loaded_weight = loaded_weight.reshape_as(param) - param.data.copy_(loaded_weight) - - def sparse_transpose( - self, size: int, row_indices, - column_indices) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - block_columns = size[1] // self.blocking - - # Sort row indices by column indices to get the transposed matrix's - # column indices. - # - # NOTE: Our sort operation uses the same width indices as the input - # values. To avoid overflow when we have large activation matrices - # we cast to 32-bit before sorting. - _, gather_indices = ops.sort(column_indices.int(), - self.transpose_sort_end_bit) - - # There are a constant number of blocks in every row of the sparse - # matrix. A blocks offset is: - # - # row_index * blocks_per_row + column_index % blocks_per_row - # - # Once we have the block offsets ordered for transposition we can - # divide by blocks_per_row to get the transposed column indices. - column_indices_t = row_indices.gather(0, gather_indices.long()) - block_offsets_t = gather_indices.int() - - zero = torch.zeros((1, ), dtype=torch.int32, device=row_indices.device) - nnz_per_column = ops.histogram(column_indices, block_columns) - nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) - offsets_t = torch.cat([zero, nnz_per_column]) - return column_indices_t, offsets_t, block_offsets_t - - def topology(self, x: torch.Tensor, - padded_bins: torch.Tensor) -> "stk.Matrix": - padded_tokens, _ = x.size() - assert padded_tokens % self.blocking == 0 - assert self.ffn_dim_per_partition % self.blocking == 0 - - # Offsets for the sparse matrix. All rows have the - # same number of nonzero blocks dictated by the - # dimensionality of a single expert. - block_rows = padded_tokens // self.blocking - blocks_per_row = self.ffn_dim_per_partition // self.blocking - offsets = torch.arange( - 0, - block_rows * blocks_per_row + 1, - blocks_per_row, - dtype=torch.int32, - device=x.device, - ) - - # Indices for the sparse matrix. The indices for - # the intermediate matrix are dynamic depending - # on the mapping of tokens to experts. - column_indices = ops.topology(padded_bins, self.blocking, block_rows, - blocks_per_row) - - # TODO(tgale): This is unused. Remove the need for this in stk. - # For now, use meta init to save the device memory. - data = torch.empty( - column_indices.numel(), - self.blocking, - self.blocking, - dtype=x.dtype, - device="meta", - ) - shape = (padded_tokens, self.ffn_dim_per_partition * self.num_experts) - row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) - column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( - shape, row_indices, column_indices) - return stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - ) - - def indices_and_padded_bins( - self, selected_experts: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor]: - # Sort the expert ids to produce the scatter/gather - # indices for the permutation. - selected_experts = selected_experts.int() - bin_ids, indices = ops.sort(selected_experts, self.sort_end_bit) - - # Histogram the expert ids to identify the number of - # tokens routed to each expert. - tokens_per_expert = ops.histogram(selected_experts, self.num_experts) - - # Round the token counts up to the block size used in - # the matrix muliplications. Caculate the starting - # position of each bin. - padded_tokens_per_expert = ops.round_up(tokens_per_expert, - self.blocking) - padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) - padded_bins = promote_scalar(padded_bins) - - # Calculate the bin bounds for the sorted tokens. - bins = ops.inclusive_cumsum(tokens_per_expert, 0) - bins = promote_scalar(bins) - return indices, bin_ids, bins, padded_bins, tokens_per_expert - - @torch.inference_mode() - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - x: (sequence_length, model_dim) - gate_logits: (sequence_length, n_experts) - """ - # optional reshape - input_shape = x.shape - x = x.view(-1, input_shape[-1]) - - # gate_logits: (sequence_length, n_experts) - gate_logits = self.gate(x) - # all_probs: (sequence_length, n_experts) and upcast for softmax - all_probs = F.softmax(gate_logits, dim=1, dtype=torch.float) - # weights, selected_experts: (sequence_length, top-k) - weights, selected_experts = torch.topk(all_probs, self.top_k, dim=-1) - weights /= weights.sum(dim=-1, keepdim=True) - weights = weights.flatten().to(x.dtype) - selected_experts = selected_experts.flatten() - - (indices, bin_ids, bins, padded_bins, - _) = self.indices_and_padded_bins(selected_experts) - - # Permute tokens and pad to prepare expert computation - # (top_k * sequence_length + padding, model_dim) - x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, - self.top_k) - - # Create the sparse matrix topology - with torch.no_grad(): - topo = self.topology(x, padded_bins) - - # Perform the expert computation - # First Dense x Dense -> Sparse for w1 and w3, - # (top_k * sequence_length + padding, ffn_dim * n_experts) - x = stk.Matrix( - topo.size(), - F.silu(stk.ops.sdd(x, self.w1.t(), topo).data) * - stk.ops.sdd(x, self.w3.t(), topo).data, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # Then Sparse x Dense -> Dense for w2 - # (top_k * sequence_length + padding, model_dim) - x = stk.ops.dsd(x, self.w2) - - x = tensor_model_parallel_all_reduce(x) - - # Permute back and remove padding - # (top_k * sequence_length, model_dim) - x = ops.padded_scatter( - x, - indices, - bin_ids, - weights, - bins, - padded_bins, - self.top_k, - self.quantize_scatter_num_bits, - ) - return x.view(*input_shape) - - class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) - self.attention = MixtralAttention( + self.self_attn = MixtralAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - sliding_window=config.sliding_window) - self.block_sparse_moe = BlockSparseMoE( - hidden_dim=self.hidden_size, - ffn_dim=config.intermediate_size, - num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - ) - self.attention_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + sliding_window=config.sliding_window, + linear_method=linear_method) + self.block_sparse_moe = MixtralMoE(config=config, + linear_method=linear_method) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, - x: torch.Tensor, + hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], + residual: Optional[torch.Tensor], ) -> torch.Tensor: - r = self.attention( + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( positions=positions, - hidden_states=self.attention_norm(x), + hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, cache_event=cache_event, ) - h = x + r - r = self.block_sparse_moe(self.ffn_norm(h)) - out = h + r - return out + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.block_sparse_moe(hidden_states) + return hidden_states, residual -class MixtralForCausalLM(nn.Module): + +class MixtralModel(nn.Module): def __init__( self, @@ -448,23 +330,18 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() - self.config = config - assert linear_method is None self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.tok_embeddings = VocabParallelEmbedding( + + self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) - - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.output = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) - self.layers = nn.ModuleList([ - MixtralDecoderLayer(config) + MixtralDecoderLayer(config, linear_method=linear_method) for _ in range(config.num_hidden_layers) ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -474,20 +351,42 @@ def forward( input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> SamplerOutput: - hidden_states = self.tok_embeddings(input_ids) - - # forward + hidden_states = self.embed_tokens(input_ids) + residual = None for i in range(len(self.layers)): cache_event = None if cache_events is None else cache_events[i] layer = self.layers[i] - hidden_states = layer( - positions, - hidden_states, - kv_caches[i], - input_metadata, - cache_event, - ) - hidden_states = self.norm(hidden_states) + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i], input_metadata, + cache_event, residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class MixtralForCausalLM(nn.Module): + + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = MixtralModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata, cache_events) return hidden_states def sample( @@ -495,7 +394,7 @@ def sample( hidden_states: Optional[torch.Tensor], sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - next_tokens = self.sampler(self.output.weight, hidden_states, + next_tokens = self.sampler(self.lm_head.weight, hidden_states, sampling_metadata) return next_tokens @@ -506,10 +405,11 @@ def load_weights(self, revision: Optional[str] = None): stacked_params_mapping = [ # (param_name, shard_name, shard_id) - ("wqkv", "wq", "q"), - ("wqkv", "wk", "k"), - ("wqkv", "wv", "v"), + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision):