Skip to content

Commit

Permalink
Merge branch 'moe'
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Sep 21, 2023
2 parents 1ffb1b1 + ab8ba1f commit ed0408e
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 1 deletion.
File renamed without changes.
5 changes: 5 additions & 0 deletions pipegoose/nn/expert_parallel/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from torch import nn


class MoELayer(nn.Module):
pass
24 changes: 23 additions & 1 deletion pipegoose/nn/expert_parallel/routers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
"""DON'T USE THIS MODULE: Under development."""
from abc import ABC
from enum import Enum, auto

from torch import nn

class MoERouter(ABC):

class RouterType(Enum):
"""An enum for router types."""

TOP_1 = auto()
TOP_2 = auto()


class Router(ABC):
pass


class Top1Router(Router, nn.Module):
def __init__(self):
super().__init__()

def forward(self, inputs):
pass


def get_router(router_type: RouterType) -> Router:
pass
20 changes: 20 additions & 0 deletions tests/nn/expert_parallel/test_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
import torch

from pipegoose.nn.expert_parallel.routers import RouterType, get_router


@pytest.mark.parametrize("router_type", [RouterType.TOP_1])
def test_topk_router(router_type):
SEQ_LEN = 10
HIDDEN_DIM = 64
N_EXPERTS = 5

inputs = torch.randn(SEQ_LEN, HIDDEN_DIM)
router = get_router(router_type)(
n_experts=N_EXPERTS,
)

outputs = router(inputs)

assert isinstance(outputs, torch.Tensor)

0 comments on commit ed0408e

Please sign in to comment.