diff --git a/pipegoose/nn/expert_parallel/expert.py b/pipegoose/nn/expert_parallel/experts.py similarity index 100% rename from pipegoose/nn/expert_parallel/expert.py rename to pipegoose/nn/expert_parallel/experts.py diff --git a/pipegoose/nn/expert_parallel/layers.py b/pipegoose/nn/expert_parallel/layers.py new file mode 100644 index 0000000..5fd8a94 --- /dev/null +++ b/pipegoose/nn/expert_parallel/layers.py @@ -0,0 +1,5 @@ +from torch import nn + + +class MoELayer(nn.Module): + pass diff --git a/pipegoose/nn/expert_parallel/routers.py b/pipegoose/nn/expert_parallel/routers.py index 8b71605..bc32227 100644 --- a/pipegoose/nn/expert_parallel/routers.py +++ b/pipegoose/nn/expert_parallel/routers.py @@ -1,6 +1,28 @@ """DON'T USE THIS MODULE: Under development.""" from abc import ABC +from enum import Enum, auto +from torch import nn -class MoERouter(ABC): + +class RouterType(Enum): + """An enum for router types.""" + + TOP_1 = auto() + TOP_2 = auto() + + +class Router(ABC): + pass + + +class Top1Router(Router, nn.Module): + def __init__(self): + super().__init__() + + def forward(self, inputs): + pass + + +def get_router(router_type: RouterType) -> Router: pass diff --git a/tests/nn/expert_parallel/test_router.py b/tests/nn/expert_parallel/test_router.py new file mode 100644 index 0000000..236f5e2 --- /dev/null +++ b/tests/nn/expert_parallel/test_router.py @@ -0,0 +1,20 @@ +import pytest +import torch + +from pipegoose.nn.expert_parallel.routers import RouterType, get_router + + +@pytest.mark.parametrize("router_type", [RouterType.TOP_1]) +def test_topk_router(router_type): + SEQ_LEN = 10 + HIDDEN_DIM = 64 + N_EXPERTS = 5 + + inputs = torch.randn(SEQ_LEN, HIDDEN_DIM) + router = get_router(router_type)( + n_experts=N_EXPERTS, + ) + + outputs = router(inputs) + + assert isinstance(outputs, torch.Tensor)