From 4cfc02efd413a91a162da6cb23bb614ee9cd87b8 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 10 Oct 2024 21:27:46 -0700 Subject: [PATCH 1/6] Refactor llama / mixtral / grok for shared features Many of these features can toggle between depending on architecture. Replumbing the configurations separately allows better reuse and understanding of how models vary between eachother. grok uses a softcap, plumbing a value enables `sc * tanh( v / sc)` grok has some hardcoded values that have better representations, e.g. `sqrt(6144)` and `sqrt(3)`. output normalization is optional but used by mixtral. Presence of the tensor is sufficient for performing the normalization. --- .../sharktank/export_layer/export_moe.py | 5 +- sharktank/sharktank/layers/ffn_moe_block.py | 9 +- .../layers/mixture_of_experts_block.py | 16 ++-- .../layers/paged_llama_attention_block.py | 36 ++++---- sharktank/sharktank/models/grok/grok.py | 84 +++++++++---------- sharktank/sharktank/models/mixtral/mixtral.py | 75 +++++++++-------- .../sharktank/models/mixtral/mixtral_ref.py | 40 ++++----- .../tests/models/llama/moe_block_test.py | 1 - 8 files changed, 137 insertions(+), 129 deletions(-) diff --git a/sharktank/sharktank/export_layer/export_moe.py b/sharktank/sharktank/export_layer/export_moe.py index f2c10c4b4..b0033b2b2 100644 --- a/sharktank/sharktank/export_layer/export_moe.py +++ b/sharktank/sharktank/export_layer/export_moe.py @@ -5,7 +5,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import torch +import torch.nn.functional as F + from iree.turbine.aot import * + from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock from ..utils import cli @@ -51,7 +54,7 @@ def main(): expert_count=8, expert_used_count=2, rms_epsilon=1e-5, - use_grok=args.use_grok, + moe_activation=F.gelu if args.use_grok else F.silu, ) fxb = FxProgramsBuilder(model) input = make_rand_torch((bs, 32, 6144)) diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index 0536302cf..73fea9a9e 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -24,11 +24,10 @@ class PreGatherFFNMOE(ThetaLayer): def __init__( self, theta: Theta, - use_grok: bool = False, + activation=F.silu, ): super().__init__(theta) - self.use_grok = use_grok self.ffn_gate = theta.tensor("ffn_gate_exps", "weight") self.ffn_up = theta.tensor("ffn_up_exps", "weight") @@ -63,10 +62,8 @@ def forward( experts: torch.Tensor, expert_gate: torch.Tensor, ): - if self.use_grok: - ffn_gate = F.gelu(self.pre_matmul_gather(h, self.ffn_gate, experts)) - else: - ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate, experts)) + ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts) + ffn_gate = ops.elementwise(self.activation, ffn_gate) ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts) ffn_down = self.pre_matmul_gather( diff --git a/sharktank/sharktank/layers/mixture_of_experts_block.py b/sharktank/sharktank/layers/mixture_of_experts_block.py index f788d06f0..5b62c0d48 100644 --- a/sharktank/sharktank/layers/mixture_of_experts_block.py +++ b/sharktank/sharktank/layers/mixture_of_experts_block.py @@ -110,8 +110,8 @@ def forward( current_expert = current_expert.reshape(-1, feature_dim) moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype)) - moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) + moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) moe_output = self.layer_output_norm(moe_output) return h + moe_output @@ -130,13 +130,12 @@ def __init__( expert_count: int, expert_used_count: int, rms_epsilon: float, - use_grok: Optional[bool] = False, + moe_activation=F.silu, ): super().__init__(theta) self.expert_count = expert_count self.expert_used_count = expert_used_count - self.use_grok = use_grok # Add router gate self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp"))) @@ -146,15 +145,17 @@ def __init__( "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) ) - # Add FFN output norm layer for Grok - if self.use_grok: + # Add optional FFN output norm layer + if theta.optional_tensor("layer_output_norm") is not None: self.add_module( "layer_output_norm", RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon), ) + else: + self.add_module("layer_output_norm", torch.nn.Identity()) # Add expert_count x FFN - self.experts = PreGatherFFNMOE(theta, use_grok=self.use_grok) + self.experts = PreGatherFFNMOE(theta, activation=moe_activation) def forward( self, @@ -180,7 +181,6 @@ def forward( moe_output = self.experts(ffn_input, top_k_experts, expert_gate) moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) - if self.use_grok: - moe_output = self.layer_output_norm(moe_output) + moe_output = self.layer_output_norm(moe_output) return h + moe_output diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 123c43be7..5bc045608 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -37,7 +37,8 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, - use_grok: Optional[bool] = False, + attention_scale: Optional[float] = None, + softcap: Optional[float] = None, ): super().__init__(theta) @@ -46,7 +47,8 @@ def __init__( self.head_count = head_count self.head_dim = head_dim self.head_count_kv = head_count_kv - self.use_grok = use_grok + self.attention_scale = attention_scale + self.softcap = softcap self.add_module( "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon) @@ -56,7 +58,12 @@ def __init__( self.add_module("attn_v", LinearLayer(theta("attn_v"))) self.add_module("attn_output", LinearLayer(theta("attn_output"))) - if self.use_grok: + if theta.optional_tensor("attn_output_norm") is None: + self.add_module( + "attn_output_norm", + torch.nn.Identity(), + ) + else: self.add_module( "attn_output_norm", RMSNormLayer(theta("attn_output_norm"), epsilon=rms_epsilon), @@ -147,16 +154,16 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: keys = xk.transpose(1, 2) values = xv.transpose(1, 2) + attn_weights = ops.matmul(xq, keys.transpose(2, 3)) + if self.attention_scale is None: + attn_weights = attn_weights / math.sqrt(self.head_dim) + else: + attn_weights = attn_weights * self.attention_scale + # Flash attention. - if not self.use_grok: - attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt( - self.head_dim - ) - elif self.use_grok: - attn_weights = ops.matmul(xq, keys.transpose(2, 3)) - attn_weights = 30.0 * torch.tanh( - attn_weights * (0.08838834764831845 / 30.0) - ) + if self.softcap is not None: + attn_weights = self.softcap * torch.tanh(attn_weights / self.softcap) + self.assert_not_nan(attn_weights) # Apply attention mask. @@ -172,12 +179,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: # Project. attn_output = self.attn_output(attn_output) - - if self.use_grok: - attn_output = self.attn_output_norm(attn_output) + attn_output = self.attn_output_norm(attn_output) h = h + attn_output - return h def transact_cache_direct( diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index debeb30c7..f7e244e1f 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from ...layers import * @@ -82,6 +83,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -93,16 +95,15 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): head_dim=hp.attn_head_dim, head_count_kv=hp.attention_head_count_kv, rms_epsilon=hp.attention_layer_norm_rms_epsilon, - use_grok=True, ) ) - self.attn_blocks.append( + self.moe_blocks.append( PreGatherMoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, rms_epsilon=hp.attention_layer_norm_rms_epsilon, - use_grok=True, + activation=F.gelu, ) ) @@ -122,33 +123,32 @@ def prefill( self._assert_device(seq_block_ids) self._assert_device(*cache_state, dtype=self.activation_dtype) h = self.token_embedding(tokens) - h *= 78.38367176906169 + h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - embedding=self.attention_embedding, - start_index=0, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - ) - self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "PreGatherMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) + h = attn_block( + h, + embedding=self.attention_embedding, + start_index=0, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + ) + self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) + + h = moe_block(h) + self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits * 0.5773502691896257 + logits = logits / math.sqrt(3.0) return logits def decode( @@ -200,34 +200,34 @@ def decode( ) h = self.token_embedding(tokens) - h *= 78.38367176906169 + h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + self.attn_blocks, self.moe_blocks + ): if block_idx == 0: self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - start_positions=start_positions, - embedding=self.attention_embedding, - embedding_batch_mask=embedding_batch_mask, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "PreGatherMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) + h = attn_block( + h, + start_positions=start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=embedding_batch_mask, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + softcap=30.0, + ) + self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) + + h = moe_block(h) + self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits * 0.5773502691896257 + logits = logits / math.sqrt(3.0) return logits diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index 1fc86f87d..34afdd6cd 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -85,6 +85,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -98,7 +99,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): rms_epsilon=hp.attention_layer_norm_rms_epsilon, ) ) - self.attn_blocks.append( + self.moe_blocks.append( SparseMoeBlock( theta("blk", n), expert_count=hp.expert_count, @@ -126,25 +127,26 @@ def prefill( self.trace_tensor("mixtral.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - embedding=self.attention_embedding, - start_index=0, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + h = attn_block( + h, + embedding=self.attention_embedding, + start_index=0, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = moe_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) @@ -202,28 +204,29 @@ def decode( self.trace_tensor("mixtral.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - start_positions=start_positions, - embedding=self.attention_embedding, - embedding_batch_mask=embedding_batch_mask, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + h = attn_block( + h, + start_positions=start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=embedding_batch_mask, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = moe_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) diff --git a/sharktank/sharktank/models/mixtral/mixtral_ref.py b/sharktank/sharktank/models/mixtral/mixtral_ref.py index 392f60a25..779f7480c 100644 --- a/sharktank/sharktank/models/mixtral/mixtral_ref.py +++ b/sharktank/sharktank/models/mixtral/mixtral_ref.py @@ -66,6 +66,7 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -78,7 +79,7 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): rms_epsilon=hp.attention_layer_norm_rms_epsilon, ) ) - self.attn_blocks.append( + self.moe_blocks.append( SparseMoeBlock( theta("blk", n), expert_count=hp.expert_count, @@ -130,28 +131,29 @@ def forward( block_count = len(self.attn_blocks) // 2 # print('local_kv_cache, #attn_blocks', len(local_kv_cache), block_count) # Iterate over attention + MoE blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): # print("block_idx, block", block_idx, block) if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "LlamaAttentionBlock": - attn_block_idx = block_idx // 2 - block_cache_k = local_kv_cache[attn_block_idx] - block_cache_v = local_kv_cache[block_count + attn_block_idx] - h = block( - h, - cache_k=block_cache_k, - cache_v=block_cache_v, - start_index=start_index, - attention_mask=attention_mask, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + attn_block_idx = block_idx // 2 + block_cache_k = local_kv_cache[attn_block_idx] + block_cache_v = local_kv_cache[block_count + attn_block_idx] + h = attn_block( + h, + cache_k=block_cache_k, + cache_v=block_cache_v, + start_index=start_index, + attention_mask=attention_mask, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = attn_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py index edf1d9d97..2f3bd3cf0 100644 --- a/sharktank/tests/models/llama/moe_block_test.py +++ b/sharktank/tests/models/llama/moe_block_test.py @@ -21,7 +21,6 @@ def test(self): expert_count=8, expert_used_count=2, rms_epsilon=1e-5, - use_grok=False, ) fxb = FxProgramsBuilder(model) input = make_rand_torch((2, 32, 6144)) From 53aaa413d74d1d26d6100d50d5eae08d62ad703f Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 10 Oct 2024 21:27:46 -0700 Subject: [PATCH 2/6] Refactor llama / mixtral / grok for shared features Many of these features can toggle between depending on architecture. Replumbing the configurations separately allows better reuse and understanding of how models vary between eachother. grok uses a softcap, plumbing a value enables `sc * tanh( v / sc)` grok has some hardcoded values that have better representations, e.g. `sqrt(6144)` and `sqrt(3)`. output normalization is optional but used by mixtral. Presence of the tensor is sufficient for performing the normalization. We remove the sparse moe block as we now know it will not be used due to poor performance. --- .../sharktank/export_layer/export_moe.py | 4 +- sharktank/sharktank/layers/__init__.py | 2 +- .../layers/mixture_of_experts_block.py | 100 +----------------- sharktank/sharktank/models/grok/grok.py | 2 +- sharktank/sharktank/models/mixtral/mixtral.py | 2 +- .../sharktank/models/mixtral/mixtral_ref.py | 2 +- .../tests/models/llama/moe_block_test.py | 4 +- 7 files changed, 10 insertions(+), 106 deletions(-) diff --git a/sharktank/sharktank/export_layer/export_moe.py b/sharktank/sharktank/export_layer/export_moe.py index b0033b2b2..3db190a5e 100644 --- a/sharktank/sharktank/export_layer/export_moe.py +++ b/sharktank/sharktank/export_layer/export_moe.py @@ -10,7 +10,7 @@ from iree.turbine.aot import * from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch -from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock +from sharktank.layers.mixture_of_experts_block import MoeBlock from ..utils import cli @@ -49,7 +49,7 @@ def main(): bs = args.batch_size - model = PreGatherMoeBlock( + model = MoeBlock( theta=make_moe_block_theta()("blk.0"), expert_count=8, expert_used_count=2, diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index a90def3a9..fd56ec872 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -16,6 +16,6 @@ from .paged_llama_attention_block import PagedLlamaAttentionBlock from .ffn_block import FFN from .ffn_moe_block import FFNMOE -from .mixture_of_experts_block import SparseMoeBlock, PreGatherMoeBlock +from .mixture_of_experts_block import MoeBlock from .configs import * diff --git a/sharktank/sharktank/layers/mixture_of_experts_block.py b/sharktank/sharktank/layers/mixture_of_experts_block.py index 5b62c0d48..ddce16c55 100644 --- a/sharktank/sharktank/layers/mixture_of_experts_block.py +++ b/sharktank/sharktank/layers/mixture_of_experts_block.py @@ -16,107 +16,11 @@ from .ffn_moe_block import FFNMOE, PreGatherFFNMOE __all__ = [ - "SparseMoeBlock", - "PreGatherMoeBlock", + "MoeBlock", ] -class SparseMoeBlock(ThetaLayer): - """ - This implementation considers MoE operations as block-sparse - operations to support imbalanced token assignments to experts. - This enables the MoE to operate at a faster rate and in full capacity without any dropped tokens - (or reduced performance). - """ - - def __init__( - self, - theta: Theta, - expert_count: int, - expert_used_count: int, - rms_epsilon: float, - ): - super().__init__(theta) - - # Add router gate - self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp"))) - - # Add FFN norm - self.add_module( - "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) - ) - - # Add FFN output norm - self.add_module( - "layer_output_norm", - RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon), - ) - - # Add expert_count x FFN - self.experts = nn.ModuleList( - [FFNMOE(theta, expert_idx=i) for i in range(expert_count)] - ) - - self.expert_count = expert_count - self.expert_used_count = expert_used_count - - def forward( - self, - h: torch.Tensor, - ): - ffn_input = self.ffn_norm(h) - batch_size, sequence_length, feature_dim = ffn_input.shape - ffn_input = ffn_input.view(-1, feature_dim) - - # For each token, the router calculates the router weights for all experts - # router_logits: (batch_size * sequence_length, expert_count) - router_logits = self.ffn_gate_inp(ffn_input) - router_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - - # Select top k experts from router weights - router_weights, top_k_experts = torch.topk( - router_weights, self.expert_used_count, dim=-1 - ) - router_weights /= router_weights.sum(dim=-1, keepdim=True) - router_weights = router_weights.to(ffn_input.dtype) - - moe_output = torch.zeros( - (batch_size * sequence_length, feature_dim), dtype=ffn_input.dtype - ) - - # Create an expert mask by one hot encoding the selected top k experts - # used to index which expert is to be invoked for each token - # expert_mask: (expert_count, expert_used_count, sequence_length) - expert_mask = F.one_hot(top_k_experts, num_classes=self.expert_count).permute( - 2, 1, 0 - ) - - # Iterate over all experts in the model - for expert_idx in range(self.expert_count): - expert_layer = self.experts[expert_idx] - top_k_expert_idx, token_idx = torch.where(expert_mask[expert_idx]) - - # Given the hidden states, index the tokens assigned to this expert - # and calculate the current expert's hidden state and weigh the - # output expert hidden states by the router weights, based on the - # appropriate tokens - current_expert_tokens = ffn_input[None, token_idx] - - current_expert = ( - expert_layer(current_expert_tokens) - * router_weights[token_idx, top_k_expert_idx, None] - ) - - current_expert = current_expert.reshape(-1, feature_dim) - - moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype)) - - moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) - moe_output = self.layer_output_norm(moe_output) - return h + moe_output - - -class PreGatherMoeBlock(ThetaLayer): +class MoeBlock(ThetaLayer): """ This implementation considers MoE operations as block-sparse operations to support imbalanced token assignments to experts. diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index f7e244e1f..1abbf7113 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -98,7 +98,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ) ) self.moe_blocks.append( - PreGatherMoeBlock( + MoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index 34afdd6cd..e2995dfde 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -100,7 +100,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ) ) self.moe_blocks.append( - SparseMoeBlock( + MoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, diff --git a/sharktank/sharktank/models/mixtral/mixtral_ref.py b/sharktank/sharktank/models/mixtral/mixtral_ref.py index 779f7480c..70a9b9cf8 100644 --- a/sharktank/sharktank/models/mixtral/mixtral_ref.py +++ b/sharktank/sharktank/models/mixtral/mixtral_ref.py @@ -80,7 +80,7 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): ) ) self.moe_blocks.append( - SparseMoeBlock( + MoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py index 2f3bd3cf0..dd8a19649 100644 --- a/sharktank/tests/models/llama/moe_block_test.py +++ b/sharktank/tests/models/llama/moe_block_test.py @@ -10,11 +10,11 @@ import torch from iree.turbine.aot import * from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch -from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock +from sharktank.layers.mixture_of_experts_block import MoeBlock from sharktank import ops -class SparseMoeBlockTest(unittest.TestCase): +class MoeBlockTest(unittest.TestCase): def test(self): model = PreGatherMoeBlock( theta=make_moe_block_theta()("blk.0"), From da770583553e45a30df70adf04355b6846539e23 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 11 Oct 2024 12:26:19 -0700 Subject: [PATCH 3/6] fix tests --- sharktank/sharktank/layers/ffn_moe_block.py | 5 +++-- sharktank/tests/models/llama/moe_block_test.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index 73fea9a9e..0746f0fa0 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -12,7 +12,7 @@ from .base import ThetaLayer from .linear import LinearLayer from ..types import Theta, DefaultPrimitiveTensor -from ..ops import einsum_2args +from ..ops import einsum_2args, elementwise __all__ = [ "FFNMOE", @@ -32,6 +32,7 @@ def __init__( self.ffn_gate = theta.tensor("ffn_gate_exps", "weight") self.ffn_up = theta.tensor("ffn_up_exps", "weight") self.ffn_down = theta.tensor("ffn_down_exps", "weight") + self.activation = activation def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"): inputs = inputs[:, :] @@ -63,7 +64,7 @@ def forward( expert_gate: torch.Tensor, ): ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts) - ffn_gate = ops.elementwise(self.activation, ffn_gate) + ffn_gate = elementwise(self.activation, ffn_gate) ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts) ffn_down = self.pre_matmul_gather( diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py index dd8a19649..9b3daabdf 100644 --- a/sharktank/tests/models/llama/moe_block_test.py +++ b/sharktank/tests/models/llama/moe_block_test.py @@ -16,7 +16,7 @@ class MoeBlockTest(unittest.TestCase): def test(self): - model = PreGatherMoeBlock( + model = MoeBlock( theta=make_moe_block_theta()("blk.0"), expert_count=8, expert_used_count=2, From 7d6657a5e015d1e8eb02b73875e4dd75522f2d4b Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 14 Oct 2024 13:19:29 -0700 Subject: [PATCH 4/6] rework for review comments --- sharktank/sharktank/export_layer/export_moe.py | 6 +++--- sharktank/sharktank/layers/llama_attention_block.py | 4 +++- .../sharktank/layers/paged_llama_attention_block.py | 2 +- sharktank/sharktank/models/grok/grok.py | 12 ++++++------ sharktank/sharktank/models/llama/llama_ref.py | 4 +++- 5 files changed, 16 insertions(+), 12 deletions(-) diff --git a/sharktank/sharktank/export_layer/export_moe.py b/sharktank/sharktank/export_layer/export_moe.py index 3db190a5e..af4ed51d2 100644 --- a/sharktank/sharktank/export_layer/export_moe.py +++ b/sharktank/sharktank/export_layer/export_moe.py @@ -40,8 +40,8 @@ def main(): action="store_true", ) parser.add_argument( - "--use-grok", - help="Enable to export Grok model's version of MOE block", + "--use-gelu", + help="Enable to use gelu for moe activation", action="store_true", ) @@ -54,7 +54,7 @@ def main(): expert_count=8, expert_used_count=2, rms_epsilon=1e-5, - moe_activation=F.gelu if args.use_grok else F.silu, + moe_activation=F.gelu if args.use_gelu else F.silu, ) fxb = FxProgramsBuilder(model) input = make_rand_torch((bs, 32, 6144)) diff --git a/sharktank/sharktank/layers/llama_attention_block.py b/sharktank/sharktank/layers/llama_attention_block.py index 7be8c7366..e42550026 100644 --- a/sharktank/sharktank/layers/llama_attention_block.py +++ b/sharktank/sharktank/layers/llama_attention_block.py @@ -110,7 +110,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: values = values.transpose(1, 2) # Flash attention. - attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt( + self.head_dim + ) # Apply attention mask. if attention_mask is not None: diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 5bc045608..a1d05f790 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -156,7 +156,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: attn_weights = ops.matmul(xq, keys.transpose(2, 3)) if self.attention_scale is None: - attn_weights = attn_weights / math.sqrt(self.head_dim) + attn_weights = attn_weights / torch.sqrt(self.head_dim) else: attn_weights = attn_weights * self.attention_scale diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index 1abbf7113..6fa2cfca7 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -103,7 +103,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, rms_epsilon=hp.attention_layer_norm_rms_epsilon, - activation=F.gelu, + moe_activation=F.gelu, ) ) @@ -123,7 +123,7 @@ def prefill( self._assert_device(seq_block_ids) self._assert_device(*cache_state, dtype=self.activation_dtype) h = self.token_embedding(tokens) - h *= math.sqrt(h.shape[-1]) + h *= tanh.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. @@ -148,7 +148,7 @@ def prefill( h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits / math.sqrt(3.0) + logits = logits / torch.sqrt(3.0) return logits def decode( @@ -200,7 +200,7 @@ def decode( ) h = self.token_embedding(tokens) - h *= math.sqrt(h.shape[-1]) + h *= torch.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. @@ -220,7 +220,7 @@ def decode( seq_block_ids=seq_block_ids, xk_temp=xk_temp, xv_temp=xv_temp, - softcap=30.0, + softcap=30.0, # https://github.com/xai-org/grok-1/blob/7050ed204b8206bb8645c7b7bbef7252f79561b0/model.py#L864 ) self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) @@ -229,5 +229,5 @@ def decode( h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits / math.sqrt(3.0) + logits = logits / torch.sqrt(3.0) return logits diff --git a/sharktank/sharktank/models/llama/llama_ref.py b/sharktank/sharktank/models/llama/llama_ref.py index 74ed9e8e0..1400a44ad 100644 --- a/sharktank/sharktank/models/llama/llama_ref.py +++ b/sharktank/sharktank/models/llama/llama_ref.py @@ -230,7 +230,9 @@ def forward( values = values.transpose(1, 2) # Flash attention. - attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt( + self.head_dim + ) # Apply attention mask. if attention_mask is not None: From 933e18bc7988b47a17b07331ae673c97cff5f261 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 15 Oct 2024 17:02:29 -0700 Subject: [PATCH 5/6] Archana comments --- sharktank/sharktank/layers/llama_attention_block.py | 2 -- sharktank/sharktank/layers/paged_llama_attention_block.py | 2 +- sharktank/sharktank/models/grok/grok.py | 4 ++-- sharktank/sharktank/models/llama/llama.py | 1 - sharktank/sharktank/models/llama/llama_ref.py | 1 - sharktank/sharktank/utils/tokenizer.py | 1 + 6 files changed, 4 insertions(+), 7 deletions(-) diff --git a/sharktank/sharktank/layers/llama_attention_block.py b/sharktank/sharktank/layers/llama_attention_block.py index e42550026..0cdb5d713 100644 --- a/sharktank/sharktank/layers/llama_attention_block.py +++ b/sharktank/sharktank/layers/llama_attention_block.py @@ -6,8 +6,6 @@ from typing import Optional -import math - import torch import torch.nn.functional as F diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index a1d05f790..5bc045608 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -156,7 +156,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: attn_weights = ops.matmul(xq, keys.transpose(2, 3)) if self.attention_scale is None: - attn_weights = attn_weights / torch.sqrt(self.head_dim) + attn_weights = attn_weights / math.sqrt(self.head_dim) else: attn_weights = attn_weights * self.attention_scale diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index 6fa2cfca7..4dfd7ce5c 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -95,6 +95,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): head_dim=hp.attn_head_dim, head_count_kv=hp.attention_head_count_kv, rms_epsilon=hp.attention_layer_norm_rms_epsilon, + softcap=30.0, # https://github.com/xai-org/grok-1/blob/7050ed204b8206bb8645c7b7bbef7252f79561b0/model.py#L864 ) ) self.moe_blocks.append( @@ -123,7 +124,7 @@ def prefill( self._assert_device(seq_block_ids) self._assert_device(*cache_state, dtype=self.activation_dtype) h = self.token_embedding(tokens) - h *= tanh.sqrt(h.shape[-1]) + h *= torch.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. @@ -220,7 +221,6 @@ def decode( seq_block_ids=seq_block_ids, xk_temp=xk_temp, xv_temp=xv_temp, - softcap=30.0, # https://github.com/xai-org/grok-1/blob/7050ed204b8206bb8645c7b7bbef7252f79561b0/model.py#L864 ) self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index c324a79d5..344976ead 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -7,7 +7,6 @@ from typing import Optional from dataclasses import dataclass -import math from typing import Union import torch diff --git a/sharktank/sharktank/models/llama/llama_ref.py b/sharktank/sharktank/models/llama/llama_ref.py index 1400a44ad..9f77daa40 100644 --- a/sharktank/sharktank/models/llama/llama_ref.py +++ b/sharktank/sharktank/models/llama/llama_ref.py @@ -7,7 +7,6 @@ from typing import Optional from dataclasses import dataclass -import math import torch import torch.nn as nn diff --git a/sharktank/sharktank/utils/tokenizer.py b/sharktank/sharktank/utils/tokenizer.py index 29a57f958..a6b0980a0 100644 --- a/sharktank/sharktank/utils/tokenizer.py +++ b/sharktank/sharktank/utils/tokenizer.py @@ -31,6 +31,7 @@ def encode( raw_rows = self._encode(texts) max_length = 0 lengths: list[int] = [] + raw_rows = [row[1:] for row in raw_rows] for row in raw_rows: lengths.append(len(row)) max_length = max(max_length, len(row)) From 80891fed26af737f21ecb064209796b5aa9dfa6d Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 15 Oct 2024 17:07:39 -0700 Subject: [PATCH 6/6] fix zip --- sharktank/sharktank/models/grok/grok.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index 4dfd7ce5c..077e4e064 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import math import torch import torch.nn as nn @@ -124,7 +125,7 @@ def prefill( self._assert_device(seq_block_ids) self._assert_device(*cache_state, dtype=self.activation_dtype) h = self.token_embedding(tokens) - h *= torch.sqrt(h.shape[-1]) + h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. @@ -149,7 +150,7 @@ def prefill( h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits / torch.sqrt(3.0) + logits = logits / math.sqrt(3.0) return logits def decode( @@ -201,12 +202,12 @@ def decode( ) h = self.token_embedding(tokens) - h *= torch.sqrt(h.shape[-1]) + h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. for block_idx, (attn_block, moe_block) in enumerate( - self.attn_blocks, self.moe_blocks + zip(self.attn_blocks, self.moe_blocks) ): if block_idx == 0: self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) @@ -229,5 +230,5 @@ def decode( h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits / torch.sqrt(3.0) + logits = logits / math.sqrt(3.0) return logits