From 62aecb3ac2a010fce66c7049ed26859a4bdfefed Mon Sep 17 00:00:00 2001
From: Rob Suderman <rob.suderman@gmail.com>
Date: Fri, 11 Oct 2024 12:26:19 -0700
Subject: [PATCH] fix tests

---
 sharktank/sharktank/layers/ffn_moe_block.py    | 5 +++--
 sharktank/tests/models/llama/moe_block_test.py | 2 +-
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py
index 73fea9a9e..0746f0fa0 100644
--- a/sharktank/sharktank/layers/ffn_moe_block.py
+++ b/sharktank/sharktank/layers/ffn_moe_block.py
@@ -12,7 +12,7 @@
 from .base import ThetaLayer
 from .linear import LinearLayer
 from ..types import Theta, DefaultPrimitiveTensor
-from ..ops import einsum_2args
+from ..ops import einsum_2args, elementwise
 
 __all__ = [
     "FFNMOE",
@@ -32,6 +32,7 @@ def __init__(
         self.ffn_gate = theta.tensor("ffn_gate_exps", "weight")
         self.ffn_up = theta.tensor("ffn_up_exps", "weight")
         self.ffn_down = theta.tensor("ffn_down_exps", "weight")
+        self.activation = activation
 
     def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"):
         inputs = inputs[:, :]
@@ -63,7 +64,7 @@ def forward(
         expert_gate: torch.Tensor,
     ):
         ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts)
-        ffn_gate = ops.elementwise(self.activation, ffn_gate)
+        ffn_gate = elementwise(self.activation, ffn_gate)
 
         ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts)
         ffn_down = self.pre_matmul_gather(
diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py
index dd8a19649..9b3daabdf 100644
--- a/sharktank/tests/models/llama/moe_block_test.py
+++ b/sharktank/tests/models/llama/moe_block_test.py
@@ -16,7 +16,7 @@
 
 class MoeBlockTest(unittest.TestCase):
     def test(self):
-        model = PreGatherMoeBlock(
+        model = MoeBlock(
             theta=make_moe_block_theta()("blk.0"),
             expert_count=8,
             expert_used_count=2,