Skip to content

Commit

Permalink
Archana comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Oct 16, 2024
1 parent 265346d commit 63e8659
Show file tree
Hide file tree
Showing 6 changed files with 4 additions and 7 deletions.
2 changes: 0 additions & 2 deletions sharktank/sharktank/layers/llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from typing import Optional

import math

import torch
import torch.nn.functional as F

Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:

attn_weights = ops.matmul(xq, keys.transpose(2, 3))
if self.attention_scale is None:
attn_weights = attn_weights / torch.sqrt(self.head_dim)
attn_weights = attn_weights / math.sqrt(self.head_dim)
else:
attn_weights = attn_weights * self.attention_scale

Expand Down
4 changes: 2 additions & 2 deletions sharktank/sharktank/models/grok/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
head_dim=hp.attn_head_dim,
head_count_kv=hp.attention_head_count_kv,
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
softcap=30.0, # https://github.com/xai-org/grok-1/blob/7050ed204b8206bb8645c7b7bbef7252f79561b0/model.py#L864
)
)
self.moe_blocks.append(
Expand Down Expand Up @@ -123,7 +124,7 @@ def prefill(
self._assert_device(seq_block_ids)
self._assert_device(*cache_state, dtype=self.activation_dtype)
h = self.token_embedding(tokens)
h *= tanh.sqrt(h.shape[-1])
h *= torch.sqrt(h.shape[-1])
self.trace_tensor("grok.token_embedding", h)

# Iterate over attention blocks.
Expand Down Expand Up @@ -220,7 +221,6 @@ def decode(
seq_block_ids=seq_block_ids,
xk_temp=xk_temp,
xv_temp=xv_temp,
softcap=30.0, # https://github.com/xai-org/grok-1/blob/7050ed204b8206bb8645c7b7bbef7252f79561b0/model.py#L864
)
self.trace_tensor(f"grok.attn_block.{block_idx}.output", h)

Expand Down
1 change: 0 additions & 1 deletion sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Optional

from dataclasses import dataclass
import math
from typing import Union

import torch
Expand Down
1 change: 0 additions & 1 deletion sharktank/sharktank/models/llama/llama_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Optional

from dataclasses import dataclass
import math

import torch
import torch.nn as nn
Expand Down
1 change: 1 addition & 0 deletions sharktank/sharktank/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def encode(
raw_rows = self._encode(texts)
max_length = 0
lengths: list[int] = []
raw_rows = [row[1:] for row in raw_rows]
for row in raw_rows:
lengths.append(len(row))
max_length = max(max_length, len(row))
Expand Down

0 comments on commit 63e8659

Please sign in to comment.