From 3ebed5ea11cafc2a07aa1a98aa58ed832eebfa60 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 15 Apr 2024 14:46:17 +0800 Subject: [PATCH] update --- deepspeed/runtime/fp16/fused_optimizer.py | 12 ++++++++---- deepspeed/runtime/utils.py | 13 +++++++------ 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index 77824ef32149..a98681b3e9f8 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -232,10 +232,14 @@ def _get_norm_mask_idx(self, group): for p in group: grad_flat_en_idx = grad_flat_st_idx + p.numel() - if p.grad is None or self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)): - group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx]) - else: - grad_flat_st_idx = grad_flat_en_idx + if p.grad is not None and self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)): + # merge range + if len(group_mask_idx_list) > 0 and grad_flat_st_idx == group_mask_idx_list[-1][-1]: + group_mask_idx_list[-1][-1] = grad_flat_en_idx + else: + group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx]) + grad_flat_st_idx = grad_flat_en_idx + return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device()) def step(self, closure=None): diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index d213e00c42d6..ee1b5655dfce 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -425,16 +425,17 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No # A loop-free implementation to create a mask tensor based on a range list, # which is logically equivalent to the following implementation. - # # mask_tensor = torch.zeros_like(p, device=p.device, dtype=bool) - # # for mask_idx in grad_norm_mask[idx]: - # # mask_tensor[mask_idx[0]:mask_idx[1]] = True + # # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool) + # #for mask_idx in grad_norm_mask[idx]: + # # mask_tensor_[mask_idx[0]:mask_idx[1]] = True cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(), dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1) - mask_tensor = torch.zeros_like(p, device=get_accelerator().current_device(), dtype=p.dtype) + mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype) mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1), - cum_sum_pairs.view(-1)).cumsum(0).bool() - + cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1] + # assert torch.equal(mask_tensor_, mask_tensor) param_norm = torch.masked_fill(p.grad.data, mask_tensor, 0).float().norm(norm_type) + else: param_norm = p.grad.data.float().norm(norm_type) total_norm += param_norm.item()**norm_type