diff --git a/python/oneflow/nn/utils/clip_grad.py b/python/oneflow/nn/utils/clip_grad.py index a667d3cc00e..e293f78b7e3 100644 --- a/python/oneflow/nn/utils/clip_grad.py +++ b/python/oneflow/nn/utils/clip_grad.py @@ -100,21 +100,13 @@ def clip_grad_norm_( param0_placement = parameters[0].placement if norm_type == float("inf"): norms = [ - p.grad.detach() - .to_global(sbp=sbp_broadcast) - .abs() - .max() - .to_global(placement=param0_placement) + p.grad.detach().abs().max().to_global(placement=param0_placement) for p in parameters ] total_norm = norms[0] if len(norms) == 1 else flow.max(flow.stack(norms)) elif norm_type == float("-inf"): norms = [ - p.grad.detach() - .to_global(sbp=sbp_broadcast) - .abs() - .min() - .to_global(placement=param0_placement) + p.grad.detach().abs().min().to_global(placement=param0_placement) for p in parameters ] total_norm = norms[0] if len(norms) == 1 else flow.min(flow.stack(norms)) @@ -122,9 +114,9 @@ def clip_grad_norm_( total_norm = flow.linalg.vector_norm( flow.stack( [ - flow.linalg.vector_norm( - p.grad.detach().to_global(sbp=sbp_broadcast), norm_type - ).to_global(placement=param0_placement) + flow.linalg.vector_norm(p.grad.detach(), norm_type).to_global( + placement=param0_placement + ) for p in parameters ] ), diff --git a/python/oneflow/test/modules/test_clip_grad.py b/python/oneflow/test/modules/test_clip_grad.py index c78dba0a1bd..84f7c79c536 100644 --- a/python/oneflow/test/modules/test_clip_grad.py +++ b/python/oneflow/test/modules/test_clip_grad.py @@ -117,6 +117,7 @@ def _test_graph_clip_grad_value_impl(test_case, shape, device, clip_value): ) +# TODO(lml): find why fail on ci machine def _test_clip_grad_norm_global_impl( test_case, shape, sbp, placement, max_norm, norm_type ):