Skip to content

Commit

Permalink
Refactor llama / mixtral / grok for shared features
Browse files Browse the repository at this point in the history
Many of these features can toggle between depending on architecture.
Replumbing the configurations separately allows better reuse and
understanding of how models vary between eachother.

grok uses a softcap, plumbing a value enables `sc * tanh( v / sc)`
grok has some hardcoded values that have better representations, e.g.
`sqrt(6144)` and `sqrt(3)`.

output normalization is optional but used by mixtral. Presence of the
tensor is sufficient for performing the normalization.
  • Loading branch information
rsuderman committed Oct 11, 2024
1 parent b55065a commit c458027
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 129 deletions.
5 changes: 4 additions & 1 deletion sharktank/sharktank/export_layer/export_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch
import torch.nn.functional as F

from iree.turbine.aot import *

from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch
from sharktank.layers.mixture_of_experts_block import PreGatherMoeBlock
from ..utils import cli
Expand Down Expand Up @@ -51,7 +54,7 @@ def main():
expert_count=8,
expert_used_count=2,
rms_epsilon=1e-5,
use_grok=args.use_grok,
moe_activation=F.gelu if args.use_grok else F.silu,
)
fxb = FxProgramsBuilder(model)
input = make_rand_torch((bs, 32, 6144))
Expand Down
9 changes: 3 additions & 6 deletions sharktank/sharktank/layers/ffn_moe_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@ class PreGatherFFNMOE(ThetaLayer):
def __init__(
self,
theta: Theta,
use_grok: bool = False,
activation=F.silu,
):

super().__init__(theta)
self.use_grok = use_grok

self.ffn_gate = theta.tensor("ffn_gate_exps", "weight")
self.ffn_up = theta.tensor("ffn_up_exps", "weight")
Expand Down Expand Up @@ -63,10 +62,8 @@ def forward(
experts: torch.Tensor,
expert_gate: torch.Tensor,
):
if self.use_grok:
ffn_gate = F.gelu(self.pre_matmul_gather(h, self.ffn_gate, experts))
else:
ffn_gate = F.silu(self.pre_matmul_gather(h, self.ffn_gate, experts))
ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts)
ffn_gate = ops.elementwise(self.activation, ffn_gate)

ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts)
ffn_down = self.pre_matmul_gather(
Expand Down
16 changes: 8 additions & 8 deletions sharktank/sharktank/layers/mixture_of_experts_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def forward(
current_expert = current_expert.reshape(-1, feature_dim)

moe_output.index_add_(0, token_idx, current_expert.to(ffn_input.dtype))
moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim)

moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim)
moe_output = self.layer_output_norm(moe_output)
return h + moe_output

Expand All @@ -130,13 +130,12 @@ def __init__(
expert_count: int,
expert_used_count: int,
rms_epsilon: float,
use_grok: Optional[bool] = False,
moe_activation=F.silu,
):
super().__init__(theta)

self.expert_count = expert_count
self.expert_used_count = expert_used_count
self.use_grok = use_grok

# Add router gate
self.add_module("ffn_gate_inp", LinearLayer(theta("ffn_gate_inp")))
Expand All @@ -146,15 +145,17 @@ def __init__(
"ffn_norm", RMSNormLayer(theta("ffn_norm"), epsilon=rms_epsilon)
)

# Add FFN output norm layer for Grok
if self.use_grok:
# Add optional FFN output norm layer
if theta.optional_tensor("layer_output_norm") is not None:
self.add_module(
"layer_output_norm",
RMSNormLayer(theta("layer_output_norm"), epsilon=rms_epsilon),
)
else:
self.add_module("layer_output_norm", torch.nn.Identity())

# Add expert_count x FFN
self.experts = PreGatherFFNMOE(theta, use_grok=self.use_grok)
self.experts = PreGatherFFNMOE(theta, activation=moe_activation)

