diff --git a/sharktank/sharktank/export_layer/export_moe.py b/sharktank/sharktank/export_layer/export_moe.py index f2c10c4b4..af4ed51d2 100644 --- a/sharktank/sharktank/export_layer/export_moe.py +++ b/sharktank/sharktank/export_layer/export_moe.py @@ -5,9 +5,12 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import torch +import torch.nn.functional as F + from iree.turbine.aot import * + from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch -from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock +from sharktank.layers.mixture_of_experts_block import MoeBlock from ..utils import cli @@ -37,8 +40,8 @@ def main(): action="store_true", ) parser.add_argument( - "--use-grok", - help="Enable to export Grok model's version of MOE block", + "--use-gelu", + help="Enable to use gelu for moe activation", action="store_true", ) @@ -46,12 +49,12 @@ def main(): bs = args.batch_size - model = PreGatherMoeBlock( + model = MoeBlock( theta=make_moe_block_theta()("blk.0"), expert_count=8, expert_used_count=2, rms_epsilon=1e-5, - use_grok=args.use_grok, + moe_activation=F.gelu if args.use_gelu else F.silu, ) fxb = FxProgramsBuilder(model) input = make_rand_torch((bs, 32, 6144)) diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index a90def3a9..fd56ec872 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -16,6 +16,6 @@ from .paged_llama_attention_block import PagedLlamaAttentionBlock from .ffn_block import FFN from .ffn_moe_block import FFNMOE -from .mixture_of_experts_block import SparseMoeBlock, PreGatherMoeBlock +from .mixture_of_experts_block import MoeBlock from .configs import * diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index 0536302cf..0746f0fa0 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -12,7 +12,7 @@ from .base import ThetaLayer from .linear import LinearLayer from ..types import Theta, DefaultPrimitiveTensor -from ..ops import einsum_2args +from ..ops import einsum_2args, elementwise __all__ = [ "FFNMOE", @@ -24,15 +24,15 @@ class PreGatherFFNMOE(ThetaLayer): def __init__( self, theta: Theta, - use_grok: bool = False, + activation=F.silu, ): super().__init__(theta) - self.use_grok = use_grok self.ffn_gate = theta.tensor("ffn_gate_exps", "weight") self.ffn_up = theta.tensor("ffn_up_exps", "weight") self.ffn_down = theta.tensor("ffn_down_exps", "weight") + self.activation = activation def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"): inputs = inputs[:, :] @@ -63,10 +63,8 @@ def forward( experts: torch.Tensor, expert_gate: torch.Tensor, ): - if self.use_grok: - ffn_gate = F.gelu(self.pre_matmul_gather(h, self.ffn_gate, experts)) - else: - ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate, experts)) + ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts) + ffn_gate = elementwise(self.activation, ffn_gate) ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts) ffn_down = self.pre_matmul_gather( diff --git a/sharktank/sharktank/layers/llama_attention_block.py b/sharktank/sharktank/layers/llama_attention_block.py index 7be8c7366..0cdb5d713 100644 --- a/sharktank/sharktank/layers/llama_attention_block.py +++ b/sharktank/sharktank/layers/llama_attention_block.py @@ -6,8 +6,6 @@ from typing import Optional -import math - import torch import torch.nn.functional as F @@ -110,7 +108,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: values = values.transpose(1, 2) # Flash attention. - attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt( + self.head_dim + ) # Apply attention mask. if attention_mask is not None: diff --git a/sharktank/sharktank/layers/mixture_of_experts_block.py b/sharktank/sharktank/layers/mixture_of_experts_block.py index f788d06f0..ddce16c55 100644 --- a/sharktank/sharktank/layers/mixture_of_experts_block.py +++ b/sharktank/sharktank/layers/mixture_of_experts_block.py @@ -16,12 +16,11 @@ from .ffn_moe_block import FFNMOE, PreGatherFFNMOE __all__ = [ - "SparseMoeBlock", - "PreGatherMoeBlock", + "MoeBlock", ] -class SparseMoeBlock(ThetaLayer): +class MoeBlock(ThetaLayer): """ This implementation considers MoE operations as block-sparse operations to support imbalanced token assignments to experts. @@ -35,108 +34,12 @@ def __init__( expert_count: int, expert_used_count: int, rms_epsilon: float, - ): - super().__init__(theta) - - # Add router gate - self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp"))) - - # Add FFN norm - self.add_module( - "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) - ) - - # Add FFN output norm - self.add_module( - "layer_output_norm", - RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon), - ) - - # Add expert_count x FFN - self.experts = nn.ModuleList( - [FFNMOE(theta, expert_idx=i) for i in range(expert_count)] - ) - - self.expert_count = expert_count - self.expert_used_count = expert_used_count - - def forward( - self, - h: torch.Tensor, - ): - ffn_input = self.ffn_norm(h) - batch_size, sequence_length, feature_dim = ffn_input.shape - ffn_input = ffn_input.view(-1, feature_dim) - - # For each token, the router calculates the router weights for all experts - # router_logits: (batch_size * sequence_length, expert_count) - router_logits = self.ffn_gate_inp(ffn_input) - router_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - - # Select top k experts from router weights - router_weights, top_k_experts = torch.topk( - router_weights, self.expert_used_count, dim=-1 - ) - router_weights /= router_weights.sum(dim=-1, keepdim=True) - router_weights = router_weights.to(ffn_input.dtype) - - moe_output = torch.zeros( - (batch_size * sequence_length, feature_dim), dtype=ffn_input.dtype - ) - - # Create an expert mask by one hot encoding the selected top k experts - # used to index which expert is to be invoked for each token - # expert_mask: (expert_count, expert_used_count, sequence_length) - expert_mask = F.one_hot(top_k_experts, num_classes=self.expert_count).permute( - 2, 1, 0 - ) - - # Iterate over all experts in the model - for expert_idx in range(self.expert_count): - expert_layer = self.experts[expert_idx] - top_k_expert_idx, token_idx = torch.where(expert_mask[expert_idx]) - - # Given the hidden states, index the tokens assigned to this expert - # and calculate the current expert's hidden state and weigh the - # output expert hidden states by the router weights, based on the - # appropriate tokens - current_expert_tokens = ffn_input[None, token_idx] - - current_expert = ( - expert_layer(current_expert_tokens) - * router_weights[token_idx, top_k_expert_idx, None] - ) - - current_expert = current_expert.reshape(-1, feature_dim) - - moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype)) - moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) - - moe_output = self.layer_output_norm(moe_output) - return h + moe_output - - -class PreGatherMoeBlock(ThetaLayer): - """ - This implementation considers MoE operations as block-sparse - operations to support imbalanced token assignments to experts. - This enables the MoE to operate at a faster rate and in full capacity without any dropped tokens - (or reduced performance). - """ - - def __init__( - self, - theta: Theta, - expert_count: int, - expert_used_count: int, - rms_epsilon: float, - use_grok: Optional[bool] = False, + moe_activation=F.silu, ): super().__init__(theta) self.expert_count = expert_count self.expert_used_count = expert_used_count - self.use_grok = use_grok # Add router gate self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp"))) @@ -146,15 +49,17 @@ def __init__( "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) ) - # Add FFN output norm layer for Grok - if self.use_grok: + # Add optional FFN output norm layer + if theta.optional_tensor("layer_output_norm") is not None: self.add_module( "layer_output_norm", RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon), ) + else: + self.add_module("layer_output_norm", torch.nn.Identity()) # Add expert_count x FFN - self.experts = PreGatherFFNMOE(theta, use_grok=self.use_grok) + self.experts = PreGatherFFNMOE(theta, activation=moe_activation) def forward( self, @@ -180,7 +85,6 @@ def forward( moe_output = self.experts(ffn_input, top_k_experts, expert_gate) moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) - if self.use_grok: - moe_output = self.layer_output_norm(moe_output) + moe_output = self.layer_output_norm(moe_output) return h + moe_output diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 123c43be7..5bc045608 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -37,7 +37,8 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, - use_grok: Optional[bool] = False, + attention_scale: Optional[float] = None, + softcap: Optional[float] = None, ): super().__init__(theta) @@ -46,7 +47,8 @@ def __init__( self.head_count = head_count self.head_dim = head_dim self.head_count_kv = head_count_kv - self.use_grok = use_grok + self.attention_scale = attention_scale + self.softcap = softcap self.add_module( "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon) @@ -56,7 +58,12 @@ def __init__( self.add_module("attn_v", LinearLayer(theta("attn_v"))) self.add_module("attn_output", LinearLayer(theta("attn_output"))) - if self.use_grok: + if theta.optional_tensor("attn_output_norm") is None: + self.add_module( + "attn_output_norm", + torch.nn.Identity(), + ) + else: self.add_module( "attn_output_norm", RMSNormLayer(theta("attn_output_norm"), epsilon=rms_epsilon), @@ -147,16 +154,16 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: keys = xk.transpose(1, 2) values = xv.transpose(1, 2) + attn_weights = ops.matmul(xq, keys.transpose(2, 3)) + if self.attention_scale is None: + attn_weights = attn_weights / math.sqrt(self.head_dim) + else: + attn_weights = attn_weights * self.attention_scale + # Flash attention. - if not self.use_grok: - attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt( - self.head_dim - ) - elif self.use_grok: - attn_weights = ops.matmul(xq, keys.transpose(2, 3)) - attn_weights = 30.0 * torch.tanh( - attn_weights * (0.08838834764831845 / 30.0) - ) + if self.softcap is not None: + attn_weights = self.softcap * torch.tanh(attn_weights / self.softcap) + self.assert_not_nan(attn_weights) # Apply attention mask. @@ -172,12 +179,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: # Project. attn_output = self.attn_output(attn_output) - - if self.use_grok: - attn_output = self.attn_output_norm(attn_output) + attn_output = self.attn_output_norm(attn_output) h = h + attn_output - return h def transact_cache_direct( diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index debeb30c7..077e4e064 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -4,9 +4,11 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import math import torch import torch.nn as nn +import torch.nn.functional as F from ...layers import * @@ -82,6 +84,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -93,16 +96,16 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): head_dim=hp.attn_head_dim, head_count_kv=hp.attention_head_count_kv, rms_epsilon=hp.attention_layer_norm_rms_epsilon, - use_grok=True, + softcap=30.0, # https://github.com/xai-org/grok-1/blob/7050ed204b8206bb8645c7b7bbef7252f79561b0/model.py#L864 ) ) - self.attn_blocks.append( - PreGatherMoeBlock( + self.moe_blocks.append( + MoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, rms_epsilon=hp.attention_layer_norm_rms_epsilon, - use_grok=True, + moe_activation=F.gelu, ) ) @@ -122,33 +125,32 @@ def prefill( self._assert_device(seq_block_ids) self._assert_device(*cache_state, dtype=self.activation_dtype) h = self.token_embedding(tokens) - h *= 78.38367176906169 + h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - embedding=self.attention_embedding, - start_index=0, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - ) - self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "PreGatherMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) + h = attn_block( + h, + embedding=self.attention_embedding, + start_index=0, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + ) + self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) + + h = moe_block(h) + self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits * 0.5773502691896257 + logits = logits / math.sqrt(3.0) return logits def decode( @@ -200,34 +202,33 @@ def decode( ) h = self.token_embedding(tokens) - h *= 78.38367176906169 + h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - start_positions=start_positions, - embedding=self.attention_embedding, - embedding_batch_mask=embedding_batch_mask, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "PreGatherMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) + h = attn_block( + h, + start_positions=start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=embedding_batch_mask, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) + + h = moe_block(h) + self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits * 0.5773502691896257 + logits = logits / math.sqrt(3.0) return logits diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index c324a79d5..344976ead 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -7,7 +7,6 @@ from typing import Optional from dataclasses import dataclass -import math from typing import Union import torch diff --git a/sharktank/sharktank/models/llama/llama_ref.py b/sharktank/sharktank/models/llama/llama_ref.py index 74ed9e8e0..9f77daa40 100644 --- a/sharktank/sharktank/models/llama/llama_ref.py +++ b/sharktank/sharktank/models/llama/llama_ref.py @@ -7,7 +7,6 @@ from typing import Optional from dataclasses import dataclass -import math import torch import torch.nn as nn @@ -230,7 +229,9 @@ def forward( values = values.transpose(1, 2) # Flash attention. - attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt( + self.head_dim + ) # Apply attention mask. if attention_mask is not None: diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index 1fc86f87d..e2995dfde 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -85,6 +85,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -98,8 +99,8 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): rms_epsilon=hp.attention_layer_norm_rms_epsilon, ) ) - self.attn_blocks.append( - SparseMoeBlock( + self.moe_blocks.append( + MoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, @@ -126,25 +127,26 @@ def prefill( self.trace_tensor("mixtral.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - embedding=self.attention_embedding, - start_index=0, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + h = attn_block( + h, + embedding=self.attention_embedding, + start_index=0, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = moe_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) @@ -202,28 +204,29 @@ def decode( self.trace_tensor("mixtral.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - start_positions=start_positions, - embedding=self.attention_embedding, - embedding_batch_mask=embedding_batch_mask, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + h = attn_block( + h, + start_positions=start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=embedding_batch_mask, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = moe_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) diff --git a/sharktank/sharktank/models/mixtral/mixtral_ref.py b/sharktank/sharktank/models/mixtral/mixtral_ref.py index 392f60a25..70a9b9cf8 100644 --- a/sharktank/sharktank/models/mixtral/mixtral_ref.py +++ b/sharktank/sharktank/models/mixtral/mixtral_ref.py @@ -66,6 +66,7 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -78,8 +79,8 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): rms_epsilon=hp.attention_layer_norm_rms_epsilon, ) ) - self.attn_blocks.append( - SparseMoeBlock( + self.moe_blocks.append( + MoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, @@ -130,28 +131,29 @@ def forward( block_count = len(self.attn_blocks) // 2 # print('local_kv_cache, #attn_blocks', len(local_kv_cache), block_count) # Iterate over attention + MoE blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): # print("block_idx, block", block_idx, block) if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "LlamaAttentionBlock": - attn_block_idx = block_idx // 2 - block_cache_k = local_kv_cache[attn_block_idx] - block_cache_v = local_kv_cache[block_count + attn_block_idx] - h = block( - h, - cache_k=block_cache_k, - cache_v=block_cache_v, - start_index=start_index, - attention_mask=attention_mask, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + attn_block_idx = block_idx // 2 + block_cache_k = local_kv_cache[attn_block_idx] + block_cache_v = local_kv_cache[block_count + attn_block_idx] + h = attn_block( + h, + cache_k=block_cache_k, + cache_v=block_cache_v, + start_index=start_index, + attention_mask=attention_mask, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = attn_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) diff --git a/sharktank/sharktank/utils/tokenizer.py b/sharktank/sharktank/utils/tokenizer.py index 29a57f958..a6b0980a0 100644 --- a/sharktank/sharktank/utils/tokenizer.py +++ b/sharktank/sharktank/utils/tokenizer.py @@ -31,6 +31,7 @@ def encode( raw_rows = self._encode(texts) max_length = 0 lengths: list[int] = [] + raw_rows = [row[1:] for row in raw_rows] for row in raw_rows: lengths.append(len(row)) max_length = max(max_length, len(row)) diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py index edf1d9d97..9b3daabdf 100644 --- a/sharktank/tests/models/llama/moe_block_test.py +++ b/sharktank/tests/models/llama/moe_block_test.py @@ -10,18 +10,17 @@ import torch from iree.turbine.aot import * from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch -from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock +from sharktank.layers.mixture_of_experts_block import MoeBlock from sharktank import ops -class SparseMoeBlockTest(unittest.TestCase): +class MoeBlockTest(unittest.TestCase): def test(self): - model = PreGatherMoeBlock( + model = MoeBlock( theta=make_moe_block_theta()("blk.0"), expert_count=8, expert_used_count=2, rms_epsilon=1e-5, - use_grok=False, ) fxb = FxProgramsBuilder(model) input = make_rand_torch((2, 32, 6144))