diff --git a/sharktank/sharktank/export_layer/export_moe.py b/sharktank/sharktank/export_layer/export_moe.py index f2c10c4b4..b0033b2b2 100644 --- a/sharktank/sharktank/export_layer/export_moe.py +++ b/sharktank/sharktank/export_layer/export_moe.py @@ -5,7 +5,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import torch +import torch.nn.functional as F + from iree.turbine.aot import * + from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock from ..utils import cli @@ -51,7 +54,7 @@ def main(): expert_count=8, expert_used_count=2, rms_epsilon=1e-5, - use_grok=args.use_grok, + moe_activation=F.gelu if args.use_grok else F.silu, ) fxb = FxProgramsBuilder(model) input = make_rand_torch((bs, 32, 6144)) diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index 0536302cf..73fea9a9e 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -24,11 +24,10 @@ class PreGatherFFNMOE(ThetaLayer): def __init__( self, theta: Theta, - use_grok: bool = False, + activation=F.silu, ): super().__init__(theta) - self.use_grok = use_grok self.ffn_gate = theta.tensor("ffn_gate_exps", "weight") self.ffn_up = theta.tensor("ffn_up_exps", "weight") @@ -63,10 +62,8 @@ def forward( experts: torch.Tensor, expert_gate: torch.Tensor, ): - if self.use_grok: - ffn_gate = F.gelu(self.pre_matmul_gather(h, self.ffn_gate, experts)) - else: - ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate, experts)) + ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts) + ffn_gate = ops.elementwise(self.activation, ffn_gate) ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts) ffn_down = self.pre_matmul_gather( diff --git a/sharktank/sharktank/layers/mixture_of_experts_block.py b/sharktank/sharktank/layers/mixture_of_experts_block.py index f788d06f0..5b62c0d48 100644 --- a/sharktank/sharktank/layers/mixture_of_experts_block.py +++ b/sharktank/sharktank/layers/mixture_of_experts_block.py @@ -110,8 +110,8 @@ def forward( current_expert = current_expert.reshape(-1, feature_dim) moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype)) - moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) + moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) moe_output = self.layer_output_norm(moe_output) return h + moe_output @@ -130,13 +130,12 @@ def __init__( expert_count: int, expert_used_count: int, rms_epsilon: float, - use_grok: Optional[bool] = False, + moe_activation=F.silu, ): super().__init__(theta) self.expert_count = expert_count self.expert_used_count = expert_used_count - self.use_grok = use_grok # Add router gate self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp"))) @@ -146,15 +145,17 @@ def __init__( "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) ) - # Add FFN output norm layer for Grok - if self.use_grok: + # Add optional FFN output norm layer + if theta.optional_tensor("layer_output_norm") is not None: self.add_module( "layer_output_norm", RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon), ) + else: + self.add_module("layer_output_norm", torch.nn.Identity()) # Add expert_count x FFN - self.experts = PreGatherFFNMOE(theta, use_grok=self.use_grok) + self.experts = PreGatherFFNMOE(theta, activation=moe_activation) def forward( self, @@ -180,7 +181,6 @@ def forward( moe_output = self.experts(ffn_input, top_k_experts, expert_gate) moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) - if self.use_grok: - moe_output = self.layer_output_norm(moe_output) + moe_output = self.layer_output_norm(moe_output) return h + moe_output diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 123c43be7..5bc045608 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -37,7 +37,8 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, - use_grok: Optional[bool] = False, + attention_scale: Optional[float] = None, + softcap: Optional[float] = None, ): super().__init__(theta) @@ -46,7 +47,8 @@ def __init__( self.head_count = head_count self.head_dim = head_dim self.head_count_kv = head_count_kv - self.use_grok = use_grok + self.attention_scale = attention_scale + self.softcap = softcap self.add_module( "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon) @@ -56,7 +58,12 @@ def __init__( self.add_module("attn_v", LinearLayer(theta("attn_v"))) self.add_module("attn_output", LinearLayer(theta("attn_output"))) - if self.use_grok: + if theta.optional_tensor("attn_output_norm") is None: + self.add_module( + "attn_output_norm", + torch.nn.Identity(), + ) + else: self.add_module( "attn_output_norm", RMSNormLayer(theta("attn_output_norm"), epsilon=rms_epsilon), @@ -147,16 +154,16 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: keys = xk.transpose(1, 2) values = xv.transpose(1, 2) + attn_weights = ops.matmul(xq, keys.transpose(2, 3)) + if self.attention_scale is None: + attn_weights = attn_weights / math.sqrt(self.head_dim) + else: + attn_weights = attn_weights * self.attention_scale + # Flash attention. - if not self.use_grok: - attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt( - self.head_dim - ) - elif self.use_grok: - attn_weights = ops.matmul(xq, keys.transpose(2, 3)) - attn_weights = 30.0 * torch.tanh( - attn_weights * (0.08838834764831845 / 30.0) - ) + if self.softcap is not None: + attn_weights = self.softcap * torch.tanh(attn_weights / self.softcap) + self.assert_not_nan(attn_weights) # Apply attention mask. @@ -172,12 +179,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: # Project. attn_output = self.attn_output(attn_output) - - if self.use_grok: - attn_output = self.attn_output_norm(attn_output) + attn_output = self.attn_output_norm(attn_output) h = h + attn_output - return h def transact_cache_direct( diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index debeb30c7..f7e244e1f 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from ...layers import * @@ -82,6 +83,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -93,16 +95,15 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): head_dim=hp.attn_head_dim, head_count_kv=hp.attention_head_count_kv, rms_epsilon=hp.attention_layer_norm_rms_epsilon, - use_grok=True, ) ) - self.attn_blocks.append( + self.moe_blocks.append( PreGatherMoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, rms_epsilon=hp.attention_layer_norm_rms_epsilon, - use_grok=True, + activation=F.gelu, ) ) @@ -122,33 +123,32 @@ def prefill( self._assert_device(seq_block_ids) self._assert_device(*cache_state, dtype=self.activation_dtype) h = self.token_embedding(tokens) - h *= 78.38367176906169 + h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - embedding=self.attention_embedding, - start_index=0, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - ) - self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "PreGatherMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) + h = attn_block( + h, + embedding=self.attention_embedding, + start_index=0, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + ) + self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) + + h = moe_block(h) + self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits * 0.5773502691896257 + logits = logits / math.sqrt(3.0) return logits def decode( @@ -200,34 +200,34 @@ def decode( ) h = self.token_embedding(tokens) - h *= 78.38367176906169 + h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + self.attn_blocks, self.moe_blocks + ): if block_idx == 0: self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - start_positions=start_positions, - embedding=self.attention_embedding, - embedding_batch_mask=embedding_batch_mask, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "PreGatherMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) + h = attn_block( + h, + start_positions=start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=embedding_batch_mask, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + softcap=30.0, + ) + self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) + + h = moe_block(h) + self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits * 0.5773502691896257 + logits = logits / math.sqrt(3.0) return logits diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index 1fc86f87d..34afdd6cd 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -85,6 +85,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -98,7 +99,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): rms_epsilon=hp.attention_layer_norm_rms_epsilon, ) ) - self.attn_blocks.append( + self.moe_blocks.append( SparseMoeBlock( theta("blk", n), expert_count=hp.expert_count, @@ -126,25 +127,26 @@ def prefill( self.trace_tensor("mixtral.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - embedding=self.attention_embedding, - start_index=0, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + h = attn_block( + h, + embedding=self.attention_embedding, + start_index=0, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = moe_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) @@ -202,28 +204,29 @@ def decode( self.trace_tensor("mixtral.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - start_positions=start_positions, - embedding=self.attention_embedding, - embedding_batch_mask=embedding_batch_mask, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + h = attn_block( + h, + start_positions=start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=embedding_batch_mask, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = moe_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) diff --git a/sharktank/sharktank/models/mixtral/mixtral_ref.py b/sharktank/sharktank/models/mixtral/mixtral_ref.py index 392f60a25..779f7480c 100644 --- a/sharktank/sharktank/models/mixtral/mixtral_ref.py +++ b/sharktank/sharktank/models/mixtral/mixtral_ref.py @@ -66,6 +66,7 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -78,7 +79,7 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): rms_epsilon=hp.attention_layer_norm_rms_epsilon, ) ) - self.attn_blocks.append( + self.moe_blocks.append( SparseMoeBlock( theta("blk", n), expert_count=hp.expert_count, @@ -130,28 +131,29 @@ def forward( block_count = len(self.attn_blocks) // 2 # print('local_kv_cache, #attn_blocks', len(local_kv_cache), block_count) # Iterate over attention + MoE blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): # print("block_idx, block", block_idx, block) if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "LlamaAttentionBlock": - attn_block_idx = block_idx // 2 - block_cache_k = local_kv_cache[attn_block_idx] - block_cache_v = local_kv_cache[block_count + attn_block_idx] - h = block( - h, - cache_k=block_cache_k, - cache_v=block_cache_v, - start_index=start_index, - attention_mask=attention_mask, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + attn_block_idx = block_idx // 2 + block_cache_k = local_kv_cache[attn_block_idx] + block_cache_v = local_kv_cache[block_count + attn_block_idx] + h = attn_block( + h, + cache_k=block_cache_k, + cache_v=block_cache_v, + start_index=start_index, + attention_mask=attention_mask, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = attn_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py index edf1d9d97..2f3bd3cf0 100644 --- a/sharktank/tests/models/llama/moe_block_test.py +++ b/sharktank/tests/models/llama/moe_block_test.py @@ -21,7 +21,6 @@ def test(self): expert_count=8, expert_used_count=2, rms_epsilon=1e-5, - use_grok=False, ) fxb = FxProgramsBuilder(model) input = make_rand_torch((2, 32, 6144))