From 5891226072281d457de8f0563ec163ad3055d9ee Mon Sep 17 00:00:00 2001 From: POI-WX Date: Wed, 20 Mar 2024 20:52:24 +0800 Subject: [PATCH] add adamw for ascend speed --- csrc/extensions.cpp | 12 +++++ deeplink_ext/llm_ops_for_ascend_speed.py | 56 +++++++++++++++++++++++- 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index a80b9892..36f53501 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -40,6 +40,15 @@ at::IntArrayRef optionalIntArrayToIntArrayRefOrDefault( } // namespace +// auto extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq, +// at::Tensor& max_exp_avg_sq, at::Tensor& grad, float lr, +// float beta1, float beta2, float epsilon, float weight_decay, +// int64_t step, bool amsgrad) { +// // the diopiAdamW func has no "maximize" param +// callDiopi(diopiAdamW, param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, +// beta1, beta2, epsilon, weight_decay, step, amsgrad); +// } + auto extRmsNorm(const at::Tensor& input, const OptionalIntArray& normalized_shape, const at::Tensor& weight, const at::Tensor& bias, double eps) { @@ -358,6 +367,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { if (&diopiApplyPenalty != nullptr) { m.def("apply_penalty", &extApplyPenalty, "deeplink ext_apply_penalty"); } + if (&diopiAdamW != nullptr) { + m.def("adamw", &extAdamW, "deeplink ext_adamw"); + } } } // namespace dipu::dipu_ext diff --git a/deeplink_ext/llm_ops_for_ascend_speed.py b/deeplink_ext/llm_ops_for_ascend_speed.py index cf9df98e..033bdb42 100644 --- a/deeplink_ext/llm_ops_for_ascend_speed.py +++ b/deeplink_ext/llm_ops_for_ascend_speed.py @@ -1,12 +1,13 @@ # Copyright (c) 2024, DeepLink. -from typing import Optional, Union +from typing import Optional, Union, List import torch import deeplink_ext.cpp_extensions as ext assert hasattr(ext, "fa_fwd") and hasattr(ext, "fa_bwd") assert hasattr(ext, "apply_rotary") assert hasattr(ext, "rms_norm") and hasattr(ext, "rms_norm_backward") +assert hasattr(ext, "adamw") class DeepLinkFlashSelfAttention(torch.autograd.Function): @@ -128,3 +129,56 @@ def backward(ctx, grad_output): hidden_states, grad_output, inv_rms, None, weight, bias, ctx.eps ) return grad_input, grad_weight, grad_bias, None + + +def adamw_for_ascend_speed( + params: List[torch.Tensor], + grads: List[torch.Tensor], + exp_avgs: List[torch.Tensor], + exp_avg_sqs: List[torch.Tensor], + max_exp_avg_sqs: List[torch.Tensor], + state_steps: List[int], + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool, + norm_coeff_scale: float +): + r"""Functional API that performs AdamW algorithm computation. + See :class:`~torch.optim.AdamW` for details. + """ + + assert maximize == False, "ascend diopiAdamW only support False 'maximize'." + assert amsgrad == False, "ascend diopiAdamW only support False 'amsgrad'." + + for i, param in enumerate(params): + if norm_coeff_scale is not None: + grad = grads[i].float() * norm_coeff_scale + else: + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + if not max_exp_avg_sqs: + max_exp_avg_sq = torch.Tensor().cuda() + else: + max_exp_avg_sq = max_exp_avg_sqs[i] + ext.adamw( + param, + exp_avg, + exp_avg_sq, + max_exp_avg_sq, + grad, + lr, + beta1, + beta2, + eps, + weight_decay, + step, + amsgrad, + ) + return params, exp_avgs, exp_avg_sq