diff --git a/pipegoose/nn/expert_parallel/expert.py b/pipegoose/nn/expert_parallel/experts.py similarity index 100% rename from pipegoose/nn/expert_parallel/expert.py rename to pipegoose/nn/expert_parallel/experts.py diff --git a/pipegoose/nn/expert_parallel/layers.py b/pipegoose/nn/expert_parallel/layers.py new file mode 100644 index 0000000..5fd8a94 --- /dev/null +++ b/pipegoose/nn/expert_parallel/layers.py @@ -0,0 +1,5 @@ +from torch import nn + + +class MoELayer(nn.Module): + pass diff --git a/pipegoose/nn/expert_parallel/routers.py b/pipegoose/nn/expert_parallel/routers.py index 8b71605..bc32227 100644 --- a/pipegoose/nn/expert_parallel/routers.py +++ b/pipegoose/nn/expert_parallel/routers.py @@ -1,6 +1,28 @@ """DON'T USE THIS MODULE: Under development.""" from abc import ABC +from enum import Enum, auto +from torch import nn -class MoERouter(ABC): + +class RouterType(Enum): + """An enum for router types.""" + + TOP_1 = auto() + TOP_2 = auto() + + +class Router(ABC): + pass + + +class Top1Router(Router, nn.Module): + def __init__(self): + super().__init__() + + def forward(self, inputs): + pass + + +def get_router(router_type: RouterType) -> Router: pass diff --git a/tests/nn/expert_parallel/test_layer.py b/tests/nn/expert_parallel/test_layer.py new file mode 100644 index 0000000..7364e8d --- /dev/null +++ b/tests/nn/expert_parallel/test_layer.py @@ -0,0 +1,28 @@ +import torch +from torch import nn + +from pipegoose.nn.expert_parallel.layers import MoELayer +from pipegoose.nn.expert_parallel.routers import Top1Router + + +def test_moe_layer(): + BATCH_SIZE = 10 + SEQ_LEN = 5 + HIDDEN_DIM = 64 + N_EXPERTS = 10 + + inputs = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM) + expert = nn.Linear(10, 10) + + router = Top1Router(n_experts=N_EXPERTS) + + layer = MoELayer( + expert=expert, + n_experts=10, + router=router, + ) + + outputs = layer(inputs) + + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == (BATCH_SIZE, SEQ_LEN, HIDDEN_DIM) diff --git a/tests/nn/expert_parallel/test_router.py b/tests/nn/expert_parallel/test_router.py index 0a437a3..236f5e2 100644 --- a/tests/nn/expert_parallel/test_router.py +++ b/tests/nn/expert_parallel/test_router.py @@ -1,8 +1,20 @@ +import pytest import torch -from pipegoose.nn.expert_parallel.routers import TopKRouter +from pipegoose.nn.expert_parallel.routers import RouterType, get_router -def test_topk_router(): - torch.randn() - TopKRouter() +@pytest.mark.parametrize("router_type", [RouterType.TOP_1]) +def test_topk_router(router_type): + SEQ_LEN = 10 + HIDDEN_DIM = 64 + N_EXPERTS = 5 + + inputs = torch.randn(SEQ_LEN, HIDDEN_DIM) + router = get_router(router_type)( + n_experts=N_EXPERTS, + ) + + outputs = router(inputs) + + assert isinstance(outputs, torch.Tensor)