-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add zero optimizer parallel * code check * add some comments * update * add some info * zero helper * bug fix * reconstruct * comm fusion * ema update * update * update * update * update * checkpoint merging * fix bug * fix bug * fix bug --------- Co-authored-by: zhaoting <[email protected]>
- Loading branch information
1 parent
aa1c32d
commit 5831703
Showing
7 changed files
with
971 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from mindspore import nn | ||
|
||
from .conv import Conv1d, Conv2d, Conv3d | ||
from .dense import Dense | ||
|
||
# {Original MindSpore Cell: New Cell in ZeRO3} | ||
PARALLEL_MODULES = { | ||
nn.Conv1d: Conv1d, | ||
nn.Conv2d: Conv2d, | ||
nn.Conv3d: Conv3d, | ||
nn.Dense: Dense, | ||
} | ||
__all__ = ["Conv1d", "Conv2d", "Conv3d", "Dense"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from mindspore import nn, ops | ||
from mindspore.communication import get_group_size, get_rank | ||
from mindspore.communication.management import GlobalComm | ||
from mindspore.context import ParallelMode | ||
from mindspore.parallel._utils import _get_parallel_mode | ||
|
||
from .param_wrapper import ZeroParamWrapper | ||
|
||
|
||
class _Conv(nn.Cell): | ||
def __init__(self, net, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None): | ||
super(_Conv, self).__init__(auto_prefix=False) | ||
self.net = net | ||
self.set_param_wrapper(zero_stage, op_group, cell_type) | ||
|
||
def set_param_wrapper(self, zero_stage, op_group, cell_type=None): | ||
self.param_wrapper_w = nn.Identity() | ||
self.param_wrapper_b = nn.Identity() | ||
if zero_stage == 3: | ||
# Init parallel settings | ||
is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL | ||
op_group_size = get_group_size(op_group) if is_parallel else 1 | ||
op_rank_id = get_rank(op_group) if is_parallel else 0 | ||
self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, op_group, cell_type) | ||
split_op = ops.Split(0, op_group_size) | ||
if self.param_wrapper_w.need_rewrite: | ||
self.net.weight.assign_value(split_op(self.net.weight)[op_rank_id]) | ||
if self.net.has_bias: | ||
self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, op_group, cell_type) | ||
if self.param_wrapper_b.need_rewrite: | ||
self.net.bias.assign_value(split_op(self.net.bias)[op_rank_id]) | ||
|
||
|
||
class Conv1d(_Conv): | ||
def construct(self, x): | ||
x = self.net.expand_dims(x, 2) | ||
output = self.net.conv2d(x, self.param_wrapper_w(self.net.weight)) | ||
if self.net.has_bias: | ||
output = self.net.bias_add(output, self.param_wrapper_b(self.net.bias)) | ||
|
||
output = self.net.squeeze(output) | ||
return output | ||
|
||
|
||
class Conv2d(_Conv): | ||
def construct(self, x): | ||
output = self.net.conv2d(x, self.param_wrapper_w(self.net.weight)) | ||
if self.net.has_bias: | ||
output = self.net.bias_add(output, self.param_wrapper_b(self.net.bias)) | ||
return output | ||
|
||
|
||
class Conv3d(_Conv): | ||
def construct(self, x): | ||
weight = self.param_wrapper_w(self.net.weight) | ||
bias = self.param_wrapper_b(self.net.bias) | ||
if self.net.group == 1: | ||
out = self.net.conv3d(x, weight) | ||
if self.net.has_bias: | ||
out = self.net.bias_add(out, bias) | ||
else: | ||
features = self.net.split_1(x) | ||
weights = self.net.split_0(weight) | ||
outputs = () | ||
for i in range(self.net.group): | ||
output = self.net.conv3d(features[i], weights[i]) | ||
outputs = outputs + (output,) | ||
out = self.net.concat(outputs) | ||
if self.net.bias is not None: | ||
new_shape = [1 for _ in range(out.ndim)] | ||
new_shape[1] = self.net.out_channels | ||
out = out + bias.reshape(new_shape) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from mindspore import nn, ops | ||
from mindspore.communication import get_group_size, get_rank | ||
from mindspore.communication.management import GlobalComm | ||
from mindspore.context import ParallelMode | ||
from mindspore.parallel._utils import _get_parallel_mode | ||
|
||
from .param_wrapper import ZeroParamWrapper | ||
|
||
|
||
class Dense(nn.Cell): | ||
def __init__(self, net, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None): | ||
super(Dense, self).__init__(auto_prefix=False) | ||
self.net = net | ||
self.set_param_wrapper(zero_stage, op_group, cell_type) | ||
|
||
def set_param_wrapper(self, zero_stage, op_group, cell_type=None): | ||
self.param_wrapper_w = nn.Identity() | ||
self.param_wrapper_b = nn.Identity() | ||
if zero_stage == 3: | ||
# Init parallel settings | ||
is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL | ||
op_group_size = get_group_size(op_group) if is_parallel else 1 | ||
op_rank_id = get_rank(op_group) if is_parallel else 0 | ||
self.param_wrapper_w = ZeroParamWrapper(self.net.weight, zero_stage, op_group, cell_type) | ||
split_op = ops.Split(0, op_group_size) | ||
if self.param_wrapper_w.need_rewrite: | ||
self.net.weight.assign_value(split_op(self.net.weight)[op_rank_id]) | ||
if self.net.has_bias: | ||
self.param_wrapper_b = ZeroParamWrapper(self.net.bias, zero_stage, op_group, cell_type) | ||
if self.param_wrapper_b.need_rewrite: | ||
self.net.bias.assign_value(split_op(self.net.bias)[op_rank_id]) | ||
|
||
def construct(self, x): | ||
x_shape = x.shape | ||
if len(x_shape) != 2: | ||
x = x.reshape(-1, x_shape[-1]) | ||
x = self.net.matmul(x, self.param_wrapper_w(self.net.weight)) | ||
if self.net.has_bias: | ||
x = self.net.bias_add(x, self.param_wrapper_b(self.net.bias)) | ||
if self.net.activation_flag: | ||
x = self.net.activation(x) | ||
if len(x_shape) != 2: | ||
out_shape = x_shape[:-1] + (x.shape[-1],) | ||
x = x.reshape(out_shape) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import mindspore as ms | ||
from mindspore import nn, ops | ||
from mindspore.communication import get_group_size | ||
from mindspore.communication.management import GlobalComm | ||
from mindspore.context import ParallelMode | ||
from mindspore.parallel._utils import _get_parallel_mode | ||
|
||
|
||
class ZeroParamWrapper(nn.Cell): | ||
""" | ||
a cell to Insert communication operators before and after parameters when `zero_stage == 3`. | ||
""" | ||
|
||
def __init__( | ||
self, param: ms.Parameter, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None | ||
): | ||
super().__init__(auto_prefix=False) | ||
self.op_group = op_group | ||
self.zero_stage = zero_stage | ||
self.cell_type = cell_type | ||
if zero_stage != 3: | ||
raise ValueError(f"ZeroParamWrapper not support zero_stage {zero_stage}.") | ||
|
||
# Init parallel settings | ||
self.is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL | ||
self.op_group_size = get_group_size(self.op_group) if self.is_parallel else 1 | ||
self.allgather = ops.Identity() | ||
self.reduce_scatter = None | ||
self.dtype = param.dtype | ||
self.allreduce = ops.AllReduce(group=self.op_group, op=ops.ReduceOp.SUM) | ||
|
||
self.need_rewrite = self.check_rewrite(param) | ||
if self.need_rewrite: | ||
self.op_allgather = ops.AllGather(group=self.op_group) | ||
self.op_reduce_scatter = ops.ReduceScatter(group=self.op_group, op=ops.ReduceOp.SUM) | ||
|
||
def check_rewrite(self, param): | ||
"""Check the parameter need to split or not.""" | ||
need_rewrite = self.is_parallel | ||
B = param.shape[0] | ||
if not param.parallel_optimizer or B < self.op_group_size or B % self.op_group_size != 0: | ||
need_rewrite = False | ||
return need_rewrite | ||
|
||
def construct(self, param): | ||
if self.need_rewrite: | ||
if self.cell_type is not None: | ||
param = param.to(self.cell_type) | ||
return self.op_allgather(param) | ||
return param | ||
|
||
def bprop(self, param, out, dout): | ||
if self.need_rewrite: | ||
r = self.op_reduce_scatter(dout.to(self.dtype)) / self.op_group_size | ||
return (r,) | ||
dout = self.allreduce(dout.to(self.dtype)) / self.op_group_size | ||
return (dout,) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.