Skip to content

Commit

Permalink
feat: support adamw op for ascend speed (#65)
Browse files Browse the repository at this point in the history
Support adamw op for ascend speed.

---------

Co-authored-by: Zhangzefeng <[email protected]>
  • Loading branch information
POI-WX and zhangzefeng92 authored Apr 1, 2024
1 parent 8261278 commit 220857e
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 2 deletions.
16 changes: 15 additions & 1 deletion csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <pybind11/cast.h>
#include <pybind11/detail/common.h>

#include <diopi/functions.h>
#include <diopi/functions_ext.h>

#include <csrc_dipu/runtime/core/DIPUGeneratorImpl.h>
Expand All @@ -26,6 +27,15 @@

namespace dipu::dipu_ext {

void extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq,
at::Tensor& max_exp_avg_sq, at::Tensor& grad, float lr,
float beta1, float beta2, float epsilon, float weight_decay,
int64_t step, bool amsgrad) {
// the diopiAdamW func has no "maximize" param
callDiopi(diopiAdamW, param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr,
beta1, beta2, epsilon, weight_decay, step, amsgrad);
}

auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms,
const at::Tensor& input,
const OptionalIntArray& normalized_shape,
Expand Down Expand Up @@ -217,7 +227,11 @@ auto extRmsNormLightllm(const at::Tensor& x, const at::Tensor& weight,
// 否则不注册, 等到 python 层处理.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
if (&diopiRMSNorm != nullptr) { // Check if weak symbol defined
// Check if weak symbol defined
if (&diopiAdamW != nullptr) {
m.def("adamw", &extAdamW, "deeplink ext_adamw");
}
if (&diopiRMSNorm != nullptr) {
m.def("rms_norm", &extRmsNorm, "deeplink ext_rms_norm");
}
if (&diopiRMSNormBackward != nullptr) {
Expand Down
3 changes: 2 additions & 1 deletion deeplink_ext/ascend_speed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .rotary_embedding import apply_rotary, RotaryEmbedding
from .adamw import adamw

__all__ = ["apply_rotary", "RotaryEmbedding"]
__all__ = ["apply_rotary", "RotaryEmbedding", "adamw"]
59 changes: 59 additions & 0 deletions deeplink_ext/ascend_speed/adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import List
import torch
import deeplink_ext.cpp_extensions as ext


assert hasattr(ext, "adamw")


def adamw(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
exp_avg_sqs: List[torch.Tensor],
max_exp_avg_sqs: List[torch.Tensor],
state_steps: List[int],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool,
norm_coeff_scale: float
):
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""

assert maximize == False, "ascend diopiAdamW only support False 'maximize'."
assert amsgrad == False, "ascend diopiAdamW only support False 'amsgrad'."

for i, param in enumerate(params):
if norm_coeff_scale is not None:
grad = grads[i].float() * norm_coeff_scale
else:
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
if not max_exp_avg_sqs:
max_exp_avg_sq = torch.Tensor().cuda()
else:
max_exp_avg_sq = max_exp_avg_sqs[i]
ext.adamw(
param,
exp_avg,
exp_avg_sq,
max_exp_avg_sq,
grad,
lr,
beta1,
beta2,
eps,
weight_decay,
step,
amsgrad,
)
return params, exp_avgs, exp_avg_sqs

0 comments on commit 220857e

Please sign in to comment.