From c45802718560a2fa501acefa6086e013fe4a542b Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 10 Oct 2024 21:27:46 -0700 Subject: [PATCH] Refactor llama / mixtral / grok for shared features Many of these features can toggle between depending on architecture. Replumbing the configurations separately allows better reuse and understanding of how models vary between eachother. grok uses a softcap, plumbing a value enables `sc * tanh( v / sc)` grok has some hardcoded values that have better representations, e.g. `sqrt(6144)` and `sqrt(3)`. output normalization is optional but used by mixtral. Presence of the tensor is sufficient for performing the normalization. --- .../sharktank/export_layer/export_moe.py | 5 +- sharktank/sharktank/layers/ffn_moe_block.py | 9 +- .../layers/mixture_of_experts_block.py | 16 ++-- .../layers/paged_llama_attention_block.py | 36 ++++---- sharktank/sharktank/models/grok/grok.py | 84 +++++++++---------- sharktank/sharktank/models/mixtral/mixtral.py | 75 +++++++++-------- .../sharktank/models/mixtral/mixtral_ref.py | 40 ++++----- .../tests/models/llama/moe_block_test.py | 1 - 8 files changed, 137 insertions(+), 129 deletions(-) diff --git a/sharktank/sharktank/export_layer/export_moe.py b/sharktank/sharktank/export_layer/export_moe.py index f2c10c4b4..b0033b2b2 100644 --- a/sharktank/sharktank/export_layer/export_moe.py +++ b/sharktank/sharktank/export_layer/export_moe.py @@ -5,7 +5,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import torch +import torch.nn.functional as F + from iree.turbine.aot import * + from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock from ..utils import cli @@ -51,7 +54,7 @@ def main(): expert_count=8, expert_used_count=2, rms_epsilon=1e-5, - use_grok=args.use_grok, + moe_activation=F.gelu if args.use_grok else F.silu, ) fxb = FxProgramsBuilder(model) input = make_rand_torch((bs, 32, 6144)) diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index 0536302cf..73fea9a9e 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -24,11 +24,10 @@ class PreGatherFFNMOE(ThetaLayer): def __init__( self, theta: Theta, - use_grok: bool = False, + activation=F.silu, ): super().__init__(theta) - self.use_grok = use_grok self.ffn_gate = theta.tensor("ffn_gate_exps", "weight") self.ffn_up = theta.tensor("ffn_up_exps", "weight") @@ -63,10 +62,8 @@ def forward( experts: torch.Tensor, expert_gate: torch.Tensor, ): - if self.use_grok: - ffn_gate = F.gelu(self.pre_matmul_gather(h, self.ffn_gate, experts)) - else: - ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate, experts)) + ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts) + ffn_gate = ops.elementwise(self.activation, ffn_gate) ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts) ffn_down = self.pre_matmul_gather( diff --git a/sharktank/sharktank/layers/mixture_of_experts_block.py b/sharktank/sharktank/layers/mixture_of_experts_block.py index f788d06f0..5b62c0d48 100644 --- a/sharktank/sharktank/layers/mixture_of_experts_block.py +++ b/sharktank/sharktank/layers/mixture_of_experts_block.py @@ -110,8 +110,8 @@ def forward( current_expert = current_expert.reshape(-1, feature_dim) moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype)) - moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) + moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) moe_output = self.layer_output_norm(moe_output) return h + moe_output @@ -130,13 +130,12 @@ def __init__( expert_count: int, expert_used_count: int, rms_epsilon: float, - use_grok: Optional[bool] = False, + moe_activation=F.silu, ): super().__init__(theta) self.expert_count = expert_count self.expert_used_count = expert_used_count - self.use_grok = use_grok # Add router gate self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp"))) @@ -146,15 +145,17 @@ def __init__( "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) ) - # Add FFN output norm layer for Grok - if self.use_grok: + # Add optional FFN output norm layer + if theta.optional_tensor("layer_output_norm") is not None: self.add_module( "layer_output_norm", RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon), ) + else: + self.add_module("layer_output_norm", torch.nn.Identity()) # Add expert_count x FFN - self.experts = PreGatherFFNMOE(theta, use_grok=self.use_grok) + self.experts = PreGatherFFNMOE(theta, activation=moe_activation) def forward( self, @@ -180,7 +181,6 @@ def forward( moe_output = self.experts(ffn_input, top_k_experts, expert_gate) moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) - if self.use_grok: - moe_output = self.layer_output_norm(moe_output) + moe_output = self.layer_output_norm(moe_output) return h + moe_output diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 123c43be7..5bc045608 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -37,7 +37,8 @@ def __init__( head_dim: int, head_count_kv: int, rms_epsilon: float, - use_grok: Optional[bool] = False, + attention_scale: Optional[float] = None, + softcap: Optional[float] = None, ): super().__init__(theta) @@ -46,7 +47,8 @@ def __init__( self.head_count = head_count self.head_dim = head_dim self.head_count_kv = head_count_kv - self.use_grok = use_grok + self.attention_scale = attention_scale + self.softcap = softcap self.add_module( "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon) @@ -56,7 +58,12 @@ def __init__( self.add_module("attn_v", LinearLayer(theta("attn_v"))) self.add_module("attn_output", LinearLayer(theta("attn_output"))) - if self.use_grok: + if theta.optional_tensor("attn_output_norm") is None: + self.add_module( + "attn_output_norm", + torch.nn.Identity(), + ) + else: self.add_module( "attn_output_norm", RMSNormLayer(theta("attn_output_norm"), epsilon=rms_epsilon), @@ -147,16 +154,16 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: keys = xk.transpose(1, 2) values = xv.transpose(1, 2) + attn_weights = ops.matmul(xq, keys.transpose(2, 3)) + if self.attention_scale is None: + attn_weights = attn_weights / math.sqrt(self.head_dim) + else: + attn_weights = attn_weights * self.attention_scale + # Flash attention. - if not self.use_grok: - attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt( - self.head_dim - ) - elif self.use_grok: - attn_weights = ops.matmul(xq, keys.transpose(2, 3)) - attn_weights = 30.0 * torch.tanh( - attn_weights * (0.08838834764831845 / 30.0) - ) + if self.softcap is not None: + attn_weights = self.softcap * torch.tanh(attn_weights / self.softcap) + self.assert_not_nan(attn_weights) # Apply attention mask. @@ -172,12 +179,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: # Project. attn_output = self.attn_output(attn_output) - - if self.use_grok: - attn_output = self.attn_output_norm(attn_output) + attn_output = self.attn_output_norm(attn_output) h = h + attn_output - return h def transact_cache_direct( diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index debeb30c7..f7e244e1f 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from ...layers import * @@ -82,6 +83,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -93,16 +95,15 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): head_dim=hp.attn_head_dim, head_count_kv=hp.attention_head_count_kv, rms_epsilon=hp.attention_layer_norm_rms_epsilon, - use_grok=True, ) ) - self.attn_blocks.append( + self.moe_blocks.append( PreGatherMoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, rms_epsilon=hp.attention_layer_norm_rms_epsilon, - use_grok=True, + activation=F.gelu, ) ) @@ -122,33 +123,32 @@ def prefill( self._assert_device(seq_block_ids) self._assert_device(*cache_state, dtype=self.activation_dtype) h = self.token_embedding(tokens) - h *= 78.38367176906169 + h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - embedding=self.attention_embedding, - start_index=0, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - ) - self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "PreGatherMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) + h = attn_block( + h, + embedding=self.attention_embedding, + start_index=0, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + ) + self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) + + h = moe_block(h) + self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits * 0.5773502691896257 + logits = logits / math.sqrt(3.0) return logits def decode( @@ -200,34 +200,34 @@ def decode( ) h = self.token_embedding(tokens) - h *= 78.38367176906169 + h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + self.attn_blocks, self.moe_blocks + ): if block_idx == 0: self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - start_positions=start_positions, - embedding=self.attention_embedding, - embedding_batch_mask=embedding_batch_mask, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "PreGatherMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) + h = attn_block( + h, + start_positions=start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=embedding_batch_mask, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + softcap=30.0, + ) + self.trace_tensor(f"grok.attn_block.{block_idx}.output", h) + + h = moe_block(h) + self.trace_tensor(f"grok.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits * 0.5773502691896257 + logits = logits / math.sqrt(3.0) return logits diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index 1fc86f87d..34afdd6cd 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -85,6 +85,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -98,7 +99,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): rms_epsilon=hp.attention_layer_norm_rms_epsilon, ) ) - self.attn_blocks.append( + self.moe_blocks.append( SparseMoeBlock( theta("blk", n), expert_count=hp.expert_count, @@ -126,25 +127,26 @@ def prefill( self.trace_tensor("mixtral.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - embedding=self.attention_embedding, - start_index=0, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + h = attn_block( + h, + embedding=self.attention_embedding, + start_index=0, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = moe_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) @@ -202,28 +204,29 @@ def decode( self.trace_tensor("mixtral.token_embedding", h) # Iterate over attention blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "PagedLlamaAttentionBlock": - h = block( - h, - start_positions=start_positions, - embedding=self.attention_embedding, - embedding_batch_mask=embedding_batch_mask, - attention_mask=attention_mask, - cache_state=cache_state, - seq_block_ids=seq_block_ids, - xk_temp=xk_temp, - xv_temp=xv_temp, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + h = attn_block( + h, + start_positions=start_positions, + embedding=self.attention_embedding, + embedding_batch_mask=embedding_batch_mask, + attention_mask=attention_mask, + cache_state=cache_state, + seq_block_ids=seq_block_ids, + xk_temp=xk_temp, + xv_temp=xv_temp, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = moe_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) diff --git a/sharktank/sharktank/models/mixtral/mixtral_ref.py b/sharktank/sharktank/models/mixtral/mixtral_ref.py index 392f60a25..779f7480c 100644 --- a/sharktank/sharktank/models/mixtral/mixtral_ref.py +++ b/sharktank/sharktank/models/mixtral/mixtral_ref.py @@ -66,6 +66,7 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): self.add_module("output_lm_head", LinearLayer(theta("output"))) self.attn_blocks = nn.ModuleList() + self.moe_blocks = nn.ModuleList() for n in range(hp.block_count): self.attn_blocks.append( @@ -78,7 +79,7 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): rms_epsilon=hp.attention_layer_norm_rms_epsilon, ) ) - self.attn_blocks.append( + self.moe_blocks.append( SparseMoeBlock( theta("blk", n), expert_count=hp.expert_count, @@ -130,28 +131,29 @@ def forward( block_count = len(self.attn_blocks) // 2 # print('local_kv_cache, #attn_blocks', len(local_kv_cache), block_count) # Iterate over attention + MoE blocks. - for block_idx, block in enumerate(self.attn_blocks): + for block_idx, (attn_block, moe_block) in enumerate( + zip(self.attn_blocks, self.moe_blocks) + ): # print("block_idx, block", block_idx, block) if block_idx == 0: self.trace_tensor(f"mixtral.attn_block.{block_idx}.input", h) - if block.__class__.__name__ == "LlamaAttentionBlock": - attn_block_idx = block_idx // 2 - block_cache_k = local_kv_cache[attn_block_idx] - block_cache_v = local_kv_cache[block_count + attn_block_idx] - h = block( - h, - cache_k=block_cache_k, - cache_v=block_cache_v, - start_index=start_index, - attention_mask=attention_mask, - ) - self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) - elif block.__class__.__name__ == "SparseMoeBlock": - h = block( - h, - ) - self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) + attn_block_idx = block_idx // 2 + block_cache_k = local_kv_cache[attn_block_idx] + block_cache_v = local_kv_cache[block_count + attn_block_idx] + h = attn_block( + h, + cache_k=block_cache_k, + cache_v=block_cache_v, + start_index=start_index, + attention_mask=attention_mask, + ) + self.trace_tensor(f"mixtral.attn_block.{block_idx}.output", h) + + h = attn_block( + h, + ) + self.trace_tensor(f"mixtral.moe_block.{block_idx}.output", h) h = self.output_norm(h) logits = self.output_lm_head(h) diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py index edf1d9d97..2f3bd3cf0 100644 --- a/sharktank/tests/models/llama/moe_block_test.py +++ b/sharktank/tests/models/llama/moe_block_test.py @@ -21,7 +21,6 @@ def test(self): expert_count=8, expert_used_count=2, rms_epsilon=1e-5, - use_grok=False, ) fxb = FxProgramsBuilder(model) input = make_rand_torch((2, 32, 6144))