Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/moe #55

Merged
merged 5 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions pipegoose/distributed/_initializers/initialize_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,26 @@
from pipegoose.distributed.parallel_mode import ParallelMode


class ExpertParallelGroupInitializer(ProcessGroupInitializer):
class ExpertDataParallelGroupInitializer(ProcessGroupInitializer):
"""
Initialize the process group for data parallelism in expert parallelism.

Pipeline MoE: A Flexible MoE Implementation with Pipeline Parallelism" by Xin Chen et al
https://arxiv.org/abs/2304.11414

NOTE: This looks similar to TensorParallelGroupInitializer, because it aligns with the paper.
"""

def init_dist_group(self) -> ProcessGroupResult:
num_tensor_parallel_groups = self.world_size // self.tensor_parallel_size
local_rank = None
process_group = None
local_world_size = None
ranks_in_group = None
parallel_mode = ParallelMode.TENSOR
parallel_mode = ParallelMode.EXPERT_DATA

for i in range(num_tensor_parallel_groups):
ranks = list(range(i * self.tensor_parallel_size, (i + 1) * self.tensor_parallel_size))

# NOTE: dist.new_group() must be called collectively by all the processes
# that would be part of the group, which means every process in the group
# needs to call this function. If only a subset of the processes call new_group(),
# it will hang because it's waiting for the rest of the processes to join.
group = dist.new_group(ranks=ranks)

if self.rank in ranks:
Expand Down
15 changes: 14 additions & 1 deletion pipegoose/distributed/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from pipegoose.distributed._initializers.initialize_data import (
DataParallelGroupInitializer,
)
from pipegoose.distributed._initializers.initialize_expert import (
ExpertDataParallelGroupInitializer,
)
from pipegoose.distributed._initializers.initialize_pipeline import (
PipelineParallelGroupInitializer,
)
Expand Down Expand Up @@ -188,6 +191,7 @@ def init_parallel_groups(self):
TensorParallelGroupInitializer(**params).init_dist_group(),
PipelineParallelGroupInitializer(**params).init_dist_group(),
DataParallelGroupInitializer(**params).init_dist_group(),
ExpertDataParallelGroupInitializer(**params).init_dist_group(),
]

for result in results:
Expand Down Expand Up @@ -270,7 +274,16 @@ def map_rank_to_device(self):
dist.all_gather(tensor_list=rank_tensor_list, tensor=rank_tensor)

for _rank, _rank_tensor in enumerate(rank_tensor_list):
modes_and_ranks = {mode: rank for mode, rank in zip(self._local_ranks.keys(), _rank_tensor.tolist())}
# NOTE: In 3D parallelism for MoE, the gpu assignment only depends on
# tensor parallelism, pipeline parallelism and data parallelism.
# according to the paper: Pipeline MoE: A Flexible MoE Implementatio
# with Pipeline Parallelism by Xin Chen et al
# https://arxiv.org/abs/2304.11414
modes_and_ranks = {
mode: rank
for mode, rank in zip(self._local_ranks.keys(), _rank_tensor.tolist())
if mode != ParallelMode.EXPERT_DATA
}
self._ranks_to_device[tuple(modes_and_ranks.items())] = _rank

def ranks2device(self, ranks: RanksToDevice) -> int:
Expand Down
3 changes: 3 additions & 0 deletions pipegoose/distributed/parallel_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ class ParallelMode(Enum):
TENSOR = "tensor"
PIPELINE = "pipeline"
DATA = "data"

# NOTE: for expert data parallelism
EXPERT_DATA = "expert"
15 changes: 12 additions & 3 deletions pipegoose/nn/data_parallel/data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import torch
import torch.distributed as dist
from torch import nn
Expand Down Expand Up @@ -26,9 +28,16 @@ def parallelize(self) -> nn.Module:
def _register_grad_avg_hook(self, module: nn.Module):
for p in module.parameters():
if p.requires_grad is True:
p.register_hook(self._average_grad)
is_expert = getattr(p, "is_expert", False)
p.register_hook(partial(self._average_grad, is_expert=is_expert))

def _average_grad(self, grad: torch.Tensor):
def _average_grad(self, grad: torch.Tensor, is_expert: bool):
# NOTE: (grad1 + grad2 + ... + gradn) / n = grad1/n + grad2/n + ... + gradn/n
grad.div_(self.parallel_context.data_parallel_size)
all_reduce(grad, op=dist.ReduceOp.SUM, parallel_context=self.parallel_context, parallel_mode=ParallelMode.DATA)

