Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor llama / mixtral / grok for shared features #267

Merged
merged 6 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions sharktank/sharktank/export_layer/export_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch
import torch.nn.functional as F

from iree.turbine.aot import *

from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch
from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock
from sharktank.layers.mixture_of_experts_block import MoeBlock
from ..utils import cli


Expand Down Expand Up @@ -37,21 +40,21 @@ def main():
action="store_true",
)
parser.add_argument(
"--use-grok",
help="Enable to export Grok model's version of MOE block",
"--use-gelu",
help="Enable to use gelu for moe activation",
action="store_true",
)

args = cli.parse(parser)

bs = args.batch_size

model = PreGatherMoeBlock(
model = MoeBlock(
theta=make_moe_block_theta()("blk.0"),
expert_count=8,
expert_used_count=2,
rms_epsilon=1e-5,
use_grok=args.use_grok,
moe_activation=F.gelu if args.use_gelu else F.silu,
)
fxb = FxProgramsBuilder(model)
input = make_rand_torch((bs, 32, 6144))
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
from .paged_llama_attention_block import PagedLlamaAttentionBlock
from .ffn_block import FFN
from .ffn_moe_block import FFNMOE
from .mixture_of_experts_block import SparseMoeBlock, PreGatherMoeBlock
from .mixture_of_experts_block import MoeBlock

from .configs import *
12 changes: 5 additions & 7 deletions sharktank/sharktank/layers/ffn_moe_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .base import ThetaLayer
from .linear import LinearLayer
from ..types import Theta, DefaultPrimitiveTensor
from ..ops import einsum_2args
from ..ops import einsum_2args, elementwise

__all__ = [
"FFNMOE",
Expand All @@ -24,15 +24,15 @@ class PreGatherFFNMOE(ThetaLayer):
def __init__(
self,
theta: Theta,
use_grok: bool = False,
activation=F.silu,
):

super().__init__(theta)
self.use_grok = use_grok

self.ffn_gate = theta.tensor("ffn_gate_exps", "weight")
self.ffn_up = theta.tensor("ffn_up_exps", "weight")
self.ffn_down = theta.tensor("ffn_down_exps", "weight")
self.activation = activation

def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"):
inputs = inputs[:, :]
Expand Down Expand Up @@ -63,10 +63,8 @@ def forward(
experts: torch.Tensor,
expert_gate: torch.Tensor,
):
if self.use_grok:
ffn_gate = F.gelu(self.pre_matmul_gather(h, self.ffn_gate, experts))
else:
ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate, experts))
ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts)
ffn_gate = elementwise(self.activation, ffn_gate)

ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts)
ffn_down = self.pre_matmul_gather(
Expand Down
6 changes: 3 additions & 3 deletions sharktank/sharktank/layers/llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from typing import Optional

import math

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -110,7 +108,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
values = values.transpose(1, 2)

# Flash attention.
attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt(
self.head_dim
)

# Apply attention mask.
if attention_mask is not None:
Expand Down
114 changes: 9 additions & 105 deletions sharktank/sharktank/layers/mixture_of_experts_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
from .ffn_moe_block import FFNMOE, PreGatherFFNMOE

__all__ = [
"SparseMoeBlock",
"PreGatherMoeBlock",
"MoeBlock",
]


class SparseMoeBlock(ThetaLayer):
class MoeBlock(ThetaLayer):
"""
This implementation considers MoE operations as block-sparse
operations to support imbalanced token assignments to experts.
Expand All @@ -35,108 +34,12 @@ def __init__(
expert_count: int,
expert_used_count: int,
rms_epsilon: float,
):
super().__init__(theta)

# Add router gate
self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp")))

# Add FFN norm
self.add_module(
"ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon)
)

# Add FFN output norm
self.add_module(
"layer_output_norm",
RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon),
)

# Add expert_count x FFN
self.experts = nn.ModuleList(
[FFNMOE(theta, expert_idx=i) for i in range(expert_count)]
)

self.expert_count = expert_count
self.expert_used_count = expert_used_count

def forward(
self,
h: torch.Tensor,
):
ffn_input = self.ffn_norm(h)
batch_size, sequence_length, feature_dim = ffn_input.shape
ffn_input = ffn_input.view(-1, feature_dim)

# For each token, the router calculates the router weights for all experts
# router_logits: (batch_size * sequence_length, expert_count)
router_logits = self.ffn_gate_inp(ffn_input)
router_weights = F.softmax(router_logits, dim=1, dtype=torch.float)

# Select top k experts from router weights
router_weights, top_k_experts = torch.topk(
router_weights, self.expert_used_count, dim=-1
)
router_weights /= router_weights.sum(dim=-1, keepdim=True)
router_weights = router_weights.to(ffn_input.dtype)

moe_output = torch.zeros(
(batch_size * sequence_length, feature_dim), dtype=ffn_input.dtype
)

# Create an expert mask by one hot encoding the selected top k experts
# used to index which expert is to be invoked for each token
# expert_mask: (expert_count, expert_used_count, sequence_length)
expert_mask = F.one_hot(top_k_experts, num_classes=self.expert_count).permute(
2, 1, 0
)

# Iterate over all experts in the model
for expert_idx in range(self.expert_count):
expert_layer = self.experts[expert_idx]
top_k_expert_idx, token_idx = torch.where(expert_mask[expert_idx])

# Given the hidden states, index the tokens assigned to this expert
# and calculate the current expert's hidden state and weigh the
# output expert hidden states by the router weights, based on the
# appropriate tokens
current_expert_tokens = ffn_input[None, token_idx]

current_expert = (
expert_layer(current_expert_tokens)
* router_weights[token_idx, top_k_expert_idx, None]
)

current_expert = current_expert.reshape(-1, feature_dim)

moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype))
moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim)

moe_output = self.layer_output_norm(moe_output)
return h + moe_output


class PreGatherMoeBlock(ThetaLayer):
"""
This implementation considers MoE operations as block-sparse
operations to support imbalanced token assignments to experts.
This enables the MoE to operate at a faster rate and in full capacity without any dropped tokens
(or reduced performance).
"""

def __init__(
self,
theta: Theta,
expert_count: int,
expert_used_count: int,
rms_epsilon: float,
use_grok: Optional[bool] = False,
moe_activation=F.silu,
):
super().__init__(theta)

self.expert_count = expert_count
self.expert_used_count = expert_used_count
self.use_grok = use_grok

# Add router gate
self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp")))
Expand All @@ -146,15 +49,17 @@ def __init__(
"ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon)
)

# Add FFN output norm layer for Grok
if self.use_grok:
# Add optional FFN output norm layer
if theta.optional_tensor("layer_output_norm") is not None:
self.add_module(
"layer_output_norm",
RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon),
)
else:
self.add_module("layer_output_norm", torch.nn.Identity())

# Add expert_count x FFN
self.experts = PreGatherFFNMOE(theta, use_grok=self.use_grok)
self.experts = PreGatherFFNMOE(theta, activation=moe_activation)

def forward(
self,
Expand All @@ -180,7 +85,6 @@ def forward(
moe_output = self.experts(ffn_input, top_k_experts, expert_gate)
moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim)

if self.use_grok:
moe_output = self.layer_output_norm(moe_output)
moe_output = self.layer_output_norm(moe_output)

return h + moe_output
36 changes: 20 additions & 16 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(
head_dim: int,
head_count_kv: int,
rms_epsilon: float,
use_grok: Optional[bool] = False,
attention_scale: Optional[float] = None,
softcap: Optional[float] = None,
):
super().__init__(theta)

Expand All @@ -46,7 +47,8 @@ def __init__(
self.head_count = head_count
self.head_dim = head_dim
self.head_count_kv = head_count_kv
self.use_grok = use_grok
self.attention_scale = attention_scale
self.softcap = softcap

self.add_module(
"attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon)
Expand All @@ -56,7 +58,12 @@ def __init__(
self.add_module("attn_v", LinearLayer(theta("attn_v")))
self.add_module("attn_output", LinearLayer(theta("attn_output")))

if self.use_grok:
if theta.optional_tensor("attn_output_norm") is None:
self.add_module(
"attn_output_norm",
torch.nn.Identity(),
)
else:
self.add_module(
"attn_output_norm",
RMSNormLayer(theta("attn_output_norm"), epsilon=rms_epsilon),
Expand Down Expand Up @@ -147,16 +154,16 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
keys = xk.transpose(1, 2)
values = xv.transpose(1, 2)

attn_weights = ops.matmul(xq, keys.transpose(2, 3))
if self.attention_scale is None:
attn_weights = attn_weights / math.sqrt(self.head_dim)
else:
attn_weights = attn_weights * self.attention_scale

# Flash attention.
if not self.use_grok:
attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt(
self.head_dim
)
elif self.use_grok:
attn_weights = ops.matmul(xq, keys.transpose(2, 3))
attn_weights = 30.0 * torch.tanh(
attn_weights * (0.08838834764831845 / 30.0)
rsuderman marked this conversation as resolved.
Show resolved Hide resolved
)
if self.softcap is not None:
attn_weights = self.softcap * torch.tanh(attn_weights / self.softcap)

self.assert_not_nan(attn_weights)

# Apply attention mask.
Expand All @@ -172,12 +179,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:

# Project.
attn_output = self.attn_output(attn_output)

if self.use_grok:
attn_output = self.attn_output_norm(attn_output)
attn_output = self.attn_output_norm(attn_output)

h = h + attn_output

return h

def transact_cache_direct(
Expand Down
Loading
Loading