Skip to content

Commit

Permalink
Refactor llama / mixtral / grok for shared features
Browse files Browse the repository at this point in the history
Many of these features can toggle between depending on architecture.
Replumbing the configurations separately allows better reuse and
understanding of how models vary between eachother.

grok uses a softcap, plumbing a value enables `sc * tanh( v / sc)`
grok has some hardcoded values that have better representations, e.g.
`sqrt(6144)` and `sqrt(3)`.

output normalization is optional but used by mixtral. Presence of the
tensor is sufficient for performing the normalization.

We remove the sparse moe block as we now know it will not be used due to
poor performance.
  • Loading branch information
rsuderman committed Oct 11, 2024
1 parent af2fe79 commit 2e9e006
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 106 deletions.
4 changes: 2 additions & 2 deletions sharktank/sharktank/export_layer/export_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from iree.turbine.aot import *

from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch
from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock
from sharktank.layers.mixture_of_experts_block import MoeBlock
from ..utils import cli


Expand Down Expand Up @@ -49,7 +49,7 @@ def main():

bs = args.batch_size

model = PreGatherMoeBlock(
model = MoeBlock(
theta=make_moe_block_theta()("blk.0"),
expert_count=8,
expert_used_count=2,
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
from .paged_llama_attention_block import PagedLlamaAttentionBlock
from .ffn_block import FFN
from .ffn_moe_block import FFNMOE
from .mixture_of_experts_block import SparseMoeBlock, PreGatherMoeBlock
from .mixture_of_experts_block import MoeBlock

from .configs import *
100 changes: 2 additions & 98 deletions sharktank/sharktank/layers/mixture_of_experts_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,107 +16,11 @@
from .ffn_moe_block import FFNMOE, PreGatherFFNMOE

__all__ = [
"SparseMoeBlock",
"PreGatherMoeBlock",
"MoeBlock",
]


class SparseMoeBlock(ThetaLayer):
"""
This implementation considers MoE operations as block-sparse
operations to support imbalanced token assignments to experts.
This enables the MoE to operate at a faster rate and in full capacity without any dropped tokens
(or reduced performance).
"""

def __init__(
self,
theta: Theta,
expert_count: int,
expert_used_count: int,
rms_epsilon: float,
):
super().__init__(theta)

# Add router gate
self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp")))

# Add FFN norm
self.add_module(
"ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon)
)

# Add FFN output norm
self.add_module(
"layer_output_norm",
RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon),
)

# Add expert_count x FFN
self.experts = nn.ModuleList(
[FFNMOE(theta, expert_idx=i) for i in range(expert_count)]
)

self.expert_count = expert_count
self.expert_used_count = expert_used_count

def forward(
self,
h: torch.Tensor,
):
ffn_input = self.ffn_norm(h)
batch_size, sequence_length, feature_dim = ffn_input.shape
ffn_input = ffn_input.view(-1, feature_dim)

# For each token, the router calculates the router weights for all experts
# router_logits: (batch_size * sequence_length, expert_count)
router_logits = self.ffn_gate_inp(ffn_input)
router_weights = F.softmax(router_logits, dim=1, dtype=torch.float)

# Select top k experts from router weights
router_weights, top_k_experts = torch.topk(
router_weights, self.expert_used_count, dim=-1
)
router_weights /= router_weights.sum(dim=-1, keepdim=True)
router_weights = router_weights.to(ffn_input.dtype)

moe_output = torch.zeros(
(batch_size * sequence_length, feature_dim), dtype=ffn_input.dtype
)

# Create an expert mask by one hot encoding the selected top k experts
# used to index which expert is to be invoked for each token
# expert_mask: (expert_count, expert_used_count, sequence_length)
expert_mask = F.one_hot(top_k_experts, num_classes=self.expert_count).permute(
2, 1, 0
)

# Iterate over all experts in the model
for expert_idx in range(self.expert_count):
expert_layer = self.experts[expert_idx]
top_k_expert_idx, token_idx = torch.where(expert_mask[expert_idx])

# Given the hidden states, index the tokens assigned to this expert
# and calculate the current expert's hidden state and weigh the
# output expert hidden states by the router weights, based on the
# appropriate tokens
current_expert_tokens = ffn_input[None, token_idx]

current_expert = (
expert_layer(current_expert_tokens)
* router_weights[token_idx, top_k_expert_idx, None]
)

current_expert = current_expert.reshape(-1, feature_dim)

moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype))

moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim)
moe_output = self.layer_output_norm(moe_output)
return h + moe_output


class PreGatherMoeBlock(ThetaLayer):
class MoeBlock(ThetaLayer):
"""
This implementation considers MoE operations as block-sparse
operations to support imbalanced token assignments to experts.
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/models/grok/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
)
)
self.moe_blocks.append(
PreGatherMoeBlock(
MoeBlock(
theta("blk", n),
expert_count=hp.expert_count,
expert_used_count=hp.expert_used_count,
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/models/mixtral/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
)
)
self.moe_blocks.append(
SparseMoeBlock(
MoeBlock(
theta("blk", n),
expert_count=hp.expert_count,
expert_used_count=hp.expert_used_count,
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/models/mixtral/mixtral_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, theta: Theta, config: RefLlamaModelConfig):
)
)
self.moe_blocks.append(
SparseMoeBlock(
MoeBlock(
theta("blk", n),
expert_count=hp.expert_count,
expert_used_count=hp.expert_used_count,
Expand Down
4 changes: 2 additions & 2 deletions sharktank/tests/models/llama/moe_block_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import torch
from iree.turbine.aot import *
from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch
from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock
from sharktank.layers.mixture_of_experts_block import MoeBlock
from sharktank import ops


class SparseMoeBlockTest(unittest.TestCase):
class MoeBlockTest(unittest.TestCase):
def test(self):
model = PreGatherMoeBlock(
theta=make_moe_block_theta()("blk.0"),
Expand Down

0 comments on commit 2e9e006

Please sign in to comment.