From 2e9e006a9194992a1e47affe77874a517daa337f Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 10 Oct 2024 21:27:46 -0700 Subject: [PATCH] Refactor llama / mixtral / grok for shared features Many of these features can toggle between depending on architecture. Replumbing the configurations separately allows better reuse and understanding of how models vary between eachother. grok uses a softcap, plumbing a value enables `sc * tanh( v / sc)` grok has some hardcoded values that have better representations, e.g. `sqrt(6144)` and `sqrt(3)`. output normalization is optional but used by mixtral. Presence of the tensor is sufficient for performing the normalization. We remove the sparse moe block as we now know it will not be used due to poor performance. --- .../sharktank/export_layer/export_moe.py | 4 +- sharktank/sharktank/layers/__init__.py | 2 +- .../layers/mixture_of_experts_block.py | 100 +----------------- sharktank/sharktank/models/grok/grok.py | 2 +- sharktank/sharktank/models/mixtral/mixtral.py | 2 +- .../sharktank/models/mixtral/mixtral_ref.py | 2 +- .../tests/models/llama/moe_block_test.py | 4 +- 7 files changed, 10 insertions(+), 106 deletions(-) diff --git a/sharktank/sharktank/export_layer/export_moe.py b/sharktank/sharktank/export_layer/export_moe.py index b0033b2b2..3db190a5e 100644 --- a/sharktank/sharktank/export_layer/export_moe.py +++ b/sharktank/sharktank/export_layer/export_moe.py @@ -10,7 +10,7 @@ from iree.turbine.aot import * from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch -from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock +from sharktank.layers.mixture_of_experts_block import MoeBlock from ..utils import cli @@ -49,7 +49,7 @@ def main(): bs = args.batch_size - model = PreGatherMoeBlock( + model = MoeBlock( theta=make_moe_block_theta()("blk.0"), expert_count=8, expert_used_count=2, diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index a90def3a9..fd56ec872 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -16,6 +16,6 @@ from .paged_llama_attention_block import PagedLlamaAttentionBlock from .ffn_block import FFN from .ffn_moe_block import FFNMOE -from .mixture_of_experts_block import SparseMoeBlock, PreGatherMoeBlock +from .mixture_of_experts_block import MoeBlock from .configs import * diff --git a/sharktank/sharktank/layers/mixture_of_experts_block.py b/sharktank/sharktank/layers/mixture_of_experts_block.py index 5b62c0d48..ddce16c55 100644 --- a/sharktank/sharktank/layers/mixture_of_experts_block.py +++ b/sharktank/sharktank/layers/mixture_of_experts_block.py @@ -16,107 +16,11 @@ from .ffn_moe_block import FFNMOE, PreGatherFFNMOE __all__ = [ - "SparseMoeBlock", - "PreGatherMoeBlock", + "MoeBlock", ] -class SparseMoeBlock(ThetaLayer): - """ - This implementation considers MoE operations as block-sparse - operations to support imbalanced token assignments to experts. - This enables the MoE to operate at a faster rate and in full capacity without any dropped tokens - (or reduced performance). - """ - - def __init__( - self, - theta: Theta, - expert_count: int, - expert_used_count: int, - rms_epsilon: float, - ): - super().__init__(theta) - - # Add router gate - self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp"))) - - # Add FFN norm - self.add_module( - "ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon) - ) - - # Add FFN output norm - self.add_module( - "layer_output_norm", - RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon), - ) - - # Add expert_count x FFN - self.experts = nn.ModuleList( - [FFNMOE(theta, expert_idx=i) for i in range(expert_count)] - ) - - self.expert_count = expert_count - self.expert_used_count = expert_used_count - - def forward( - self, - h: torch.Tensor, - ): - ffn_input = self.ffn_norm(h) - batch_size, sequence_length, feature_dim = ffn_input.shape - ffn_input = ffn_input.view(-1, feature_dim) - - # For each token, the router calculates the router weights for all experts - # router_logits: (batch_size * sequence_length, expert_count) - router_logits = self.ffn_gate_inp(ffn_input) - router_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - - # Select top k experts from router weights - router_weights, top_k_experts = torch.topk( - router_weights, self.expert_used_count, dim=-1 - ) - router_weights /= router_weights.sum(dim=-1, keepdim=True) - router_weights = router_weights.to(ffn_input.dtype) - - moe_output = torch.zeros( - (batch_size * sequence_length, feature_dim), dtype=ffn_input.dtype - ) - - # Create an expert mask by one hot encoding the selected top k experts - # used to index which expert is to be invoked for each token - # expert_mask: (expert_count, expert_used_count, sequence_length) - expert_mask = F.one_hot(top_k_experts, num_classes=self.expert_count).permute( - 2, 1, 0 - ) - - # Iterate over all experts in the model - for expert_idx in range(self.expert_count): - expert_layer = self.experts[expert_idx] - top_k_expert_idx, token_idx = torch.where(expert_mask[expert_idx]) - - # Given the hidden states, index the tokens assigned to this expert - # and calculate the current expert's hidden state and weigh the - # output expert hidden states by the router weights, based on the - # appropriate tokens - current_expert_tokens = ffn_input[None, token_idx] - - current_expert = ( - expert_layer(current_expert_tokens) - * router_weights[token_idx, top_k_expert_idx, None] - ) - - current_expert = current_expert.reshape(-1, feature_dim) - - moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype)) - - moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim) - moe_output = self.layer_output_norm(moe_output) - return h + moe_output - - -class PreGatherMoeBlock(ThetaLayer): +class MoeBlock(ThetaLayer): """ This implementation considers MoE operations as block-sparse operations to support imbalanced token assignments to experts. diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index f7e244e1f..1abbf7113 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -98,7 +98,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ) ) self.moe_blocks.append( - PreGatherMoeBlock( + MoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, diff --git a/sharktank/sharktank/models/mixtral/mixtral.py b/sharktank/sharktank/models/mixtral/mixtral.py index 34afdd6cd..e2995dfde 100644 --- a/sharktank/sharktank/models/mixtral/mixtral.py +++ b/sharktank/sharktank/models/mixtral/mixtral.py @@ -100,7 +100,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): ) ) self.moe_blocks.append( - SparseMoeBlock( + MoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, diff --git a/sharktank/sharktank/models/mixtral/mixtral_ref.py b/sharktank/sharktank/models/mixtral/mixtral_ref.py index 779f7480c..70a9b9cf8 100644 --- a/sharktank/sharktank/models/mixtral/mixtral_ref.py +++ b/sharktank/sharktank/models/mixtral/mixtral_ref.py @@ -80,7 +80,7 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig): ) ) self.moe_blocks.append( - SparseMoeBlock( + MoeBlock( theta("blk", n), expert_count=hp.expert_count, expert_used_count=hp.expert_used_count, diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py index 2f3bd3cf0..dd8a19649 100644 --- a/sharktank/tests/models/llama/moe_block_test.py +++ b/sharktank/tests/models/llama/moe_block_test.py @@ -10,11 +10,11 @@ import torch from iree.turbine.aot import * from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch -from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock +from sharktank.layers.mixture_of_experts_block import MoeBlock from sharktank import ops -class SparseMoeBlockTest(unittest.TestCase): +class MoeBlockTest(unittest.TestCase): def test(self): model = PreGatherMoeBlock( theta=make_moe_block_theta()("blk.0"),