all_reduce(
grad,
op=dist.ReduceOp.SUM,
parallel_context=self.parallel_context,
parallel_mode=ParallelMode.EXPERT_DATA if is_expert else ParallelMode.DATA,
)
20 changes: 14 additions & 6 deletions pipegoose/nn/expert_parallel/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def __init__(
num_experts: int,
expert: Optional[nn.Module] = None,
mapping: Optional[List[int]] = None,
router: Union[int, Callable] = 1,
router: Callable = None,
# noise_poligy: Union[str, Callable],
enable_tensor_parallelism: bool = False,
parallel_context: ParallelContext = None
parallel_context: ParallelContext = None,
):
tensor_parallel_size = parallel_context.get_world_size(ParallelMode.TENSOR)
assert parallel_context is not None, "parallel_context must be provided"
Expand All @@ -52,20 +52,28 @@ def __init__(

@torch.no_grad()
def parallelize(self) -> nn.Module:
pattern = re.compile(r"^transformer\.h\.(\d+)\.mlp$")

for name, module in self.module.named_modules():
# TODO: make it generalize
def _is_mlp(name) -> Union[bool, Optional[int]]:
pattern = re.compile(r"^transformer\.h\.(\d+)\.mlp$")
match = pattern.match(name)
if match:
layer_idx = int(match.group(1))
return True, layer_idx
else:
return False, None

for name, module in self.module.named_modules():
is_mlp, layer_idx = _is_mlp(name)
if is_mlp:
if layer_idx in self.mapping:
expert_layer = ExpertLayer(
self.num_experts,
module if self.expert is None else self.expert,
self.router,
self.enable_tensor_parallelism,
self.parallel_context
self.parallel_context,
)
# TODO: make it generalize
getattr(self.module, "transformer").h[layer_idx].mlp = expert_layer

return self.module
Expand Down
8 changes: 8 additions & 0 deletions pipegoose/nn/expert_parallel/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def __init__(
expert = expert() if not isinstance(expert, nn.Module) else expert
self.num_local_experts = num_local_experts
self.experts = nn.ModuleList([deepcopy(expert) for _ in range(num_local_experts)])
self._set_expert_attr(self.experts)

def _set_expert_attr(self, experts: nn.ModuleList):
# NOTE: for filtering out the expert parameters later on
# in data parallelism
for expert in experts:
for p in expert.parameters():
setattr(p, "is_expert", True)

def forward(
self,
Expand Down
16 changes: 4 additions & 12 deletions pipegoose/nn/expert_parallel/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(
expert_capacity: Optional[Tuple[float, float]] = None,
alpha: float = 0.01,
eps: float = 0.1,
aux_loss_weight: float = 1.0,
z_loss_weight: float = 1.0,
# aux_loss_weight: float = 1.0,
# z_loss_weight: float = 1.0,
):
super().__init__()
self.noise_policy = noise_policy
Expand All @@ -66,8 +66,8 @@ def __init__(
self.expert_capacity = expert_capacity
self.alpha = alpha
self.eps = eps
self.aux_loss_weight = aux_loss_weight
self.z_loss_weight = z_loss_weight
# self.aux_loss_weight = aux_loss_weight
# self.z_loss_weight = z_loss_weight
self.gate = nn.Linear(d_model, num_experts)

def _aux_loss(
Expand Down Expand Up @@ -156,8 +156,6 @@ def __init__(
expert_capacity: Optional[Tuple[float, float]] = None,
alpha: float = 0.01,
eps: float = 0.1,
aux_loss_weight: float = 1.0,
z_loss_weight: float = 1.0,
):
super().__init__(
noise_policy=noise_policy,
Expand All @@ -167,8 +165,6 @@ def __init__(
expert_capacity=expert_capacity,
alpha=alpha,
eps=eps,
aux_loss_weight=aux_loss_weight,
z_loss_weight=z_loss_weight,
)


Expand All @@ -181,8 +177,6 @@ def __init__(
expert_capacity: Optional[Tuple[float, float]] = None,
alpha: float = 0.01,
eps: float = 0.1,
aux_loss_weight: float = 1.0,
z_loss_weight: float = 1.0,
):
super().__init__(
noise_policy=noise_policy,
Expand All @@ -192,6 +186,4 @@ def __init__(
expert_capacity=expert_capacity,
alpha=alpha,
eps=eps,
aux_loss_weight=aux_loss_weight,
z_loss_weight=z_loss_weight,
)
22 changes: 21 additions & 1 deletion pipegoose/nn/tensor_parallel/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import nn

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.nn.expert_parallel.layers import ExpertLayer
from pipegoose.nn.parallel import Parallel
from pipegoose.nn.tensor_parallel.parallelizer import (
EmbeddingParallelizer,
Expand Down Expand Up @@ -42,10 +43,29 @@ def parallelize(self) -> nn.Module:
return module

def _get_leaf_modules(self, model: nn.Module) -> List[Tuple[str, nn.Module]]:
"""Return non-expert leaf modules."""
leaf_modules = []
expert_names = []

def is_child_of_expert(module_name):
# NOTE: suppose an mlp expert has name "transformer.h.0.mlp"
# then its children will have names like "transformer.h.0.mlp.{child_name}"
# so we can check if a module is a child of an expert by checking if its name
# starts with "transformer.h.0.mlp"
for expert_name in expert_names:
if module_name.startswith(expert_name):
return True
return False

for module_name, module in model.named_modules():
if list(module.children()):
if isinstance(module, ExpertLayer):
expert_names.append(module_name)
continue

# NOTE: skip leaf modules that belong to ExpertLayer
if is_child_of_expert(module_name) or list(module.children()):
continue

leaf_modules.append((module_name, module))

return leaf_modules
Expand Down
15 changes: 14 additions & 1 deletion pipegoose/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import random
import socket
from functools import partial
from typing import Callable
from typing import Callable, Tuple

import pytest
import torch
Expand Down Expand Up @@ -118,3 +118,16 @@ def calculate_parameter_similarity(module1: nn.Module, module2: nn.Module, rtol:

def count_model_parameters(model):
return sum(p.numel() for p in model.parameters())


def get_microbatch(
inputs, labels, parallel_context: ParallelContext, parallel_mode: ParallelMode
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
local_rank = parallel_context.get_local_rank(parallel_mode)
world_size = parallel_context.get_world_size(parallel_mode)

input_chunks = torch.chunk(inputs["input_ids"], chunks=world_size, dim=0)
attention_chunks = torch.chunk(inputs["attention_mask"], chunks=world_size, dim=0)
label_chunks = torch.chunk(labels, chunks=world_size, dim=0)

return input_chunks[local_rank], attention_chunks[local_rank], label_chunks[local_rank]
Loading