Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Jan 30, 2024
1 parent 0b158cc commit acbf8dc
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <diopi/functions_ext.h>

#include <csrc_dipu/diopirt/diopirt_impl.h>
#include <csrc_dipu/runtime/core/DIPUGeneratorImpl.h>

#include "diopi_helper.h"
Expand Down Expand Up @@ -128,11 +129,6 @@ auto extMultiHeadAttentionBackward(const at::Tensor& grad_out,
auto extFlashAttention(const at::Tensor& q, const at::Tensor& k,
const at::Tensor& v, double p_dropout,
double softmax_scale, bool is_causal) {
const auto batch_size = q.sizes()[0];
const auto q_seq_len = q.sizes()[1];
const auto head_num = q.sizes()[2];
const auto k_seq_len = k.sizes()[1];

auto out = at::empty_like(q);
diopiTensorHandle_t softmax_max;
diopiTensorHandle_t softmax_sum;
Expand All @@ -141,12 +137,13 @@ auto extFlashAttention(const at::Tensor& q, const at::Tensor& k,
auto gen = createDIPUGenerator();

[[maybe_unused]] auto context = callDiopiKeepContext(
diopiFlashAttention, out, gen, &softmax_max, &softmax_sum, &softmax_out,
diopiFlashAttention, out, &softmax_max, &softmax_sum, &softmax_out, gen,
q, k, v, p_dropout, softmax_scale, is_causal);

return std::make_tuple(std::move(out), *fromDiopiTensorHandle(softmax_max),
*fromDiopiTensorHandle(softmax_sum),
*fromDiopiTensorHandle(softmax_out), std::move(gen));
return std::make_tuple(
std::move(out), *dipu::diopi_helper::fromDiopiTensorHandle(softmax_max),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_sum),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_out), std::move(gen));
}

// grad_q, grad_k, grad_v are output args, and should be pre-allocated.
Expand Down

0 comments on commit acbf8dc

Please sign in to comment.