Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
inkcherry committed Apr 15, 2024
1 parent 79cc4ce commit 3ebed5e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
12 changes: 8 additions & 4 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,14 @@ def _get_norm_mask_idx(self, group):

for p in group:
grad_flat_en_idx = grad_flat_st_idx + p.numel()
if p.grad is None or self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)):
group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx])
else:
grad_flat_st_idx = grad_flat_en_idx
if p.grad is not None and self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)):
# merge range
if len(group_mask_idx_list) > 0 and grad_flat_st_idx == group_mask_idx_list[-1][-1]:
group_mask_idx_list[-1][-1] = grad_flat_en_idx
else:
group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx])
grad_flat_st_idx = grad_flat_en_idx

return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device())

def step(self, closure=None):
Expand Down
13 changes: 7 additions & 6 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,16 +425,17 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No
# A loop-free implementation to create a mask tensor based on a range list,
# which is logically equivalent to the following implementation.

# # mask_tensor = torch.zeros_like(p, device=p.device, dtype=bool)
# # for mask_idx in grad_norm_mask[idx]:
# # mask_tensor[mask_idx[0]:mask_idx[1]] = True
# # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool)
# #for mask_idx in grad_norm_mask[idx]:
# # mask_tensor_[mask_idx[0]:mask_idx[1]] = True
cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(),
dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1)
mask_tensor = torch.zeros_like(p, device=get_accelerator().current_device(), dtype=p.dtype)
mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype)
mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1),
cum_sum_pairs.view(-1)).cumsum(0).bool()

cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1]
# assert torch.equal(mask_tensor_, mask_tensor)
param_norm = torch.masked_fill(p.grad.data, mask_tensor, 0).float().norm(norm_type)

else:
param_norm = p.grad.data.float().norm(norm_type)
total_norm += param_norm.item()**norm_type
Expand Down

0 comments on commit 3ebed5e

Please sign in to comment.