Skip to content

Commit

Permalink
fix zip
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Oct 16, 2024
1 parent 63e8659 commit 4fb7663
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions sharktank/sharktank/models/grok/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import math

import torch
import torch.nn as nn
Expand Down Expand Up @@ -124,7 +125,7 @@ def prefill(
self._assert_device(seq_block_ids)
self._assert_device(*cache_state, dtype=self.activation_dtype)
h = self.token_embedding(tokens)
h *= torch.sqrt(h.shape[-1])
h *= math.sqrt(h.shape[-1])
self.trace_tensor("grok.token_embedding", h)

# Iterate over attention blocks.
Expand All @@ -149,7 +150,7 @@ def prefill(

h = self.output_norm(h)
logits = self.output_lm_head(h)
logits = logits / torch.sqrt(3.0)
logits = logits / math.sqrt(3.0)
return logits

def decode(
Expand Down Expand Up @@ -201,12 +202,12 @@ def decode(
)

h = self.token_embedding(tokens)
h *= torch.sqrt(h.shape[-1])
h *= math.sqrt(h.shape[-1])
self.trace_tensor("grok.token_embedding", h)

# Iterate over attention blocks.
for block_idx, (attn_block, moe_block) in enumerate(
self.attn_blocks, self.moe_blocks
zip(self.attn_blocks, self.moe_blocks)
):
if block_idx == 0:
self.trace_tensor(f"grok.attn_block.{block_idx}.input", h)
Expand All @@ -229,5 +230,5 @@ def decode(

h = self.output_norm(h)
logits = self.output_lm_head(h)
logits = logits / torch.sqrt(3.0)
logits = logits / math.sqrt(3.0)
return logits

0 comments on commit 4fb7663

Please sign in to comment.