diff --git a/sharktank/sharktank/layers/ffn_moe_block.py b/sharktank/sharktank/layers/ffn_moe_block.py index 73fea9a9e..0746f0fa0 100644 --- a/sharktank/sharktank/layers/ffn_moe_block.py +++ b/sharktank/sharktank/layers/ffn_moe_block.py @@ -12,7 +12,7 @@ from .base import ThetaLayer from .linear import LinearLayer from ..types import Theta, DefaultPrimitiveTensor -from ..ops import einsum_2args +from ..ops import einsum_2args, elementwise __all__ = [ "FFNMOE", @@ -32,6 +32,7 @@ def __init__( self.ffn_gate = theta.tensor("ffn_gate_exps", "weight") self.ffn_up = theta.tensor("ffn_up_exps", "weight") self.ffn_down = theta.tensor("ffn_down_exps", "weight") + self.activation = activation def pre_matmul_gather(self, inputs, weights, experts, einstring="mk,menk->men"): inputs = inputs[:, :] @@ -63,7 +64,7 @@ def forward( expert_gate: torch.Tensor, ): ffn_gate = self.pre_matmul_gather(h, self.ffn_gate, experts) - ffn_gate = ops.elementwise(self.activation, ffn_gate) + ffn_gate = elementwise(self.activation, ffn_gate) ffn_up = self.pre_matmul_gather(h, self.ffn_up, experts) ffn_down = self.pre_matmul_gather( diff --git a/sharktank/tests/models/llama/moe_block_test.py b/sharktank/tests/models/llama/moe_block_test.py index dd8a19649..9b3daabdf 100644 --- a/sharktank/tests/models/llama/moe_block_test.py +++ b/sharktank/tests/models/llama/moe_block_test.py @@ -16,7 +16,7 @@ class MoeBlockTest(unittest.TestCase): def test(self): - model = PreGatherMoeBlock( + model = MoeBlock( theta=make_moe_block_theta()("blk.0"), expert_count=8, expert_used_count=2,