Skip to content

Commit

Permalink
add lm_head and embed_out tensor parallel (#3962)
Browse files Browse the repository at this point in the history
* add lm_head and embed_out tensor parallel

* fix load lm_head.weight name issue

* replace all_reduce with inference_all_reduce

* refactor lm_head tensor parallel

---------

Co-authored-by: Chen, Zhenhuan <[email protected]>
  • Loading branch information
Yejing-Lai and dc3671 authored Oct 9, 2023
1 parent 6b634d0 commit 6763e2d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 2 deletions.
20 changes: 19 additions & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Optional
import torch
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw

Expand Down Expand Up @@ -318,6 +318,11 @@ def _replace(self, child, name, conv_linear_layer):
del data

setattr(child, "replaced", True)
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
child.bias if child.bias is None else torch.nn.parameter.Parameter(
child.bias.to(get_accelerator().current_device_name())), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(child.bias.to(get_accelerator().current_device_name())), self.mp_group)
else:
Expand Down Expand Up @@ -436,3 +441,16 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
self.update_mp_params(child)
self._replace_module(child, name, class_name)
return r_module

def _replace_last_linear_module(self, r_module):
if hasattr(r_module, "lm_head"):
name = "lm_head"
child = r_module.lm_head
elif hasattr(r_module, "embed_out"):
name = "embed_out"
child = r_module.embed_out
else:
return r_module
if child.__class__ in self.linear_policies:
setattr(r_module, name, self.linear_policies[child.__class__](child, name, self.conv_linear_layer))
return r_module
30 changes: 30 additions & 0 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,36 @@ def forward(self, input):
return output


class LmHeadLinearAllreduce(nn.Module):

def __init__(
self,
weight,
rank,
world_size,
bias=None,
mp_group=None,
):
super(LmHeadLinearAllreduce, self).__init__()
self.weight = weight
self.bias = bias
self.mp_group = mp_group
self.rank = rank
self.world_size = world_size

def forward(self, input):
assert input.shape[
-1] % self.world_size == 0, 'Please ensure that self.world_size is divisible by input.shape[-1]'
input_shard = input.shape[-1] // self.world_size
output = torch.matmul(input[:, :, self.rank * input_shard:(self.rank + 1) * input_shard],
self.weight.transpose(-1, -2))
if self.mp_group is not None:
dist.inference_all_reduce(output, group=self.mp_group)
if self.bias is not None:
output += self.bias
return output


class LinearLayer(nn.Module):

def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
Expand Down
11 changes: 10 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
_autotp.update_linear_policies()

# 4. Replace modules
if "lm_head" in all_reduce_linears or "embed_out" in all_reduce_linears:
return _autotp._replace_last_linear_module(module)
return _autotp._replace_module(module)

def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
Expand Down Expand Up @@ -304,6 +306,13 @@ def set_lm_head(module):
if embedding_weight is not None and hasattr(module, "lm_head") and hasattr(
module.lm_head, "weight") and module.lm_head.weight.is_meta:
module.lm_head.weight = embedding_weight
# enable tensor parallel for the last linear
if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and not module.lm_head.weight.is_meta:
module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head")
elif hasattr(module, "embed_out") and hasattr(module.embed_out,
"weight") and not module.embed_out.weight.is_meta:
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
return module

if checkpoint_dict is not None and not config.replace_with_kernel_inject:
# AutoTP shard loading
Expand All @@ -318,7 +327,7 @@ def set_lm_head(module):
checkpoint=checkpoint_file)
pbar.update(1)
gc.collect()
set_lm_head(replaced_module)
replaced_module = set_lm_head(replaced_module)
else:
replaced_module = replace_module(model=model,
orig_class=orig_layer_impl,
Expand Down

0 comments on commit 6763e2d

Please sign in to comment.