diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 8b57071f..23f8fbcb 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -26,6 +27,15 @@ namespace dipu::dipu_ext { +void extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq, + at::Tensor& max_exp_avg_sq, at::Tensor& grad, float lr, + float beta1, float beta2, float epsilon, float weight_decay, + int64_t step, bool amsgrad) { + // the diopiAdamW func has no "maximize" param + callDiopi(diopiAdamW, param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, + beta1, beta2, epsilon, weight_decay, step, amsgrad); +} + auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms, const at::Tensor& input, const OptionalIntArray& normalized_shape, @@ -217,7 +227,11 @@ auto extRmsNormLightllm(const at::Tensor& x, const at::Tensor& weight, // 否则不注册, 等到 python 层处理. // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - if (&diopiRMSNorm != nullptr) { // Check if weak symbol defined + // Check if weak symbol defined + if (&diopiAdamW != nullptr) { + m.def("adamw", &extAdamW, "deeplink ext_adamw"); + } + if (&diopiRMSNorm != nullptr) { m.def("rms_norm", &extRmsNorm, "deeplink ext_rms_norm"); } if (&diopiRMSNormBackward != nullptr) { diff --git a/deeplink_ext/ascend_speed/__init__.py b/deeplink_ext/ascend_speed/__init__.py index 399f9138..fab32da5 100644 --- a/deeplink_ext/ascend_speed/__init__.py +++ b/deeplink_ext/ascend_speed/__init__.py @@ -1,3 +1,4 @@ from .rotary_embedding import apply_rotary, RotaryEmbedding +from .adamw import adamw -__all__ = ["apply_rotary", "RotaryEmbedding"] +__all__ = ["apply_rotary", "RotaryEmbedding", "adamw"] diff --git a/deeplink_ext/ascend_speed/adamw.py b/deeplink_ext/ascend_speed/adamw.py new file mode 100644 index 00000000..9133a71a --- /dev/null +++ b/deeplink_ext/ascend_speed/adamw.py @@ -0,0 +1,59 @@ +from typing import List +import torch +import deeplink_ext.cpp_extensions as ext + + +assert hasattr(ext, "adamw") + + +def adamw( + params: List[torch.Tensor], + grads: List[torch.Tensor], + exp_avgs: List[torch.Tensor], + exp_avg_sqs: List[torch.Tensor], + max_exp_avg_sqs: List[torch.Tensor], + state_steps: List[int], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + norm_coeff_scale: float +): + r"""Functional API that performs AdamW algorithm computation. + See :class:`~torch.optim.AdamW` for details. + """ + + assert maximize == False, "ascend diopiAdamW only support False 'maximize'." + assert amsgrad == False, "ascend diopiAdamW only support False 'amsgrad'." + + for i, param in enumerate(params): + if norm_coeff_scale is not None: + grad = grads[i].float() * norm_coeff_scale + else: + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + if not max_exp_avg_sqs: + max_exp_avg_sq = torch.Tensor().cuda() + else: + max_exp_avg_sq = max_exp_avg_sqs[i] + ext.adamw( + param, + exp_avg, + exp_avg_sq, + max_exp_avg_sq, + grad, + lr, + beta1, + beta2, + eps, + weight_decay, + step, + amsgrad, + ) + return params, exp_avgs, exp_avg_sqs