Skip to content

modify clip_grad with no to_global #10443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions python/oneflow/nn/utils/clip_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,31 +100,23 @@ def clip_grad_norm_(
param0_placement = parameters[0].placement
if norm_type == float("inf"):
norms = [
p.grad.detach()
.to_global(sbp=sbp_broadcast)
.abs()
.max()
.to_global(placement=param0_placement)
p.grad.detach().abs().max().to_global(placement=param0_placement)
for p in parameters
]
total_norm = norms[0] if len(norms) == 1 else flow.max(flow.stack(norms))
elif norm_type == float("-inf"):
norms = [
p.grad.detach()
.to_global(sbp=sbp_broadcast)
.abs()
.min()
.to_global(placement=param0_placement)
p.grad.detach().abs().min().to_global(placement=param0_placement)
for p in parameters
]
total_norm = norms[0] if len(norms) == 1 else flow.min(flow.stack(norms))
else:
total_norm = flow.linalg.vector_norm(
flow.stack(
[
flow.linalg.vector_norm(
p.grad.detach().to_global(sbp=sbp_broadcast), norm_type
).to_global(placement=param0_placement)
flow.linalg.vector_norm(p.grad.detach(), norm_type).to_global(
placement=param0_placement
)
for p in parameters
]
),
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/test/modules/test_clip_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def _test_graph_clip_grad_value_impl(test_case, shape, device, clip_value):
)


# TODO(lml): find why fail on ci machine
def _test_clip_grad_norm_global_impl(
test_case, shape, sbp, placement, max_norm, norm_type
):
Expand Down