Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoting committed Sep 2, 2024
1 parent 29be10b commit 2f78b18
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 40 deletions.
12 changes: 5 additions & 7 deletions mindone/trainers/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ def __init__(
# zero init
self.zero_helper = zero_helper
self.zero_stage = zero_helper.zero_stage if zero_helper is not None else 0
self.need_dp = zero_helper.need_dp if zero_helper is not None else False
self.run_optimizer = zero_helper.run_optimizer if zero_helper is not None else self.optimizer
self.grad_reducer = self.grad_reducer if self.zero_stage in [0, 1] else nn.Identity()
if self.zero_stage != 0:
self.zero_helper.split_params()

def construct(self, *inputs):
# compute loss
Expand All @@ -113,13 +114,10 @@ def construct(self, *inputs):

# 1. compute gradients (of the up-scaled loss w.r.t. the model weights)
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)

# Gradient communication
if self.zero_stage == 1:
grads = self.zero_helper.split_gradients(grads)
if self.zero_stage == 2:
grads = self.zero_helper.reduce_scatter_gradients(grads)
if self.need_dp:
grads = self.zero_helper.dp_allreduce_gradients(grads)
grads = self.zero_helper.cal_gradients(grads)

if self.accum_steps == 1:
grads = self.grad_reducer(grads)
scaling_sens = ops.depend(scaling_sens, grads)
Expand Down
80 changes: 48 additions & 32 deletions mindone/trainers/zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
self.ori_parameters = self.optimizer._parameters
# Init parallel settings
self.is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL
if self.is_parallel and self.zero_stage != 0:
if not self.is_parallel and self.zero_stage != 0:
_logger.warning("Not in DATA_PARALLEL, set zero_stage to 0.")
self.zero_stage = 0
self.split_op = ops.Identity()
Expand All @@ -101,15 +101,8 @@ def __init__(
self.need_dp = False
self.last_assign = False
self.dp_group_size = 1
self.need_allgather = [False] * len(self.optimizer._parameters)
if self.zero_stage in [1, 2] and self.is_parallel:
_logger.info("Clone optimizer.parameters, will increase memory.")
# Because the first input of MindSpore optimizer must be ms.Parameter,
# copy optimizer.parameters for optimizer parameters update.
# It will increase 1/n parameters' memory.
self.optimizer.parameters = self.optimizer.parameters.clone(prefix="wrapper", init="same")
self.optimizer._parameters = self.optimizer.parameters
self.last_assign = True
self.need_allgather = tuple([False] * len(self.optimizer._parameters))

if self.zero_stage in [1, 2, 3] and self.is_parallel:
if self.zero_stage == 2:
self.op_reduce_scatter = ops.ReduceScatter(op=ops.ReduceOp.SUM, group=self.op_group)
Expand All @@ -122,8 +115,7 @@ def __init__(
self.dp_allreduce = ops.AllReduce(op=ops.ReduceOp.SUM, group=dp_group)
self.dp_group_size = ms.Tensor(get_group_size(group=dp_group), ms.float32)
self.split_op = ops.Split(0, self.op_group_size) # optimizer parallel split
self.split_params()
self.need_allgather = tuple(self.need_allgather)

self.hyper_map = ops.HyperMap()
if optimizer_offload:
if isinstance(self.optimizer, nn.AdamWeightDecay):
Expand All @@ -135,9 +127,9 @@ def __init__(
_logger.info(
f"Build TrainOneStepWrapper with ZeRO stage: {self.zero_stage}, "
f"optimizer_offload: {optimizer_offload}, "
f"op_group_size: {self.op_group_size} "
f"op_rank_id: {self.op_rank_id} "
f"dp_group_size: {self.dp_group_size} "
f"op_group_size: {self.op_group_size}, "
f"op_rank_id: {self.op_rank_id}, "
f"dp_group_size: {self.dp_group_size}."
)

def split_param(self, param):
Expand All @@ -161,6 +153,16 @@ def get_optimizer_param_tuples(self):
return param_tuples

def split_params(self):
if self.zero_stage in [1, 2] and self.is_parallel:
_logger.info("Clone optimizer.parameters, will increase memory.")
# Because the first input of MindSpore optimizer must be ms.Parameter,
# copy optimizer.parameters for optimizer parameters update.
# It will increase 1/n parameters' memory.
self.optimizer.parameters = self.optimizer.parameters.clone(prefix="wrapper", init="same")
self.optimizer._parameters = self.optimizer.parameters
self.last_assign = True

self.need_allgather = [False] * len(self.optimizer._parameters)
param_tuples = self.get_optimizer_param_tuples()
for i, param in enumerate(self.optimizer._parameters):
_logger.debug(f"Split optimizer param {param.name} {param.shape}")
Expand All @@ -184,8 +186,8 @@ def split_params(self):
ori_shape = param_tuple[i].shape
param_tuple[i].assign_value(self.split_param(param_tuple[i]))
_logger.debug(f"Optimizer {param_tuple[i].name} " f"from {ori_shape} to {param_tuple[i].shape}")
self.need_allgather = tuple(self.need_allgather)

@ms.jit
def reduce_scatter_gradients(self, gradients):
dtype = gradients[0].dtype
gradients = self.hyper_map(
Expand All @@ -199,7 +201,6 @@ def reduce_scatter_gradients(self, gradients):
)
return gradients

@ms.jit
def dp_allreduce_gradients(self, gradients):
dtype = gradients[0].dtype
gradients = self.hyper_map(
Expand All @@ -212,7 +213,6 @@ def dp_allreduce_gradients(self, gradients):
)
return gradients

@ms.jit
def split_gradients(self, gradients):
gradients = self.hyper_map(
ops.partial(
Expand All @@ -225,7 +225,15 @@ def split_gradients(self, gradients):
)
return gradients

@ms.jit
def cal_gradients(self, gradients):
if self.zero_stage == 1:
gradients = self.split_gradients(gradients)
if self.zero_stage == 2:
gradients = self.reduce_scatter_gradients(gradients)
if self.need_dp:
gradients = self.dp_allreduce_gradients(gradients)
return gradients

def run_optimizer(self, grads):
optim_result = self.optimizer(grads)
if self.zero_stage == 1 or self.zero_stage == 2:
Expand All @@ -246,48 +254,55 @@ class ZeroParamWrapper(nn.Cell):
a cell to Insert communication operators before and after parameters when `zero_stage == 3`.
"""

def __init__(self, param: ms.Parameter, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP):
def __init__(
self, param: ms.Parameter, zero_stage: int = 0, op_group: str = GlobalComm.WORLD_COMM_GROUP, cell_type=None
):
super().__init__(auto_prefix=False)
self.op_group = op_group
self.zero_stage = zero_stage
self.cell_type = cell_type
if zero_stage != 3:
raise ValueError(f"ZeroParamWrapper not support zero_stage {zero_stage}.")
self.need_rewrite = self.check_rewrite(param)

# Init parallel settings
self.is_parallel = _get_parallel_mode() == ParallelMode.DATA_PARALLEL
self.op_group_size = get_group_size(self.op_group) if self.is_parallel else 1
self.allgather = ops.Identity()
self.reduce_scatter = None
if self.need_rewrite and self.zero_stage == 3:

self.need_rewrite = self.check_rewrite(param)
if self.need_rewrite:
self.op_allgather = ops.AllGather(group=self.op_group)
self.op_reduce_scatter = ops.ReduceScatter(group=self.op_group, op=ops.ReduceOp.SUM)

def check_rewrite(self, param):
"""Check the parameter need to split or not."""
need_rewrite = self.is_parallel and self.zero_stage == 3
need_rewrite = self.is_parallel
B = param.shape[0]
if not param.parallel_optimizer or B < self.op_group_size or B % self.op_group_size != 0:
need_rewrite = False
return need_rewrite

def construct(self, param):
if self.need_rewrite:
if self.cell_type is not None:
param = param.to(self.cell_type)
return self.op_allgather(param)
return param

def bprop(self, param, out, dout):
if self.need_rewrite:
r = self.op_reduce_scatter(dout) / self.op_group_size
r = self.op_reduce_scatter(dout.to(param.dtype)) / self.op_group_size
return (r,)
return (dout,)


def get_cell_dtype(cell):
if cell.fp16:
if getattr(cell, "fp16", False):
return ms.float16
if cell.fp32:
if getattr(cell, "fp32", False):
return ms.float32
if cell.bf16:
if getattr(cell, "bf16", False):
return ms.bfloat16
return None

Expand Down Expand Up @@ -337,6 +352,7 @@ def _prepare_network(network: nn.Cell, op_group: str, op_group_size: int = 1, op
rewrite_params, new_cell = rewrite_res
_logger.debug(f"Rewrite cell {name} with params {rewrite_params}")
network.__setattr__(name, new_cell)
cell_type = get_cell_dtype(sub_net)

# parameter name will update after __setattr__, reset to ori parameter name.
for param_name in rewrite_params:
Expand All @@ -345,17 +361,15 @@ def _prepare_network(network: nn.Cell, op_group: str, op_group_size: int = 1, op
for param_name in rewrite_params:
param = getattr(sub_net, param_name)
# Set zero_param_wrapper same type with sub_net
cell_type = get_cell_dtype(sub_net)
if cell_type:
zero_param_wrapper = ZeroParamWrapper(param, zero_stage=3, op_group=op_group).to_float(
cell_type
)
zero_param_wrapper = ZeroParamWrapper(param, zero_stage=3, op_group=op_group, cell_type=cell_type)
new_cell.__setattr__(f"param_w_{param_name}", zero_param_wrapper)
if zero_param_wrapper.need_rewrite:
split_op = ops.Split(0, op_group_size)
ori_shape = param.shape
new_cell.__getattr__(param_name).assign_value(split_op(param)[op_rank_id])
_logger.debug(f"Cell {name} split {param_name} from {ori_shape} to {param.shape}")
if cell_type and ms.get_context("mode") == ms.PYNATIVE_MODE:
new_cell.to_float(cell_type)

_prepare_network(sub_net, op_group, op_group_size, op_rank_id)

Expand Down Expand Up @@ -417,6 +431,8 @@ def prepare_train_network(

new_network = prepare_network(network, zero_stage, op_group)
zero_helper = ZeroHelper(optimizer, zero_stage, op_group, dp_group, optimizer_offload)
if isinstance(scale_sense, float):
scale_sense = ms.Tensor(scale_sense, ms.float32)
train_network = TrainOneStepWrapper(
new_network,
optimizer,
Expand Down
3 changes: 2 additions & 1 deletion tests/st/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, p=1):
self.conv2.bias.parallel_optimizer = False
self.dense.recompute()
self.conv1.to_float(ms.float16)
self.dense.to_float(ms.float32)

def construct(self, x):
y = self.conv1(x) * self.p
Expand All @@ -57,7 +58,7 @@ def test_zero(x, y, zero_stage=0):
print("-" * 30)
print("-" * 6, f"zero_stage={zero_stage}", "-" * 6)
print("-" * 30)
net = TestNet()
net = nn.WithLossCell(TestNet(), nn.MSELoss())
opt = nn.AdamWeightDecay(net.trainable_params(), learning_rate=1e-3)
train_net = prepare_train_network(net, opt, zero_stage=zero_stage, op_group=GlobalComm.WORLD_COMM_GROUP)

Expand Down

0 comments on commit 2f78b18

Please sign in to comment.