Skip to content

Commit

Permalink
Merge pull request #39 from xrsrke/feature/moe
Browse files Browse the repository at this point in the history
[Feature] Add Expert Parallel
  • Loading branch information
xrsrke authored Nov 27, 2023
2 parents db8ae11 + bc2664a commit 1236baa
Show file tree
Hide file tree
Showing 22 changed files with 927 additions and 80 deletions.
1 change: 1 addition & 0 deletions pipegoose/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pipegoose.nn.data_parallel.data_parallel import DataParallel
from pipegoose.nn.tensor_parallel.tensor_parallel import TensorParallel
from pipegoose.nn.pipeline_parallel.pipeline_parallel import PipelineParallel
from pipegoose.nn.expert_parallel.expert_parallel import ExpertParallel
Binary file removed pipegoose/nn/expert_parallel/.DS_Store
Binary file not shown.
3 changes: 3 additions & 0 deletions pipegoose/nn/expert_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pipegoose.nn.expert_parallel.expert_parallel import ExpertParallel
from pipegoose.nn.expert_parallel.loss import ExpertLoss
from pipegoose.nn.expert_parallel.routers import Top1Router, Top2Router, SwitchNoisePolicy
52 changes: 42 additions & 10 deletions pipegoose/nn/expert_parallel/expert_parallel.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,74 @@
from typing import Callable
import re
from typing import Callable, List, Optional, Union

import torch
from torch import nn

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.expert_parallel.layers import ExpertLayer
from pipegoose.nn.parallel import Parallel


class ExpertParallel(Parallel):
"""
Turn a module into an Mixture of Experts module.
Turn a model into an Mixture of Experts model.
NOTE: The architecture is based on "A Flexible MoE Implementation with Pipeline Parallelism" by Xin Chen et al.
NOTE: The architecture is based on "Pipeline MoE: A Flexible MoE Implementation with Pipeline Parallelism" by Xin Chen et al.
https://arxiv.org/abs/2304.11414
"""

def __init__(
self,
module: nn.Module,
num_experts: int,
expert: nn.Module,
router: Callable,
noise_poligy: Callable,
enable_tensor_parallelism: bool = True,
expert: Optional[nn.Module] = None,
mapping: Optional[List[int]] = None,
router: Union[int, Callable] = 1,
# noise_poligy: Union[str, Callable],
enable_tensor_parallelism: bool = False,
parallel_context: ParallelContext = None,
):
tensor_parallel_size = parallel_context.get_world_size(ParallelMode.TENSOR)
assert parallel_context is not None, "parallel_context must be provided"
assert num_experts % tensor_parallel_size == 0, "The number of experts must be divisible by the tensor parallel size."
num_layers = module.config.num_hidden_layers
assert [
0 <= i < num_layers for i in mapping
], f"There is a layer index that out of range. Expected range: [0, {num_layers}-1]"

if mapping is None:
# NOTE: default mapping is to parallelize all MLP layers
mapping = list(range(module.config.num_hidden_layers))

self.module = module
self.num_experts = num_experts
self.expert = expert
self.mapping = mapping
self.router = router
self.noise_policy = noise_poligy
# self.noise_policy = noise_poligy
self.enable_tensor_parallelism = enable_tensor_parallelism
self.parallel_context = parallel_context

@torch.no_grad()
def parallelize(self):
pass
def parallelize(self) -> nn.Module:
pattern = re.compile(r"^transformer\.h\.(\d+)\.mlp$")

for name, module in self.module.named_modules():
match = pattern.match(name)
if match:
layer_idx = int(match.group(1))
if layer_idx in self.mapping:
expert_layer = ExpertLayer(
self.num_experts,
module if self.expert is None else self.expert,
self.router,
self.enable_tensor_parallelism,
self.parallel_context,
)
getattr(self.module, "transformer").h[layer_idx].mlp = expert_layer

return self.module

@torch.no_grad()
def deparallelize(self):
Expand Down
83 changes: 77 additions & 6 deletions pipegoose/nn/expert_parallel/experts.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,87 @@
"""DON'T USE THIS MODULE: Under development."""
from copy import deepcopy
from typing import Tuple

import torch
import torch.distributed as dist
from einops import rearrange
from torch import nn
from torchtyping import TensorType

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.tensor_parallel._functional import all_reduce


class Experts(nn.Module):
def __init__(self, num_experts: int, expert: nn.Module, parallel_context: ParallelContext):
"""A collection of experts in an expert layer."""

