Skip to content

Commit

Permalink
Fix optimizer rollback for distributed optimizer (#26)
Browse files Browse the repository at this point in the history
Co-authored-by: Qi Penghui <[email protected]>
  • Loading branch information
QPHutu and QPH-SAIL authored Jun 28, 2024
1 parent 331ee01 commit e7cf420
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 34 deletions.
23 changes: 21 additions & 2 deletions megatron/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,10 @@ def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
self.optimizer.param_groups = \
[ g["orig_group"] for g in self.opt_group_ranges ]
self.optimizer.load_state_dict(self.optimizer.state_dict())
fp32_size = 0
for groups in self.shard_fp32_groups:
fp32_size += len(groups)
assert fp32_size == 0, "Not supported, because it is rarely used and makes code messy"


def get_model_param_range_map(self, param):
Expand Down Expand Up @@ -1153,6 +1157,21 @@ def step(self, args, timers):

return self.update_successful, grad_norm, num_zeros_in_grad

def _release_grad_fp32_from_fp16(self, set_to_none=True):
"""
Only used when optimizer post validation is enabled.
"""
for group in self.shard_fp32_from_float16_groups:
_zero_grad_group_helper(group, set_to_none)

def get_mp_group_except_pp_for_bypassing_sync(self):
"""
Only used when optimizer post validation is enabled.
"""
assert get_args().enable_optimizer_post_validation
# Note: expert parallel are not supported yet
return mpu.get_tensor_and_data_parallel_group()

@torch.no_grad()
def do_all_gather(self):
# Reset metadata needed to track results of all-gathers.
Expand All @@ -1173,8 +1192,8 @@ def pre_step(self, args, timers):
self.do_all_gather()

@torch.no_grad()
def post_validation(self):
updated, grad_norm, rollback, succeed = super().post_validation()
def post_validation(self, free_buffers_callback):
updated, grad_norm, rollback, succeed = super().post_validation(free_buffers_callback)
if rollback:
self.update_successful = True
self.do_all_gather()
Expand Down
51 changes: 19 additions & 32 deletions megatron/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,16 @@ def rollback_parameters(self):
self.optimizer.state[param]["exp_avg_sq"] = s2
self.parameters_backup = None

def get_mp_group_except_pp_for_bypassing_sync(self):
"""Default returned here, but the distributed optimizer overrides this."""
# Note: expert parallel are not supported yet
return mpu.get_tensor_model_parallel_group()

def calc_local_grad_norm(self):
grads_for_norm = self.get_main_grads_for_grad_norm()
return self.do_clac_local_grad_norm(
grads_for_norm,
tensor_parallel_group=parallel_state.get_tensor_model_parallel_group())
tensor_parallel_group=self.get_mp_group_except_pp_for_bypassing_sync())

def get_clip_coeff_and_grad_norm(self, max_norm, norm_type=2):
_total_norm = self.partial_reduced_total_norm
Expand Down Expand Up @@ -525,6 +530,10 @@ def _local_unscale_main_grads_and_check_for_nan(self):
# Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale)
# Update across all model parallel instances, except pipeline parallel
torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX,
group=self.get_mp_group_except_pp_for_bypassing_sync())

def partially_reduce_local_found_inf(self):
# self.partial_reduced_found_inf = self.recv_one(self.partial_reduced_found_inf)
Expand Down Expand Up @@ -646,10 +655,10 @@ def send_post_validation(self):

@torch.no_grad()
def recompute_fp32_grad(self):
self._copy_fp32_model_grads_to_fp16_main_grads()
self._copy_model_grads_to_main_grads()
if self.grad_scaler:
# Collect fp32 main grads from fp16.
main_grads = self._collect_main_grad_fp32_from_fp16()
main_grads = self._collect_main_grad_data_for_unscaling()
# Reset found inf.
self.found_inf.fill_(0.0)
# Unscale and set found inf/nan
Expand All @@ -660,12 +669,13 @@ def recompute_fp32_grad(self):
@torch.no_grad()
def post_validation(self, free_buffers_callback):
rank = parallel_state.get_pipeline_model_parallel_rank()
global_rank = torch.distributed.get_rank()
if self.grad_scaler:
# found_inf_flag = self.get_found_inf_flag()
found_inf_flag = self.fully_reduced_global_states["found_inf_flag"]
if found_inf_flag:
if self.do_this_step:
print("found inf rollback")
print(f"{rank}-{global_rank} found inf rollback")
free_buffers_callback()
self.recompute_fp32_grad()
rollback_optimizer_step(self.optimizer)
Expand All @@ -686,7 +696,7 @@ def post_validation(self, free_buffers_callback):
assert not is_nan
if clip_coeff < 1.0:
if self.do_this_step:
print(f"{rank} grad rollback")
print(f"{rank}-{global_rank} grad rollback {clip_coeff}")
free_buffers_callback()
self.recompute_fp32_grad()
rollback_optimizer_step(self.optimizer)
Expand Down Expand Up @@ -859,7 +869,10 @@ def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
self.fp32_from_float16_groups.append(
fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)

fp32_size = 0
for groups in self.fp32_from_fp32_groups:
fp32_size += len(groups)
assert fp32_size == 0, "Not supported, because it is rarely used and makes code messy"

def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
Expand All @@ -878,32 +891,6 @@ def _release_grad_fp32_from_fp16(self, set_to_none=True):
for group in self.fp32_from_float16_groups:
_zero_grad_group_helper(group, set_to_none)

def _collect_main_grad_fp32_from_fp16(self):
main_grads = []
# fp32 params from float16 ones.
for main_group in self.fp32_from_float16_groups:
for main_param in main_group:
if main_param.grad is not None:
main_grads.append(main_param.grad.data)
return main_grads

def _copy_fp32_model_grads_to_fp16_main_grads(self):
# This only needs to be done for the float16 group.
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
main_param.grad = model_param.main_grad.float()
else:
assert False
# if model_param.grad is not None:
# main_param.grad = model_param.grad.float()

# Safe to deallocate model's grad/main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
model_param.grad = None

def _collect_main_grad_data_for_unscaling(self):

main_grads = []
Expand Down

0 comments on commit e7cf420

Please sign in to comment.