Skip to content

Commit

Permalink
fix python lint
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangzefeng92 committed Mar 27, 2024
1 parent 3f100c0 commit fce081f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 17 deletions.
57 changes: 41 additions & 16 deletions deeplink_ext/internlm_ops/rms_norm/deeplink.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ class _DeepLinkRMSNormFunction(torch.autograd.Function):
def forward(ctx, hidden_states, weight, bias, eps):
output = torch.empty_like(hidden_states)
inv_rms_shape = list(hidden_states.shape[:-1], 1)
inv_rms = torch.empty(inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device)
inv_rms = torch.empty(
inv_rms_shape, dtype=hidden_states.dtype, device=hidden_states.device
)
ext.rms_norm(output, inv_rms, hidden_states, None, weight, bias, eps)

ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps))
Expand All @@ -22,12 +24,23 @@ def forward(ctx, hidden_states, weight, bias, eps):
def backward(ctx, grad_output):
hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors
eps = eps_tensor.item()

grad_input = torch.empty_like(hidden_states)
grad_weight = torch.empty_like(weight)
grad_bias = torch.empty_like(bias)

ext.rms_norm_backward(grad_input, grad_weight, grad_bias, grad_output, hidden_states, weight, bias, inv_rms, None, eps)
ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output,
hidden_states,
weight,
bias,
inv_rms,
None,
eps
)
return grad_input, grad_weight, grad_bias, None


Expand All @@ -36,8 +49,18 @@ class _DeepLinkRMSNormFunctionWithNormalizedShape(torch.autograd.Function):
def forward(ctx, hidden_states, weight, bias, eps, normalized_shape):
output = torch.empty_like(hidden_states, dtype=torch.float32)
inv_rms_shape = list(hidden_states.shape[:-1], 1)
inv_rms = torch.empty(inv_rms_shape, dtype=torch.float32, device=hidden_states.device)
ext.rms_norm(output, inv_rms, hidden_states.float(), normalized_shape, weight.float(), bias.float(), eps)
inv_rms = torch.empty(
inv_rms_shape, dtype=torch.float32, device=hidden_states.device
)
ext.rms_norm(
output,
inv_rms,
hidden_states.float(),
normalized_shape,
weight.float(),
bias.float(),
eps
)
output = output.half()
inv_rms = inv_rms.half()
ctx.save_for_backward(hidden_states, inv_rms, weight, bias, torch.tensor(eps))
Expand All @@ -49,20 +72,22 @@ def backward(ctx, grad_output):
hidden_states, inv_rms, weight, bias, eps_tensor = ctx.saved_tensors
eps = eps_tensor.item()
normalized_shape = ctx.intermediate_results

grad_input = torch.empty_like(hidden_states, dtype=torch.float32)
grad_weight = torch.empty_like(weight, dtype=torch.float32)
grad_bias = torch.empty_like(bias, dtype=torch.float32)
ext.rms_norm_backward(grad_input,
grad_weight,
grad_bias,
grad_output.float(),
hidden_states.float(),
weight.float(),
bias.float(),
inv_rms.float(),
normalized_shape,
eps)
ext.rms_norm_backward(
grad_input,
grad_weight,
grad_bias,
grad_output.float(),
hidden_states.float(),
weight.float(),
bias.float(),
inv_rms.float(),
normalized_shape,
eps
)
grad_output = grad_output.half()
hidden_states = hidden_states.half()
inv_rms = inv_rms.half()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rms_lightlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
bias,
inv_rms,
weight.shape,
1e-6
1e-6,
)

print("Output:", output)
Expand Down

0 comments on commit fce081f

Please sign in to comment.