Skip to content

Commit

Permalink
add adamw for ascend speed
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Mar 20, 2024
1 parent dfe9065 commit 5891226
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
12 changes: 12 additions & 0 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ at::IntArrayRef optionalIntArrayToIntArrayRefOrDefault(

} // namespace

// auto extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq,
// at::Tensor& max_exp_avg_sq, at::Tensor& grad, float lr,
// float beta1, float beta2, float epsilon, float weight_decay,
// int64_t step, bool amsgrad) {
// // the diopiAdamW func has no "maximize" param
// callDiopi(diopiAdamW, param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr,
// beta1, beta2, epsilon, weight_decay, step, amsgrad);
// }

auto extRmsNorm(const at::Tensor& input,
const OptionalIntArray& normalized_shape,
const at::Tensor& weight, const at::Tensor& bias, double eps) {
Expand Down Expand Up @@ -358,6 +367,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
if (&diopiApplyPenalty != nullptr) {
m.def("apply_penalty", &extApplyPenalty, "deeplink ext_apply_penalty");
}
if (&diopiAdamW != nullptr) {
m.def("adamw", &extAdamW, "deeplink ext_adamw");
}
}

} // namespace dipu::dipu_ext
56 changes: 55 additions & 1 deletion deeplink_ext/llm_ops_for_ascend_speed.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) 2024, DeepLink.

from typing import Optional, Union
from typing import Optional, Union, List
import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "fa_fwd") and hasattr(ext, "fa_bwd")
assert hasattr(ext, "apply_rotary")
assert hasattr(ext, "rms_norm") and hasattr(ext, "rms_norm_backward")
assert hasattr(ext, "adamw")


class DeepLinkFlashSelfAttention(torch.autograd.Function):
Expand Down Expand Up @@ -128,3 +129,56 @@ def backward(ctx, grad_output):
hidden_states, grad_output, inv_rms, None, weight, bias, ctx.eps
)
return grad_input, grad_weight, grad_bias, None


def adamw_for_ascend_speed(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
exp_avg_sqs: List[torch.Tensor],
max_exp_avg_sqs: List[torch.Tensor],
state_steps: List[int],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool,
norm_coeff_scale: float
):
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""

assert maximize == False, "ascend diopiAdamW only support False 'maximize'."
assert amsgrad == False, "ascend diopiAdamW only support False 'amsgrad'."

for i, param in enumerate(params):
if norm_coeff_scale is not None:
grad = grads[i].float() * norm_coeff_scale
else:
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
if not max_exp_avg_sqs:
max_exp_avg_sq = torch.Tensor().cuda()
else:
max_exp_avg_sq = max_exp_avg_sqs[i]
ext.adamw(
param,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
grad,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
amsgrad,
)
return params, exp_avgs, exp_avg_sq

0 comments on commit 5891226

Please sign in to comment.