From 6b64d42abeec1224851be97b663e7879e917200c Mon Sep 17 00:00:00 2001 From: shw Date: Wed, 6 Mar 2024 17:45:43 +0800 Subject: [PATCH 1/2] modify clip grad --- python/oneflow/nn/utils/clip_grad.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/oneflow/nn/utils/clip_grad.py b/python/oneflow/nn/utils/clip_grad.py index a667d3cc00e..d3e8e37d25a 100644 --- a/python/oneflow/nn/utils/clip_grad.py +++ b/python/oneflow/nn/utils/clip_grad.py @@ -119,6 +119,8 @@ def clip_grad_norm_( ] total_norm = norms[0] if len(norms) == 1 else flow.min(flow.stack(norms)) else: + ''' + # data parallel total_norm = flow.linalg.vector_norm( flow.stack( [ @@ -130,6 +132,22 @@ def clip_grad_norm_( ), norm_type, ) + ''' + # tensor parallel: + partial_grad_squre_sum = flow.sum( + flow.stack( + [ + flow.sum( + flow.pow(p.grad.detach(), norm_type) + ).to_local() + for p in parameters + ] + ) + ) + + flow.comm.all_reduce(partial_grad_squre_sum) + total_norm = flow.pow(partial_grad_squre_sum, 1 / norm_type) + total_norm = total_norm.to_global(sbp=sbp_broadcast, placement=param0_placement) if error_if_nonfinite and flow.logical_or( total_norm.isnan(), total_norm.isinf() ): From 40877372472afadfe197192becff30a95708b182 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Wed, 6 Mar 2024 11:32:29 +0000 Subject: [PATCH 2/2] auto format by CI --- python/oneflow/nn/utils/clip_grad.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/oneflow/nn/utils/clip_grad.py b/python/oneflow/nn/utils/clip_grad.py index d3e8e37d25a..508295385e9 100644 --- a/python/oneflow/nn/utils/clip_grad.py +++ b/python/oneflow/nn/utils/clip_grad.py @@ -119,7 +119,7 @@ def clip_grad_norm_( ] total_norm = norms[0] if len(norms) == 1 else flow.min(flow.stack(norms)) else: - ''' + """ # data parallel total_norm = flow.linalg.vector_norm( flow.stack( @@ -132,14 +132,12 @@ def clip_grad_norm_( ), norm_type, ) - ''' + """ # tensor parallel: partial_grad_squre_sum = flow.sum( flow.stack( [ - flow.sum( - flow.pow(p.grad.detach(), norm_type) - ).to_local() + flow.sum(flow.pow(p.grad.detach(), norm_type)).to_local() for p in parameters ] ) @@ -147,7 +145,9 @@ def clip_grad_norm_( flow.comm.all_reduce(partial_grad_squre_sum) total_norm = flow.pow(partial_grad_squre_sum, 1 / norm_type) - total_norm = total_norm.to_global(sbp=sbp_broadcast, placement=param0_placement) + total_norm = total_norm.to_global( + sbp=sbp_broadcast, placement=param0_placement + ) if error_if_nonfinite and flow.logical_or( total_norm.isnan(), total_norm.isinf() ):