-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix grok and add ToyGrok model (#651)
The `hf` layout for rotary embedding was broken during a previous refactor. Readding the layout changes fixes this. This also includes a toy grok model for testing. It was validated to produce the same results for prefill and decode.
- Loading branch information
Showing
3 changed files
with
225 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from typing import List | ||
|
||
import torch | ||
|
||
from ...types.tensors import * | ||
from ...types.theta import Theta | ||
from typing import Optional | ||
from ..llama.llama import LlamaModelConfig | ||
import torch | ||
from ...utils.testing import make_rand_torch | ||
from ...layers.testing import make_llama_attention_block_theta | ||
|
||
|
||
def make_attention_block_ffn_theta_v2( | ||
*, | ||
block_idx: int, | ||
head_count: int, | ||
head_count_kv: int, | ||
head_dim: int, | ||
embedding_length: int, | ||
expert_count: int, | ||
dtype: torch.dtype | None = None, | ||
) -> Theta: | ||
attention_theta = make_llama_attention_block_theta( | ||
block_idx=block_idx, | ||
head_count=head_count, | ||
head_count_kv=head_count_kv, | ||
head_dim=head_dim, | ||
embedding_length=embedding_length, | ||
dtype=dtype, | ||
) | ||
moe_theta = make_moe_block_theta( | ||
block_idx=block_idx, | ||
feature_dim=embedding_length, | ||
ffn_dim=embedding_length, | ||
num_experts=expert_count, | ||
) | ||
res_dict = attention_theta.tree | ||
res_dict.update(moe_theta.tree) | ||
return Theta(res_dict) | ||
|
||
|
||
def make_moe_block_theta( | ||
block_idx=0, feature_dim=1024, ffn_dim=6144, num_experts=8 | ||
) -> Theta: | ||
return Theta( | ||
{ | ||
f"blk.{block_idx}.ffn_gate_inp.weight": DefaultPrimitiveTensor( | ||
name=f"blk.{block_idx}.ffn_gate_inp.weight", | ||
data=make_rand_torch((num_experts, ffn_dim)), | ||
), | ||
f"blk.{block_idx}.ffn_norm.weight": DefaultPrimitiveTensor( | ||
name=f"blk.{block_idx}.ffn_norm.weight", data=make_rand_torch((ffn_dim)) | ||
), | ||
f"blk.{block_idx}.layer_output_norm.weight": DefaultPrimitiveTensor( | ||
name=f"blk.{block_idx}.layer_output_norm.weight", | ||
data=make_rand_torch((ffn_dim)), | ||
), | ||
f"blk.{block_idx}.ffn_gate_exps.weight": DefaultPrimitiveTensor( | ||
name=f"blk.{block_idx}.ffn_gate_exps.weight", | ||
data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)), | ||
), | ||
f"blk.{block_idx}.ffn_up_exps.weight": DefaultPrimitiveTensor( | ||
name=f"blk.{block_idx}.ffn_up_exps.weight", | ||
data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)), | ||
), | ||
f"blk.{block_idx}.ffn_down_exps.weight": DefaultPrimitiveTensor( | ||
name=f"blk.{block_idx}.ffn_down_exps.weight", | ||
data=make_rand_torch((num_experts, ffn_dim, feature_dim * num_experts)), | ||
), | ||
} | ||
) | ||
|
||
|
||
def make_random_grok_theta( | ||
config: LlamaModelConfig, vocab_size: int, dtype: Optional[torch.dtype] = None | ||
) -> Theta: | ||
res = { | ||
"token_embd.weight": DefaultPrimitiveTensor( | ||
name="token_embd.weight", | ||
data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype), | ||
) | ||
} | ||
for i in range(config.hp.block_count): | ||
res[f"blk.{i}"] = make_attention_block_ffn_theta_v2( | ||
block_idx=i, | ||
head_count=config.hp.attention_head_count, | ||
head_count_kv=config.hp.attention_head_count_kv, | ||
head_dim=config.hp.attn_head_dim, | ||
embedding_length=config.hp.embedding_length, | ||
expert_count=config.hp.expert_count, | ||
dtype=dtype, | ||
).tree | ||
|
||
res[f"output.weight"] = DefaultPrimitiveTensor( | ||
name="output.weight", | ||
data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype), | ||
) | ||
res[f"output_norm.weight"] = DefaultPrimitiveTensor( | ||
name="output_norm.weight", | ||
data=make_rand_torch((1, config.hp.embedding_length), dtype=dtype), | ||
) | ||
|
||
return Theta(res) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from .testing import make_random_grok_theta | ||
|
||
from sharktank.layers.configs import LlamaHParams | ||
from sharktank.models.llama.llama import LlamaModelConfig | ||
from sharktank.types import Dataset | ||
|
||
import argparse | ||
import torch | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-s", "--seed", default=12345) | ||
parser.add_argument("-o", "--output", default="/tmp/toy_grok.irpa") | ||
|
||
|
||
def main(): | ||
args = parser.parse_args() | ||
torch.manual_seed(args.seed) | ||
|
||
dtype = torch.float32 | ||
block_seq_stride = 16 | ||
max_blocks = 8 | ||
attention_head_count = 8 | ||
attn_head_dim = 16 | ||
attention_head_count_kv = 2 | ||
rope_dimension_count = 16 | ||
vocabulary_size = 256 | ||
expert_count = 4 | ||
used_experts = 2 | ||
|
||
config = LlamaModelConfig( | ||
hp=LlamaHParams( | ||
context_length=block_seq_stride * max_blocks, | ||
embedding_length=attention_head_count * attn_head_dim, | ||
block_count=1, | ||
feed_forward_length=23, | ||
rope_dimension_count=rope_dimension_count, | ||
rope_freq_base=500000.0, | ||
attention_head_count=attention_head_count, | ||
attn_head_dim=attn_head_dim, | ||
attention_layer_norm_rms_epsilon=0.01, | ||
attention_head_count_kv=attention_head_count_kv, | ||
expert_count=expert_count, | ||
expert_used_count=used_experts, | ||
model_arch="grok", | ||
), | ||
block_seq_stride=block_seq_stride, | ||
activation_dtype=dtype, | ||
attention_dtype=dtype, | ||
) | ||
|
||
theta = make_random_grok_theta( | ||
config=config, | ||
vocab_size=vocabulary_size, | ||
) | ||
|
||
config_dict = config.hp.to_gguf_props() | ||
|
||
dataset = Dataset(config_dict, theta) | ||
dataset.save(args.output) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |