Skip to content

Commit

Permalink
Merge branch 'main' into shortfin-windows-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ScottTodd authored Oct 17, 2024
2 parents 830f46b + 95792af commit 5dbd4a5
Show file tree
Hide file tree
Showing 18 changed files with 362 additions and 248 deletions.
13 changes: 8 additions & 5 deletions sharktank/sharktank/export_layer/export_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch
import torch.nn.functional as F

from iree.turbine.aot import *

from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch
from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock
from sharktank.layers.mixture_of_experts_block import MoeBlock
from ..utils import cli


Expand Down Expand Up @@ -37,21 +40,21 @@ def main():
action="store_true",
)
parser.add_argument(
"--use-grok",
help="Enable to export Grok model's version of MOE block",
"--use-gelu",
help="Enable to use gelu for moe activation",
action="store_true",
)

args = cli.parse(parser)

bs = args.batch_size

model = PreGatherMoeBlock(
model = MoeBlock(
theta=make_moe_block_theta()("blk.0"),
expert_count=8,
expert_used_count=2,
rms_epsilon=1e-5,
use_grok=args.use_grok,
moe_activation=F.gelu if args.use_gelu else F.silu,
)
fxb = FxProgramsBuilder(model)
input = make_rand_torch((bs, 32, 6144))
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
from .paged_llama_attention_block import PagedLlamaAttentionBlock
from .ffn_block import FFN
from .ffn_moe_block import FFNMOE
from .mixture_of_experts_block import SparseMoeBlock, PreGatherMoeBlock
from .mixture_of_experts_block import MoeBlock

from .configs import *
12 changes: 5 additions & 7 deletions sharktank/sharktank/layers/ffn_moe_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .base import ThetaLayer
from .linear import LinearLayer
from ..types import Theta, DefaultPrimitiveTensor
from ..ops import einsum_2args
from ..ops import einsum_2args, elementwise

__all__ = [
"FFNMOE",
Expand All @@ -24,15 +24,15 @@ class PreGatherFFNMOE(ThetaLayer):
def __init__(
self,
theta: Theta,
use_grok: bool = False,
activation=F.silu,
):

super().__init__(theta)
self.use_grok = use_grok

self.ffn_gate = theta.tensor("ffn_gate_exps", "weight")
self.ffn_up = theta.tensor("ffn_up_exps", "weight")
self.ffn_down = theta.tensor("ffn_down_exps", "weight")
self.activation = activation

def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"):
inputs = inputs[:, :]
Expand Down Expand Up @@ -63,10 +63,8 @@ def forward(
experts: torch.Tensor,
expert_gate: torch.Tensor,
):
if self.use_grok:
ffn_gate = F.gelu(self.pre_matmul_gather(h, self.ffn_gate, experts))
else:
ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate, experts))
ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts)
ffn_gate = elementwise(self.activation, ffn_gate)

ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts)
ffn_down = self.pre_matmul_gather(
Expand Down
6 changes: 3 additions & 3 deletions sharktank/sharktank/layers/llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from typing import Optional

import math

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -110,7 +108,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
values = values.transpose(1, 2)

# Flash attention.
attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / torch.sqrt(
self.head_dim
)

# Apply attention mask.
if attention_mask is not None:
Expand Down
114 changes: 9 additions & 105 deletions sharktank/sharktank/layers/mixture_of_experts_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
from .ffn_moe_block import FFNMOE, PreGatherFFNMOE

__all__ = [
"SparseMoeBlock",
"PreGatherMoeBlock",
"MoeBlock",
]


class SparseMoeBlock(ThetaLayer):
class MoeBlock(ThetaLayer):
"""
This implementation considers MoE operations as block-sparse
operations to support imbalanced token assignments to experts.
Expand All @@ -35,108 +34,12 @@ def __init__(
expert_count: int,
expert_used_count: int,
rms_epsilon: float,
):
super().__init__(theta)

# Add router gate
self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp")))

# Add FFN norm
self.add_module(
"ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon)
)

# Add FFN output norm
self.add_module(
"layer_output_norm",
RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon),
)

# Add expert_count x FFN
self.experts = nn.ModuleList(
[FFNMOE(theta, expert_idx=i) for i in range(expert_count)]
)

self.expert_count = expert_count
self.expert_used_count = expert_used_count

def forward(
self,
h: torch.Tensor,
):
ffn_input = self.ffn_norm(h)
batch_size, sequence_length, feature_dim = ffn_input.shape
ffn_input = ffn_input.view(-1, feature_dim)

# For each token, the router calculates the router weights for all experts
# router_logits: (batch_size * sequence_length, expert_count)
router_logits = self.ffn_gate_inp(ffn_input)
router_weights = F.softmax(router_logits, dim=1, dtype=torch.float)

# Select top k experts from router weights
router_weights, top_k_experts = torch.topk(
router_weights, self.expert_used_count, dim=-1
)
router_weights /= router_weights.sum(dim=-1, keepdim=True)
router_weights = router_weights.to(ffn_input.dtype)

moe_output = torch.zeros(
(batch_size * sequence_length, feature_dim), dtype=ffn_input.dtype
)

# Create an expert mask by one hot encoding the selected top k experts
# used to index which expert is to be invoked for each token
# expert_mask: (expert_count, expert_used_count, sequence_length)
expert_mask = F.one_hot(top_k_experts, num_classes=self.expert_count).permute(
2, 1, 0
)

# Iterate over all experts in the model
for expert_idx in range(self.expert_count):
expert_layer = self.experts[expert_idx]
top_k_expert_idx, token_idx = torch.where(expert_mask[expert_idx])

# Given the hidden states, index the tokens assigned to this expert
# and calculate the current expert's hidden state and weigh the
# output expert hidden states by the router weights, based on the
# appropriate tokens
current_expert_tokens = ffn_input[None, token_idx]

current_expert = (
expert_layer(current_expert_tokens)
* router_weights[token_idx, top_k_expert_idx, None]
)

current_expert = current_expert.reshape(-1, feature_dim)

moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype))
moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim)

moe_output = self.layer_output_norm(moe_output)
return h + moe_output


class PreGatherMoeBlock(ThetaLayer):
"""
This implementation considers MoE operations as block-sparse
operations to support imbalanced token assignments to experts.
This enables the MoE to operate at a faster rate and in full capacity without any dropped tokens
(or reduced performance).
"""

def __init__(
self,
theta: Theta,
expert_count: int,
expert_used_count: int,
rms_epsilon: float,
use_grok: Optional[bool] = False,
moe_activation=F.silu,
):
super().__init__(theta)

self.expert_count = expert_count
self.expert_used_count = expert_used_count
self.use_grok = use_grok

# Add router gate
self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp")))
Expand All @@ -146,15 +49,17 @@ def __init__(
"ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon)
)

# Add FFN output norm layer for Grok
if self.use_grok:
# Add optional FFN output norm layer
if theta.optional_tensor("layer_output_norm") is not None:
self.add_module(
"layer_output_norm",
RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon),
)
else:
self.add_module("layer_output_norm", torch.nn.Identity())

# Add expert_count x FFN
self.experts = PreGatherFFNMOE(theta, use_grok=self.use_grok)
self.experts = PreGatherFFNMOE(theta, activation=moe_activation)

def forward(
self,
Expand All @@ -180,7 +85,6 @@ def forward(
moe_output = self.experts(ffn_input, top_k_experts, expert_gate)
moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim)

if self.use_grok:
moe_output = self.layer_output_norm(moe_output)
moe_output = self.layer_output_norm(moe_output)

return h + moe_output
36 changes: 20 additions & 16 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(
head_dim: int,
head_count_kv: int,
rms_epsilon: float,
use_grok: Optional[bool] = False,
attention_scale: Optional[float] = None,
softcap: Optional[float] = None,
):
super().__init__(theta)

Expand All @@ -46,7 +47,8 @@ def __init__(
self.head_count = head_count
self.head_dim = head_dim
self.head_count_kv = head_count_kv
self.use_grok = use_grok
self.attention_scale = attention_scale
self.softcap = softcap

self.add_module(
"attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon)
Expand All @@ -56,7 +58,12 @@ def __init__(
self.add_module("attn_v", LinearLayer(theta("attn_v")))
self.add_module("attn_output", LinearLayer(theta("attn_output")))

if self.use_grok:
if theta.optional_tensor("attn_output_norm") is None:
self.add_module(
"attn_output_norm",
torch.nn.Identity(),
)
else:
self.add_module(
"attn_output_norm",
RMSNormLayer(theta("attn_output_norm"), epsilon=rms_epsilon),
Expand Down Expand Up @@ -147,16 +154,16 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
keys = xk.transpose(1, 2)
values = xv.transpose(1, 2)

attn_weights = ops.matmul(xq, keys.transpose(2, 3))
if self.attention_scale is None:
attn_weights = attn_weights / math.sqrt(self.head_dim)
else:
attn_weights = attn_weights * self.attention_scale

# Flash attention.
if not self.use_grok:
attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt(
self.head_dim
)
elif self.use_grok:
attn_weights = ops.matmul(xq, keys.transpose(2, 3))
attn_weights = 30.0 * torch.tanh(
attn_weights * (0.08838834764831845 / 30.0)
)
if self.softcap is not None:
attn_weights = self.softcap * torch.tanh(attn_weights / self.softcap)

self.assert_not_nan(attn_weights)

# Apply attention mask.
Expand All @@ -172,12 +179,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:

# Project.
attn_output = self.attn_output(attn_output)

if self.use_grok:
attn_output = self.attn_output_norm(attn_output)
attn_output = self.attn_output_norm(attn_output)

h = h + attn_output

return h

def transact_cache_direct(
Expand Down
Loading

0 comments on commit 5dbd4a5

Please sign in to comment.