def forward(
self,
Expand All @@ -180,7 +181,6 @@ def forward(
moe_output = self.experts(ffn_input, top_k_experts, expert_gate)
moe_output = moe_output.reshape(batch_size, sequence_length, feature_dim)

if self.use_grok:
moe_output = self.layer_output_norm(moe_output)
moe_output = self.layer_output_norm(moe_output)

return h + moe_output
36 changes: 20 additions & 16 deletions sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(
head_dim: int,
head_count_kv: int,
rms_epsilon: float,
use_grok: Optional[bool] = False,
attention_scale: Optional[float] = None,
softcap: Optional[float] = None,
):
super().__init__(theta)

Expand All @@ -46,7 +47,8 @@ def __init__(
self.head_count = head_count
self.head_dim = head_dim
self.head_count_kv = head_count_kv
self.use_grok = use_grok
self.attention_scale = attention_scale
self.softcap = softcap

self.add_module(
"attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon)
Expand All @@ -56,7 +58,12 @@ def __init__(
self.add_module("attn_v", LinearLayer(theta("attn_v")))
self.add_module("attn_output", LinearLayer(theta("attn_output")))

if self.use_grok:
if theta.optional_tensor("attn_output_norm") is None:
self.add_module(
"attn_output_norm",
torch.nn.Identity(),
)
else:
self.add_module(
"attn_output_norm",
RMSNormLayer(theta("attn_output_norm"), epsilon=rms_epsilon),
Expand Down Expand Up @@ -147,16 +154,16 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
keys = xk.transpose(1, 2)
values = xv.transpose(1, 2)

attn_weights = ops.matmul(xq, keys.transpose(2, 3))
if self.attention_scale is None:
attn_weights = attn_weights / math.sqrt(self.head_dim)
else:
attn_weights = attn_weights * self.attention_scale

# Flash attention.
if not self.use_grok:
attn_weights = ops.matmul(xq, keys.transpose(2, 3)) / math.sqrt(
self.head_dim
)
elif self.use_grok:
attn_weights = ops.matmul(xq, keys.transpose(2, 3))
attn_weights = 30.0 * torch.tanh(
attn_weights * (0.08838834764831845 / 30.0)
)
if self.softcap is not None:
attn_weights = self.softcap * torch.tanh(attn_weights / self.softcap)

self.assert_not_nan(attn_weights)

# Apply attention mask.
Expand All @@ -172,12 +179,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:

# Project.
attn_output = self.attn_output(attn_output)

if self.use_grok:
attn_output = self.attn_output_norm(attn_output)
attn_output = self.attn_output_norm(attn_output)

h = h + attn_output

return h

def transact_cache_direct(
Expand Down
84 changes: 42 additions & 42 deletions sharktank/sharktank/models/grok/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F


from ...layers import *
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
self.add_module("output_lm_head", LinearLayer(theta("output")))

self.attn_blocks = nn.ModuleList()
self.moe_blocks = nn.ModuleList()

for n in range(hp.block_count):
self.attn_blocks.append(
Expand All @@ -93,16 +95,15 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
head_dim=hp.attn_head_dim,
head_count_kv=hp.attention_head_count_kv,
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
use_grok=True,
)
)
self.attn_blocks.append(
self.moe_blocks.append(
PreGatherMoeBlock(
theta("blk", n),
expert_count=hp.expert_count,
expert_used_count=hp.expert_used_count,
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
use_grok=True,
activation=F.gelu,
)
)

Expand All @@ -122,33 +123,32 @@ def prefill(
self._assert_device(seq_block_ids)
self._assert_device(*cache_state, dtype=self.activation_dtype)
h = self.token_embedding(tokens)
h *= 78.38367176906169
h *= math.sqrt(h.shape[-1])
self.trace_tensor("grok.token_embedding", h)

# Iterate over attention blocks.
for block_idx, block in enumerate(self.attn_blocks):
for block_idx, (attn_block, moe_block) in enumerate(
zip(self.attn_blocks, self.moe_blocks)
):
if block_idx == 0:
self.trace_tensor(f"grok.attn_block.{block_idx}.input", h)

if block.__class__.__name__ == "PagedLlamaAttentionBlock":
h = block(
h,
embedding=self.attention_embedding,
start_index=0,
attention_mask=attention_mask,
cache_state=cache_state,
seq_block_ids=seq_block_ids,
)
self.trace_tensor(f"grok.attn_block.{block_idx}.output", h)
elif block.__class__.__name__ == "PreGatherMoeBlock":
h = block(
h,
)
self.trace_tensor(f"grok.moe_block.{block_idx}.output", h)
h = attn_block(
h,
embedding=self.attention_embedding,
start_index=0,
attention_mask=attention_mask,
cache_state=cache_state,
seq_block_ids=seq_block_ids,
)
self.trace_tensor(f"grok.attn_block.{block_idx}.output", h)

h = moe_block(h)
self.trace_tensor(f"grok.moe_block.{block_idx}.output", h)

h = self.output_norm(h)
logits = self.output_lm_head(h)
logits = logits * 0.5773502691896257
logits = logits / math.sqrt(3.0)
return logits

def decode(
Expand Down Expand Up @@ -200,34 +200,34 @@ def decode(
)

h = self.token_embedding(tokens)
h *= 78.38367176906169
h *= math.sqrt(h.shape[-1])
self.trace_tensor("grok.token_embedding", h)

# Iterate over attention blocks.
for block_idx, block in enumerate(self.attn_blocks):
for block_idx, (attn_block, moe_block) in enumerate(
self.attn_blocks, self.moe_blocks
):
if block_idx == 0:
self.trace_tensor(f"grok.attn_block.{block_idx}.input", h)

if block.__class__.__name__ == "PagedLlamaAttentionBlock":
h = block(
h,
start_positions=start_positions,
embedding=self.attention_embedding,
embedding_batch_mask=embedding_batch_mask,
attention_mask=attention_mask,
cache_state=cache_state,
seq_block_ids=seq_block_ids,
xk_temp=xk_temp,
xv_temp=xv_temp,
)
self.trace_tensor(f"grok.attn_block.{block_idx}.output", h)
elif block.__class__.__name__ == "PreGatherMoeBlock":
h = block(
h,
)
self.trace_tensor(f"grok.moe_block.{block_idx}.output", h)
h = attn_block(
h,
start_positions=start_positions,
embedding=self.attention_embedding,
embedding_batch_mask=embedding_batch_mask,
attention_mask=attention_mask,
cache_state=cache_state,
seq_block_ids=seq_block_ids,
xk_temp=xk_temp,
xv_temp=xv_temp,
softcap=30.0,
)
self.trace_tensor(f"grok.attn_block.{block_idx}.output", h)

h = moe_block(h)
self.trace_tensor(f"grok.moe_block.{block_idx}.output", h)

h = self.output_norm(h)
logits = self.output_lm_head(h)
logits = logits * 0.5773502691896257
logits = logits / math.sqrt(3.0)
return logits
Loading

0 comments on commit c458027

Please sign in to comment.