From 5beb357db33558860fcef798b441fbdf76ec91c1 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 15 Oct 2024 17:07:39 -0700 Subject: [PATCH] fix zip --- sharktank/sharktank/models/grok/grok.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index 4dfd7ce5c..077e4e064 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import math import torch import torch.nn as nn @@ -124,7 +125,7 @@ def prefill( self._assert_device(seq_block_ids) self._assert_device(*cache_state, dtype=self.activation_dtype) h = self.token_embedding(tokens) - h *= torch.sqrt(h.shape[-1]) + h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. @@ -149,7 +150,7 @@ def prefill( h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits / torch.sqrt(3.0) + logits = logits / math.sqrt(3.0) return logits def decode( @@ -201,12 +202,12 @@ def decode( ) h = self.token_embedding(tokens) - h *= torch.sqrt(h.shape[-1]) + h *= math.sqrt(h.shape[-1]) self.trace_tensor("grok.token_embedding", h) # Iterate over attention blocks. for block_idx, (attn_block, moe_block) in enumerate( - self.attn_blocks, self.moe_blocks + zip(self.attn_blocks, self.moe_blocks) ): if block_idx == 0: self.trace_tensor(f"grok.attn_block.{block_idx}.input", h) @@ -229,5 +230,5 @@ def decode( h = self.output_norm(h) logits = self.output_lm_head(h) - logits = logits / torch.sqrt(3.0) + logits = logits / math.sqrt(3.0) return logits