-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #39 from xrsrke/feature/moe
[Feature] Add Expert Parallel
- Loading branch information
Showing
22 changed files
with
927 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from pipegoose.nn.data_parallel.data_parallel import DataParallel | ||
from pipegoose.nn.tensor_parallel.tensor_parallel import TensorParallel | ||
from pipegoose.nn.pipeline_parallel.pipeline_parallel import PipelineParallel | ||
from pipegoose.nn.expert_parallel.expert_parallel import ExpertParallel |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from pipegoose.nn.expert_parallel.expert_parallel import ExpertParallel | ||
from pipegoose.nn.expert_parallel.loss import ExpertLoss | ||
from pipegoose.nn.expert_parallel.routers import Top1Router, Top2Router, SwitchNoisePolicy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,87 @@ | ||
"""DON'T USE THIS MODULE: Under development.""" | ||
from copy import deepcopy | ||
from typing import Tuple | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from einops import rearrange | ||
from torch import nn | ||
from torchtyping import TensorType | ||
|
||
from pipegoose.distributed.parallel_context import ParallelContext | ||
from pipegoose.distributed.parallel_mode import ParallelMode | ||
from pipegoose.nn.tensor_parallel._functional import all_reduce | ||
|
||
|
||
class Experts(nn.Module): | ||
def __init__(self, num_experts: int, expert: nn.Module, parallel_context: ParallelContext): | ||
"""A collection of experts in an expert layer.""" | ||
|
||
def __init__( | ||
self, | ||
num_local_experts: int, | ||
expert: nn.Module, | ||
enable_tensor_parallel: bool, | ||
parallel_context: ParallelContext, | ||
): | ||
super().__init__() | ||
self.num_experts = num_experts | ||
self.experts = nn.ModuleList([expert for _ in range(num_experts)]) | ||
self.enable_tensor_parallel = enable_tensor_parallel | ||
self.parallel_context = parallel_context | ||
|
||
def forward(self): | ||
pass | ||
expert = expert() if not isinstance(expert, nn.Module) else expert | ||
self.num_local_experts = num_local_experts | ||
self.experts = nn.ModuleList([deepcopy(expert) for _ in range(num_local_experts)]) | ||
|
||
def forward( | ||
self, | ||
inputs: TensorType["batch_size", "seq_len", "d_model"], | ||
dispatch_order: TensorType["batch_size * seq_len"], | ||
*args, | ||
**kwargs, | ||
) -> TensorType["batch_size", "seq_len", "d_model"]: | ||
outputs = torch.zeros_like(inputs) | ||
|
||
for expert_idx, expert in enumerate(self.experts): | ||
dispatched_inputs, indices = self._get_dispatch_inputs(inputs, dispatch_order, expert_idx) | ||
if dispatched_inputs.numel() == 0: | ||
# NOTE: if there are no tokens to dispatch to the expert, skip the expert | ||
continue | ||
|
||
if len(args) > 1: | ||
# NOTE: In some transformers models, it also passes last | ||
# hidden states or other arguments to the MLP expert. | ||
# how do we detect this and pass the corresponding arguments to the expert? | ||
# For example, hidden_states.shape = (batch_size, seq_len, hidden_size), | ||
# but we need to dispatch the hidden_states to the corresponding expert | ||
expert_output = expert(dispatched_inputs, *args[1][:, indices], **kwargs) | ||
else: | ||
expert_output = expert(dispatched_inputs) | ||
|
||
outputs.view(-1, outputs.size(-1))[indices] = expert_output | ||
|
||
all_reduce( | ||
outputs, | ||
op=dist.ReduceOp.SUM, | ||
parallel_context=self.parallel_context, | ||
parallel_mode=ParallelMode.TENSOR, | ||
) | ||
|
||
return outputs | ||
|
||
@torch.no_grad() | ||
def _get_dispatch_inputs( | ||
self, | ||
inputs: TensorType["batch_size", "seq_len", "d_model"], | ||
dispatch_order: TensorType["batch_size * seq_len"], | ||
expert_idx: int, | ||
) -> Tuple[TensorType["batch_size * seq_len", "d_model"], TensorType["batch_size * seq_len"]]: | ||
"""Dispatch embeddings to the corresponding expert.""" | ||
|
||
def get_global_expert_idx(expert_idx: int) -> int: | ||
rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR) | ||
global_expert_idx = rank * self.num_local_experts + expert_idx | ||
return global_expert_idx | ||
|
||
global_expert_idx = get_global_expert_idx(expert_idx) | ||
token_indices = (dispatch_order == global_expert_idx).nonzero(as_tuple=True)[0] | ||
inputs = rearrange(inputs, "b s d -> (b s) d") | ||
dispatched_inputs = inputs[token_indices] | ||
return dispatched_inputs, token_indices |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from torch import nn | ||
from torchtyping import TensorType | ||
|
||
from pipegoose.distributed.parallel_context import ParallelContext | ||
from pipegoose.nn.expert_parallel.experts import Experts | ||
from pipegoose.nn.expert_parallel.routers import Router | ||
from pipegoose.nn.expert_parallel.utils import get_num_local_experts | ||
|
||
|
||
class ExpertLayer(nn.Module): | ||
""" | ||
An expert layer. | ||
NOTE: Switch Transformer: https://arxiv.org/abs/2101.03961 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_experts: int, | ||
expert: nn.Module, | ||
router: Router, | ||
enable_tensor_parallel: bool, | ||
parallel_context: ParallelContext, | ||
): | ||
super().__init__() | ||
self.router = router | ||
if enable_tensor_parallel is True: | ||
self.num_local_experts = num_experts | ||
else: | ||
self.num_local_experts = get_num_local_experts(num_experts, parallel_context) | ||
|
||
self._experts = Experts(self.num_local_experts, expert, enable_tensor_parallel, parallel_context) | ||
self.parallel_context = parallel_context | ||
|
||
@property | ||
def experts(self) -> nn.ModuleList: | ||
return self._experts.experts | ||
|
||
def forward(self, *args, **kwargs) -> TensorType["batch_size", "seq_len", "d_model"]: | ||
# TODO: use torch.fx to extract the inputs from args, and kwargs | ||
inputs = args[0] | ||
dispatching_order, _, _ = self.router(inputs) | ||
outputs = self._experts(inputs, dispatching_order, *args, **kwargs) | ||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from pipegoose.nn.parallel_mapping import ParallelInfo, ParallelMapping | ||
|
||
|
||
class MLP(ParallelInfo): | ||
pass | ||
|
||
|
||
class ExpertParallelMapping(ParallelMapping): | ||
__MAPPING__ = { | ||
"bloom-560m": [MLP("mlp")], | ||
} | ||
|
||
@staticmethod | ||
def is_mlp(module_name: str) -> bool: | ||
item = ExpertParallelMapping._search(module_name) | ||
if item is None: | ||
return False | ||
return isinstance(item, MLP) |
Oops, something went wrong.