From 4a5a5c40be8fdd377e289da62f302c3d8750b9a4 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 22 Mar 2024 15:40:12 +0800 Subject: [PATCH] optimize all --- csrc/extensions.cpp | 72 ++++++++++++------------ deeplink_ext/llm_ops_for_ascend_speed.py | 4 +- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 44453801..3a7f35a8 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -40,31 +40,31 @@ at::IntArrayRef optionalIntArrayToIntArrayRefOrDefault( } // namespace -// auto extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq, -// at::Tensor& max_exp_avg_sq, at::Tensor& grad, float lr, -// float beta1, float beta2, float epsilon, float weight_decay, -// int64_t step, bool amsgrad) { -// // the diopiAdamW func has no "maximize" param -// callDiopi(diopiAdamW, param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, -// beta1, beta2, epsilon, weight_decay, step, amsgrad); -// } - -// auto extScaledMaskedSoftmax(const at::Tensor& input, const at::Tensor& mask, -// double scale, bool fixed_triu_mask) { -// auto out = at::empty_like(input); -// callDiopi(diopiScaledMaskedSoftmax, out, input, mask, scale, fixed_triu_mask); -// return out; -// } - -// auto extScaledMaskedSoftmaxBackward(const at::Tensor& grad_output, -// const at::Tensor& out, -// const at::Tensor& mask, double scale, -// bool fixed_triu_mask) { -// auto grad_input = at::empty_like(grad_output); -// callDiopi(diopiScaledMaskedSoftmaxBackward, grad_input, grad_output, out, -// mask, scale, fixed_triu_mask); -// return grad_input; -// } +auto extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq, + at::Tensor& max_exp_avg_sq, at::Tensor& grad, float lr, + float beta1, float beta2, float epsilon, float weight_decay, + int64_t step, bool amsgrad) { + // the diopiAdamW func has no "maximize" param + callDiopi(diopiAdamW, param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, + beta1, beta2, epsilon, weight_decay, step, amsgrad); +} + +auto extScaledMaskedSoftmax(const at::Tensor& input, const at::Tensor& mask, + double scale, bool fixed_triu_mask) { + auto out = at::empty_like(input); + callDiopi(diopiScaledMaskedSoftmax, out, input, mask, scale, fixed_triu_mask); + return out; +} + +auto extScaledMaskedSoftmaxBackward(const at::Tensor& grad_output, + const at::Tensor& out, + const at::Tensor& mask, double scale, + bool fixed_triu_mask) { + auto grad_input = at::empty_like(grad_output); + callDiopi(diopiScaledMaskedSoftmaxBackward, grad_input, grad_output, out, + mask, scale, fixed_triu_mask); + return grad_input; +} auto extRmsNorm(const at::Tensor& input, const OptionalIntArray& normalized_shape, @@ -414,17 +414,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { if (&diopiApplyPenalty != nullptr) { m.def("apply_penalty", &extApplyPenalty, "deeplink ext_apply_penalty"); } - // if (&diopiAdamW != nullptr) { - // m.def("adamw", &extAdamW, "deeplink ext_adamw"); - // } - // if (&diopiScaledMaskedSoftmax != nullptr) { - // m.def("scaled_masked_softmax_fwd", &extScaledMaskedSoftmax, - // "deeplink ext_scaled_masked_softmax_fwd"); - // } - // if (&diopiScaledMaskedSoftmaxBackward != nullptr) { - // m.def("scaled_masked_softmax_bwd", &extScaledMaskedSoftmaxBackward, - // "deeplink ext_scaled_masked_softmax_bwd"); - // } + if (&diopiAdamW != nullptr) { + m.def("adamw", &extAdamW, "deeplink ext_adamw"); + } + if (&diopiScaledMaskedSoftmax != nullptr) { + m.def("scaled_masked_softmax_fwd", &extScaledMaskedSoftmax, + "deeplink ext_scaled_masked_softmax_fwd"); + } + if (&diopiScaledMaskedSoftmaxBackward != nullptr) { + m.def("scaled_masked_softmax_bwd", &extScaledMaskedSoftmaxBackward, + "deeplink ext_scaled_masked_softmax_bwd"); + } } } // namespace dipu::dipu_ext diff --git a/deeplink_ext/llm_ops_for_ascend_speed.py b/deeplink_ext/llm_ops_for_ascend_speed.py index 38a89ee1..62e7528e 100644 --- a/deeplink_ext/llm_ops_for_ascend_speed.py +++ b/deeplink_ext/llm_ops_for_ascend_speed.py @@ -7,8 +7,8 @@ assert hasattr(ext, "fa_fwd") and hasattr(ext, "fa_bwd") assert hasattr(ext, "apply_rotary") assert hasattr(ext, "rms_norm") and hasattr(ext, "rms_norm_backward") -# assert hasattr(ext, "adamw") -# assert hasattr(ext, "scaled_masked_softmax_fwd") and hasattr(ext, "scaled_masked_softmax_bwd") +assert hasattr(ext, "adamw") +assert hasattr(ext, "scaled_masked_softmax_fwd") and hasattr(ext, "scaled_masked_softmax_bwd") class DeepLinkFlashSelfAttention(torch.autograd.Function):