def __init__(
self,
num_local_experts: int,
expert: nn.Module,
enable_tensor_parallel: bool,
parallel_context: ParallelContext,
):
super().__init__()
self.num_experts = num_experts
self.experts = nn.ModuleList([expert for _ in range(num_experts)])
self.enable_tensor_parallel = enable_tensor_parallel
self.parallel_context = parallel_context

def forward(self):
pass
expert = expert() if not isinstance(expert, nn.Module) else expert
self.num_local_experts = num_local_experts
self.experts = nn.ModuleList([deepcopy(expert) for _ in range(num_local_experts)])

def forward(
self,
inputs: TensorType["batch_size", "seq_len", "d_model"],
dispatch_order: TensorType["batch_size * seq_len"],
*args,
**kwargs,
) -> TensorType["batch_size", "seq_len", "d_model"]:
outputs = torch.zeros_like(inputs)

for expert_idx, expert in enumerate(self.experts):
dispatched_inputs, indices = self._get_dispatch_inputs(inputs, dispatch_order, expert_idx)
if dispatched_inputs.numel() == 0:
# NOTE: if there are no tokens to dispatch to the expert, skip the expert
continue

if len(args) > 1:
# NOTE: In some transformers models, it also passes last
# hidden states or other arguments to the MLP expert.
# how do we detect this and pass the corresponding arguments to the expert?
# For example, hidden_states.shape = (batch_size, seq_len, hidden_size),
# but we need to dispatch the hidden_states to the corresponding expert
expert_output = expert(dispatched_inputs, *args[1][:, indices], **kwargs)
else:
expert_output = expert(dispatched_inputs)

outputs.view(-1, outputs.size(-1))[indices] = expert_output

all_reduce(
outputs,
op=dist.ReduceOp.SUM,
parallel_context=self.parallel_context,
parallel_mode=ParallelMode.TENSOR,
)

return outputs

@torch.no_grad()
def _get_dispatch_inputs(
self,
inputs: TensorType["batch_size", "seq_len", "d_model"],
dispatch_order: TensorType["batch_size * seq_len"],
expert_idx: int,
) -> Tuple[TensorType["batch_size * seq_len", "d_model"], TensorType["batch_size * seq_len"]]:
"""Dispatch embeddings to the corresponding expert."""

def get_global_expert_idx(expert_idx: int) -> int:
rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR)
global_expert_idx = rank * self.num_local_experts + expert_idx
return global_expert_idx

global_expert_idx = get_global_expert_idx(expert_idx)
token_indices = (dispatch_order == global_expert_idx).nonzero(as_tuple=True)[0]
inputs = rearrange(inputs, "b s d -> (b s) d")
dispatched_inputs = inputs[token_indices]
return dispatched_inputs, token_indices
44 changes: 44 additions & 0 deletions pipegoose/nn/expert_parallel/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from torch import nn
from torchtyping import TensorType

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.nn.expert_parallel.experts import Experts
from pipegoose.nn.expert_parallel.routers import Router
from pipegoose.nn.expert_parallel.utils import get_num_local_experts


class ExpertLayer(nn.Module):
"""
An expert layer.
NOTE: Switch Transformer: https://arxiv.org/abs/2101.03961
"""

def __init__(
self,
num_experts: int,
expert: nn.Module,
router: Router,
enable_tensor_parallel: bool,
parallel_context: ParallelContext,
):
super().__init__()
self.router = router
if enable_tensor_parallel is True:
self.num_local_experts = num_experts
else:
self.num_local_experts = get_num_local_experts(num_experts, parallel_context)

self._experts = Experts(self.num_local_experts, expert, enable_tensor_parallel, parallel_context)
self.parallel_context = parallel_context

@property
def experts(self) -> nn.ModuleList:
return self._experts.experts

def forward(self, *args, **kwargs) -> TensorType["batch_size", "seq_len", "d_model"]:
# TODO: use torch.fx to extract the inputs from args, and kwargs
inputs = args[0]
dispatching_order, _, _ = self.router(inputs)
outputs = self._experts(inputs, dispatching_order, *args, **kwargs)
return outputs
18 changes: 18 additions & 0 deletions pipegoose/nn/expert_parallel/parallel_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pipegoose.nn.parallel_mapping import ParallelInfo, ParallelMapping


class MLP(ParallelInfo):
pass


class ExpertParallelMapping(ParallelMapping):
__MAPPING__ = {
"bloom-560m": [MLP("mlp")],
}

@staticmethod
def is_mlp(module_name: str) -> bool:
item = ExpertParallelMapping._search(module_name)
if item is None:
return False
return isinstance(item, MLP)
Loading

0 comments on commit 1236baa

Please sign in to comment.