Skip to content

Commit

Permalink
Fix grok and add ToyGrok model (#651)
Browse files Browse the repository at this point in the history
The `hf` layout for rotary embedding was broken during a previous
refactor. Readding the layout changes fixes this. This also includes a
toy grok model for testing. It was validated to produce the same results
for prefill and decode.
  • Loading branch information
rsuderman authored Dec 6, 2024
1 parent 2faadd2 commit 7df7ad4
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 40 deletions.
86 changes: 46 additions & 40 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,43 @@ def forward(
rotary_embed_table=self.rotary_embed_table,
)

def _create_interleaved_tensor(_, dim):
"""Creates a tensor which indexes an tensor such that
it alternates between elements of its first and second
half. Intended for use for HuggingFace's rotation
implementation.
Args:
dim: Size of tensor
Returns:
Interleaved indexing tensor
"""
first_half = torch.arange(dim // 2)
second_half = torch.arange(dim // 2, dim)

interleaved_tensor = torch.empty(dim, dtype=torch.long)
interleaved_tensor[0::2] = first_half
interleaved_tensor[1::2] = second_half

return interleaved_tensor

def _create_ordering_tensor(_, dim):
"""Creates a tensor which indexes an tensor such that
it reverses the alternation induced by create_interleaved_tesnor.
Intended for use for HuggingFace's rotation implementation.
Args:
dim: Size of tensor
Returns:
Ordering indexing tensor
"""
order_tensor = torch.empty(dim, dtype=torch.long)
order_tensor[: dim // 2] = torch.arange(0, dim, 2)
order_tensor[dim // 2 :] = torch.arange(1, dim, 2)
return order_tensor

def forward_unsharded(
self,
*,
Expand All @@ -98,46 +135,8 @@ def forward_unsharded(
):
# xq_, xk_ shape: bs, sl, _, dim
# freqs_cis shape: max_sl, dim

def create_interleaved_tensor(dim):
"""Creates a tensor which indexes an tensor such that
it alternates between elements of its first and second
half. Intended for use for HuggingFace's rotation
implementation.
Args:
dim: Size of tensor
Returns:
Interleaved indexing tensor
"""
first_half = torch.arange(dim // 2)
second_half = torch.arange(dim // 2, dim)

interleaved_tensor = torch.empty(dim, dtype=torch.long)
interleaved_tensor[0::2] = first_half
interleaved_tensor[1::2] = second_half

return interleaved_tensor

def create_ordering_tensor(dim):
"""Creates a tensor which indexes an tensor such that
it reverses the alternation induced by create_interleaved_tesnor.
Intended for use for HuggingFace's rotation implementation.
Args:
dim: Size of tensor
Returns:
Ordering indexing tensor
"""
order_tensor = torch.empty(dim, dtype=torch.long)
order_tensor[: dim // 2] = torch.arange(0, dim, 2)
order_tensor[dim // 2 :] = torch.arange(1, dim, 2)
return order_tensor

if self.use_hf:
xt = xt[..., create_interleaved_tensor(xt.shape[-1])]
xt = xt[..., self._create_interleaved_tensor(xt.shape[-1])]
xt_ = xt
_, sl, _, _ = xt_.shape

Expand All @@ -158,7 +157,7 @@ def create_ordering_tensor(dim):
xt_out = ops.view_as_real(xt_)

if self.use_hf:
xt_out = xt_out[..., create_ordering_tensor(xt_out.shape[-1])]
xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])]

return ops.to(xt_out, xt.dtype)

Expand Down Expand Up @@ -229,10 +228,17 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
"""
# xq_, xk_ shape: bs, sl, _, dim
# freqs_cis shape: max_sl, dim

if self.use_hf:
xt = xt[..., self._create_interleaved_tensor(xt.shape[-1])]

xt_ = ops.view_as_complex(xt)
xt_ = xt_ * mask
xt_out = ops.view_as_real(xt_)

if self.use_hf:
xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])]

return xt_out.type_as(xt)

def _compute_rotary_embed_table(self, t):
Expand Down
110 changes: 110 additions & 0 deletions sharktank/sharktank/models/grok/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import List

import torch

from ...types.tensors import *
from ...types.theta import Theta
from typing import Optional
from ..llama.llama import LlamaModelConfig
import torch
from ...utils.testing import make_rand_torch
from ...layers.testing import make_llama_attention_block_theta


def make_attention_block_ffn_theta_v2(
*,
block_idx: int,
head_count: int,
head_count_kv: int,
head_dim: int,
embedding_length: int,
expert_count: int,
dtype: torch.dtype | None = None,
) -> Theta:
attention_theta = make_llama_attention_block_theta(
block_idx=block_idx,
head_count=head_count,
head_count_kv=head_count_kv,
head_dim=head_dim,
embedding_length=embedding_length,
dtype=dtype,
)
moe_theta = make_moe_block_theta(
block_idx=block_idx,
feature_dim=embedding_length,
ffn_dim=embedding_length,
num_experts=expert_count,
)
res_dict = attention_theta.tree
res_dict.update(moe_theta.tree)
return Theta(res_dict)


def make_moe_block_theta(
block_idx=0, feature_dim=1024, ffn_dim=6144, num_experts=8
) -> Theta:
return Theta(
{
f"blk.{block_idx}.ffn_gate_inp.weight": DefaultPrimitiveTensor(
name=f"blk.{block_idx}.ffn_gate_inp.weight",
data=make_rand_torch((num_experts, ffn_dim)),
),
f"blk.{block_idx}.ffn_norm.weight": DefaultPrimitiveTensor(
name=f"blk.{block_idx}.ffn_norm.weight", data=make_rand_torch((ffn_dim))
),
f"blk.{block_idx}.layer_output_norm.weight": DefaultPrimitiveTensor(
name=f"blk.{block_idx}.layer_output_norm.weight",
data=make_rand_torch((ffn_dim)),
),
f"blk.{block_idx}.ffn_gate_exps.weight": DefaultPrimitiveTensor(
name=f"blk.{block_idx}.ffn_gate_exps.weight",
data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)),
),
f"blk.{block_idx}.ffn_up_exps.weight": DefaultPrimitiveTensor(
name=f"blk.{block_idx}.ffn_up_exps.weight",
data=make_rand_torch((num_experts, feature_dim * num_experts, ffn_dim)),
),
f"blk.{block_idx}.ffn_down_exps.weight": DefaultPrimitiveTensor(
name=f"blk.{block_idx}.ffn_down_exps.weight",
data=make_rand_torch((num_experts, ffn_dim, feature_dim * num_experts)),
),
}
)


def make_random_grok_theta(
config: LlamaModelConfig, vocab_size: int, dtype: Optional[torch.dtype] = None
) -> Theta:
res = {
"token_embd.weight": DefaultPrimitiveTensor(
name="token_embd.weight",
data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype),
)
}
for i in range(config.hp.block_count):
res[f"blk.{i}"] = make_attention_block_ffn_theta_v2(
block_idx=i,
head_count=config.hp.attention_head_count,
head_count_kv=config.hp.attention_head_count_kv,
head_dim=config.hp.attn_head_dim,
embedding_length=config.hp.embedding_length,
expert_count=config.hp.expert_count,
dtype=dtype,
).tree

res[f"output.weight"] = DefaultPrimitiveTensor(
name="output.weight",
data=make_rand_torch((vocab_size, config.hp.embedding_length), dtype=dtype),
)
res[f"output_norm.weight"] = DefaultPrimitiveTensor(
name="output_norm.weight",
data=make_rand_torch((1, config.hp.embedding_length), dtype=dtype),
)

return Theta(res)
69 changes: 69 additions & 0 deletions sharktank/sharktank/models/grok/toy_grok.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .testing import make_random_grok_theta

from sharktank.layers.configs import LlamaHParams
from sharktank.models.llama.llama import LlamaModelConfig
from sharktank.types import Dataset

import argparse
import torch

parser = argparse.ArgumentParser()
parser.add_argument("-s", "--seed", default=12345)
parser.add_argument("-o", "--output", default="/tmp/toy_grok.irpa")


def main():
args = parser.parse_args()
torch.manual_seed(args.seed)

dtype = torch.float32
block_seq_stride = 16
max_blocks = 8
attention_head_count = 8
attn_head_dim = 16
attention_head_count_kv = 2
rope_dimension_count = 16
vocabulary_size = 256
expert_count = 4
used_experts = 2

config = LlamaModelConfig(
hp=LlamaHParams(
context_length=block_seq_stride * max_blocks,
embedding_length=attention_head_count * attn_head_dim,
block_count=1,
feed_forward_length=23,
rope_dimension_count=rope_dimension_count,
rope_freq_base=500000.0,
attention_head_count=attention_head_count,
attn_head_dim=attn_head_dim,
attention_layer_norm_rms_epsilon=0.01,
attention_head_count_kv=attention_head_count_kv,
expert_count=expert_count,
expert_used_count=used_experts,
model_arch="grok",
),
block_seq_stride=block_seq_stride,
activation_dtype=dtype,
attention_dtype=dtype,
)

theta = make_random_grok_theta(
config=config,
vocab_size=vocabulary_size,
)

config_dict = config.hp.to_gguf_props()

dataset = Dataset(config_dict, theta)
dataset.save(args.output)


if __name__ == "__main__":
main()

0 comments on commit 7df7ad4

Please sign in to comment.