From acbf8dcc99226374225a57f08efc033fbd93302c Mon Sep 17 00:00:00 2001 From: POI-WX Date: Tue, 30 Jan 2024 18:18:04 +0800 Subject: [PATCH] fix bug --- csrc/extensions.cpp | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index a3643f1d..e10527a3 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -18,6 +18,7 @@ #include +#include #include #include "diopi_helper.h" @@ -128,11 +129,6 @@ auto extMultiHeadAttentionBackward(const at::Tensor& grad_out, auto extFlashAttention(const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, double p_dropout, double softmax_scale, bool is_causal) { - const auto batch_size = q.sizes()[0]; - const auto q_seq_len = q.sizes()[1]; - const auto head_num = q.sizes()[2]; - const auto k_seq_len = k.sizes()[1]; - auto out = at::empty_like(q); diopiTensorHandle_t softmax_max; diopiTensorHandle_t softmax_sum; @@ -141,12 +137,13 @@ auto extFlashAttention(const at::Tensor& q, const at::Tensor& k, auto gen = createDIPUGenerator(); [[maybe_unused]] auto context = callDiopiKeepContext( - diopiFlashAttention, out, gen, &softmax_max, &softmax_sum, &softmax_out, + diopiFlashAttention, out, &softmax_max, &softmax_sum, &softmax_out, gen, q, k, v, p_dropout, softmax_scale, is_causal); - return std::make_tuple(std::move(out), *fromDiopiTensorHandle(softmax_max), - *fromDiopiTensorHandle(softmax_sum), - *fromDiopiTensorHandle(softmax_out), std::move(gen)); + return std::make_tuple( + std::move(out), *dipu::diopi_helper::fromDiopiTensorHandle(softmax_max), + *dipu::diopi_helper::fromDiopiTensorHandle(softmax_sum), + *dipu::diopi_helper::fromDiopiTensorHandle(softmax_out), std::move(gen)); } // grad_q, grad_k, grad_v are output args, and should be pre-allocated.