From 7df7ad402f97f9ba256db6962048e81661ce67b4 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 6 Dec 2024 14:15:38 -0800 Subject: [PATCH] Fix grok and add ToyGrok model (#651) The `hf` layout for rotary embedding was broken during a previous refactor. Readding the layout changes fixes this. This also includes a toy grok model for testing. It was validated to produce the same results for prefill and decode. --- .../sharktank/layers/rotary_embedding.py | 86 +++++++------- sharktank/sharktank/models/grok/testing.py | 110 ++++++++++++++++++ sharktank/sharktank/models/grok/toy_grok.py | 69 +++++++++++ 3 files changed, 225 insertions(+), 40 deletions(-) create mode 100644 sharktank/sharktank/models/grok/testing.py create mode 100644 sharktank/sharktank/models/grok/toy_grok.py diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index 0664a9a46..99ecf5057 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -89,6 +89,43 @@ def forward( rotary_embed_table=self.rotary_embed_table, ) + def _create_interleaved_tensor(_, dim): + """Creates a tensor which indexes an tensor such that + it alternates between elements of its first and second + half. Intended for use for HuggingFace's rotation + implementation. + + Args: + dim: Size of tensor + + Returns: + Interleaved indexing tensor + """ + first_half = torch.arange(dim // 2) + second_half = torch.arange(dim // 2, dim) + + interleaved_tensor = torch.empty(dim, dtype=torch.long) + interleaved_tensor[0::2] = first_half + interleaved_tensor[1::2] = second_half + + return interleaved_tensor + + def _create_ordering_tensor(_, dim): + """Creates a tensor which indexes an tensor such that + it reverses the alternation induced by create_interleaved_tesnor. + Intended for use for HuggingFace's rotation implementation. + + Args: + dim: Size of tensor + + Returns: + Ordering indexing tensor + """ + order_tensor = torch.empty(dim, dtype=torch.long) + order_tensor[: dim // 2] = torch.arange(0, dim, 2) + order_tensor[dim // 2 :] = torch.arange(1, dim, 2) + return order_tensor + def forward_unsharded( self, *, @@ -98,46 +135,8 @@ def forward_unsharded( ): # xq_, xk_ shape: bs, sl, _, dim # freqs_cis shape: max_sl, dim - - def create_interleaved_tensor(dim): - """Creates a tensor which indexes an tensor such that - it alternates between elements of its first and second - half. Intended for use for HuggingFace's rotation - implementation. - - Args: - dim: Size of tensor - - Returns: - Interleaved indexing tensor - """ - first_half = torch.arange(dim // 2) - second_half = torch.arange(dim // 2, dim) - - interleaved_tensor = torch.empty(dim, dtype=torch.long) - interleaved_tensor[0::2] = first_half - interleaved_tensor[1::2] = second_half - - return interleaved_tensor - - def create_ordering_tensor(dim): - """Creates a tensor which indexes an tensor such that - it reverses the alternation induced by create_interleaved_tesnor. - Intended for use for HuggingFace's rotation implementation. - - Args: - dim: Size of tensor - - Returns: - Ordering indexing tensor - """ - order_tensor = torch.empty(dim, dtype=torch.long) - order_tensor[: dim // 2] = torch.arange(0, dim, 2) - order_tensor[dim // 2 :] = torch.arange(1, dim, 2) - return order_tensor - if self.use_hf: - xt = xt[..., create_interleaved_tensor(xt.shape[-1])] + xt = xt[..., self._create_interleaved_tensor(xt.shape[-1])] xt_ = xt _, sl, _, _ = xt_.shape @@ -158,7 +157,7 @@ def create_ordering_tensor(dim): xt_out = ops.view_as_real(xt_) if self.use_hf: - xt_out = xt_out[..., create_ordering_tensor(xt_out.shape[-1])] + xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])] return ops.to(xt_out, xt.dtype) @@ -229,10 +228,17 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor): """ # xq_, xk_ shape: bs, sl, _, dim # freqs_cis shape: max_sl, dim + + if self.use_hf: + xt = xt[..., self._create_interleaved_tensor(xt.shape[-1])] + xt_ = ops.view_as_complex(xt) xt_ = xt_ * mask xt_out = ops.view_as_real(xt_) + if self.use_hf: + xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])] + return xt_out.type_as(xt) def _compute_rotary_embed_table(self, t): diff --git a/sharktank/sharktank/models/grok/testing.py b/sharktank/sharktank/models/grok/testing.py new file mode 100644 index 000000000..d77774c2d --- /dev/null +++ b/sharktank/sharktank/models/grok/testing.py @@ -0,0 +1,110 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List + +import torch + +from ...types.tensors import * +from ...types.theta import Theta +from typing import Optional +from ..llama.llama import LlamaModelConfig +import torch +from ...utils.testing import make_rand_torch +from ...layers.testing import make_llama_attention_block_theta + + +def make_attention_block_ffn_theta_v2( + *, + block_idx: int, + head_count: int, + head_count_kv: int, + head_dim: int, + embedding_length: int, + expert_count: int, + dtype: torch.dtype | None = None, +) -> Theta: + attention_theta = make_llama_attention_block_theta( + block_idx=block_idx, + head_count=head_count, + head_count_kv=head_count_kv, + head_dim=head_dim, + embedding_length=embedding_length, + dtype=dtype, + ) + moe_theta = make_moe_block_theta( + block_idx=block_idx, + feature_dim=embedding_length, + ffn_dim=embedding_length, + num_experts=expert_count, + ) + res_dict = attention_theta.tree + res_dict.update(moe_theta.tree) + return Theta(res_dict) + + +def make_moe_block_theta( + block_idx=0, feature_dim=1024, ffn_dim=6144, num_experts=8 +) -> Theta: + return Theta( + { + f"blk.{block_idx}.ffn_gate_inp.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_gate_inp.weight", + data=make_rand_torch((num_experts, ffn_dim)), + ), + f"blk.{block_idx}.ffn_norm.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_norm.weight", data=make_rand_torch((ffn_dim)) + ), + f"blk.{block_idx}.layer_output_norm.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.layer_output_norm.weight", + data=make_rand_torch((ffn_dim)), + ), + f"blk.{block_idx}.ffn_gate_exps.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_gate_exps.weight", + data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)), + ), + f"blk.{block_idx}.ffn_up_exps.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_up_exps.weight", + data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)), + ), + f"blk.{block_idx}.ffn_down_exps.weight": DefaultPrimitiveTensor( + name=f"blk.{block_idx}.ffn_down_exps.weight", + data=make_rand_torch((num_experts, ffn_dim, feature_dim * num_experts)), + ), + } + ) + + +def make_random_grok_theta( + config: LlamaModelConfig, vocab_size: int, dtype: Optional[torch.dtype] = None +) -> Theta: + res = { + "token_embd.weight": DefaultPrimitiveTensor( + name="token_embd.weight", + data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype), + ) + } + for i in range(config.hp.block_count): + res[f"blk.{i}"] = make_attention_block_ffn_theta_v2( + block_idx=i, + head_count=config.hp.attention_head_count, + head_count_kv=config.hp.attention_head_count_kv, + head_dim=config.hp.attn_head_dim, + embedding_length=config.hp.embedding_length, + expert_count=config.hp.expert_count, + dtype=dtype, + ).tree + + res[f"output.weight"] = DefaultPrimitiveTensor( + name="output.weight", + data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype), + ) + res[f"output_norm.weight"] = DefaultPrimitiveTensor( + name="output_norm.weight", + data=make_rand_torch((1, config.hp.embedding_length), dtype=dtype), + ) + + return Theta(res) diff --git a/sharktank/sharktank/models/grok/toy_grok.py b/sharktank/sharktank/models/grok/toy_grok.py new file mode 100644 index 000000000..ab57d0c1d --- /dev/null +++ b/sharktank/sharktank/models/grok/toy_grok.py @@ -0,0 +1,69 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .testing import make_random_grok_theta + +from sharktank.layers.configs import LlamaHParams +from sharktank.models.llama.llama import LlamaModelConfig +from sharktank.types import Dataset + +import argparse +import torch + +parser = argparse.ArgumentParser() +parser.add_argument("-s", "--seed", default=12345) +parser.add_argument("-o", "--output", default="/tmp/toy_grok.irpa") + + +def main(): + args = parser.parse_args() + torch.manual_seed(args.seed) + + dtype = torch.float32 + block_seq_stride = 16 + max_blocks = 8 + attention_head_count = 8 + attn_head_dim = 16 + attention_head_count_kv = 2 + rope_dimension_count = 16 + vocabulary_size = 256 + expert_count = 4 + used_experts = 2 + + config = LlamaModelConfig( + hp=LlamaHParams( + context_length=block_seq_stride * max_blocks, + embedding_length=attention_head_count * attn_head_dim, + block_count=1, + feed_forward_length=23, + rope_dimension_count=rope_dimension_count, + rope_freq_base=500000.0, + attention_head_count=attention_head_count, + attn_head_dim=attn_head_dim, + attention_layer_norm_rms_epsilon=0.01, + attention_head_count_kv=attention_head_count_kv, + expert_count=expert_count, + expert_used_count=used_experts, + model_arch="grok", + ), + block_seq_stride=block_seq_stride, + activation_dtype=dtype, + attention_dtype=dtype, + ) + + theta = make_random_grok_theta( + config=config, + vocab_size=vocabulary_size, + ) + + config_dict = config.hp.to_gguf_props() + + dataset = Dataset(config_dict, theta) + dataset.save(args.output) + + +if __name__ == "__main__": + main()