From 90482080f3dfcb8d41f0247226a975bafc0e8c75 Mon Sep 17 00:00:00 2001 From: POI-WX Date: Thu, 21 Mar 2024 18:39:08 +0800 Subject: [PATCH] add scaled_masked_softmax --- csrc/extensions.cpp | 25 ++++++++++++++++++++++++ deeplink_ext/llm_ops_for_ascend_speed.py | 18 +++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index c4ed8105..3a7f35a8 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -49,6 +49,23 @@ auto extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq, beta1, beta2, epsilon, weight_decay, step, amsgrad); } +auto extScaledMaskedSoftmax(const at::Tensor& input, const at::Tensor& mask, + double scale, bool fixed_triu_mask) { + auto out = at::empty_like(input); + callDiopi(diopiScaledMaskedSoftmax, out, input, mask, scale, fixed_triu_mask); + return out; +} + +auto extScaledMaskedSoftmaxBackward(const at::Tensor& grad_output, + const at::Tensor& out, + const at::Tensor& mask, double scale, + bool fixed_triu_mask) { + auto grad_input = at::empty_like(grad_output); + callDiopi(diopiScaledMaskedSoftmaxBackward, grad_input, grad_output, out, + mask, scale, fixed_triu_mask); + return grad_input; +} + auto extRmsNorm(const at::Tensor& input, const OptionalIntArray& normalized_shape, const at::Tensor& weight, const at::Tensor& bias, double eps) { @@ -400,6 +417,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { if (&diopiAdamW != nullptr) { m.def("adamw", &extAdamW, "deeplink ext_adamw"); } + if (&diopiScaledMaskedSoftmax != nullptr) { + m.def("scaled_masked_softmax_fwd", &extScaledMaskedSoftmax, + "deeplink ext_scaled_masked_softmax_fwd"); + } + if (&diopiScaledMaskedSoftmaxBackward != nullptr) { + m.def("scaled_masked_softmax_bwd", &extScaledMaskedSoftmaxBackward, + "deeplink ext_scaled_masked_softmax_bwd"); + } } } // namespace dipu::dipu_ext diff --git a/deeplink_ext/llm_ops_for_ascend_speed.py b/deeplink_ext/llm_ops_for_ascend_speed.py index d5a6a6e0..9cded5ad 100644 --- a/deeplink_ext/llm_ops_for_ascend_speed.py +++ b/deeplink_ext/llm_ops_for_ascend_speed.py @@ -182,3 +182,21 @@ def adamw_for_ascend_speed( amsgrad, ) return params, exp_avgs, exp_avg_sqs + + +class DeepLinkScaledMaskedSoftmax(torch.autograd.Function): + @staticmethod + def forward(ctx, input, mask, scale, fixed_triu_mask): + out = ext.scaled_masked_softmax_fwd(input, mask, scale, fixed_triu_mask) + ctx.save_for_backward(out, mask) + ctx.scale = scale + ctx.fixed_triu_mask = fixed_triu_mask + return out + + @staticmethod + def backward(ctx, grad_output): + out, mask = ctx.saved_tensors + grad_input = ext.scaled_masked_softmax_bwd( + grad_output, out, mask, ctx.scale, ctx.fixed_triu_mask + ) + return grad_input