diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 7ac89a233808..ecb2a527f870 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -310,6 +310,7 @@ def __init__(self, for param in param_group['params']: if param.requires_grad: param.grad_accum = None + param.param_idx_in_group = len(trainable_parameters) trainable_parameters.append(param) self.bit16_groups.append(trainable_parameters) @@ -961,7 +962,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient" self.grads_in_ipg_bucket.append(grad_reduc) - self.params_in_ipg_bucket.append((i, param, param_id)) + self.params_in_ipg_bucket.append((i, param.param_idx_in_group, param_id)) #make sure the average tensor function knows how to average the gradients if is_moe_param(param): @@ -1067,7 +1068,8 @@ def average_tensor(self, tensor): process_group = self.dp_process_group # count = 0 - for i, param, param_id in self.params_in_ipg_bucket: + for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[i][param_idx_in_group] process_group = self.dp_process_group @@ -1383,7 +1385,8 @@ def reduce_ipg_grads(self): stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - for _, param, param_id in self.params_in_ipg_bucket: + for group_idx, param_idx_in_group, param_id in self.params_in_ipg_bucket: + param = self.bit16_groups[group_idx][param_idx_in_group] assert self.params_already_reduced[param_id] == False, \ f"The parameter {param_id} has already been reduced. \ diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index 9ee546437f6c..c67a907c6785 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -93,7 +93,8 @@ def strict_average_tensor(tensor): process_group = optimizer.dp_process_group curr_size = 0 pg_offsets = [] - for i, param, param_id in optimizer.params_in_ipg_bucket: + for i, param_idx, param_id in optimizer.params_in_ipg_bucket: + param = optimizer.bit16_groups[i][param_idx] process_group = optimizer.dp_process_group if optimizer.ipg_bucket_has_moe_params: process_group = optimizer.expert_dp_process_group[param.group_name] if is_moe_param(