Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
fix the formatting issue for commit 2f382d8
  • Loading branch information
harygo2 committed Apr 27, 2024
1 parent 1e5e1b4 commit 2eadf2f
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,9 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No
# # mask_tensor_[mask_idx[0]:mask_idx[1]] = True
cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device_name(),
dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1)
mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device_name(), dtype=p.dtype)
mask_tensor = torch.zeros(p.shape[0] + 1,
device=get_accelerator().current_device_name(),
dtype=p.dtype)
mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1),
cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1]

Expand Down

0 comments on commit 2eadf2f

Please sign in to comment.