From fce081f9e365f2e6b130c39e09592af82267d489 Mon Sep 17 00:00:00 2001 From: zhangzefeng92 Date: Wed, 27 Mar 2024 10:19:14 +0800 Subject: [PATCH] fix python lint --- .../internlm_ops/rms_norm/deeplink.py | 57 +++++++++++++------ tests/test_rms_lightlm.py | 2 +- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/deeplink_ext/internlm_ops/rms_norm/deeplink.py b/deeplink_ext/internlm_ops/rms_norm/deeplink.py index f9342fc0..82142e17 100644 --- a/deeplink_ext/internlm_ops/rms_norm/deeplink.py +++ b/deeplink_ext/internlm_ops/rms_norm/deeplink.py @@ -12,7 +12,9 @@ class _DeepLinkRMSNormFunction(torch.autograd.Function): def forward(ctx, hidden_states, weight, bias, eps): output = torch.empty_like(hidden_states) inv_rms_shape = list(hidden_states.shape[:-1], 1) - inv_rms = torch.empty(inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device) + inv_rms = torch.empty( + inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device + ) ext.rms_norm(output, inv_rms, hidden_states, None, weight, bias, eps) ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) @@ -22,12 +24,23 @@ def forward(ctx, hidden_states, weight, bias, eps): def backward(ctx, grad_output): hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors eps = eps_tensor.item() - + grad_input = torch.empty_like(hidden_states) grad_weight = torch.empty_like(weight) grad_bias = torch.empty_like(bias) - ext.rms_norm_backward(grad_input, grad_weight, grad_bias, grad_output, hidden_states, weight, bias, inv_rms, None, eps) + ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output, + hidden_states, + weight, + bias, + inv_rms, + None, + eps + ) return grad_input, grad_weight, grad_bias, None @@ -36,8 +49,18 @@ class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function): def forward(ctx, hidden_states, weight, bias, eps, normalized_shape): output = torch.empty_like(hidden_states, dtype=torch.float32) inv_rms_shape = list(hidden_states.shape[:-1], 1) - inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=hidden_states.device) - ext.rms_norm(output, inv_rms, hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps) + inv_rms = torch.empty( + inv_rms_shape, dtype=torch.float32, device=hidden_states.device + ) + ext.rms_norm( + output, + inv_rms, + hidden_states.float(), + normalized_shape, + weight.float(), + bias.float(), + eps + ) output = output.half() inv_rms = inv_rms.half() ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps)) @@ -49,20 +72,22 @@ def backward(ctx, grad_output): hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors eps = eps_tensor.item() normalized_shape = ctx.intermediate_results - + grad_input = torch.empty_like(hidden_states, dtype=torch.float32) grad_weight = torch.empty_like(weight, dtype=torch.float32) grad_bias = torch.empty_like(bias, dtype=torch.float32) - ext.rms_norm_backward(grad_input, - grad_weight, - grad_bias, - grad_output.float(), - hidden_states.float(), - weight.float(), - bias.float(), - inv_rms.float(), - normalized_shape, - eps) + ext.rms_norm_backward( + grad_input, + grad_weight, + grad_bias, + grad_output.float(), + hidden_states.float(), + weight.float(), + bias.float(), + inv_rms.float(), + normalized_shape, + eps + ) grad_output = grad_output.half() hidden_states = hidden_states.half() inv_rms = inv_rms.half() diff --git a/tests/test_rms_lightlm.py b/tests/test_rms_lightlm.py index 35915e2d..31f081fa 100644 --- a/tests/test_rms_lightlm.py +++ b/tests/test_rms_lightlm.py @@ -38,7 +38,7 @@ bias, inv_rms, weight.shape, - 1e-6 + 1e-6, ) print("Output:", output)