Skip to content

Commit

Permalink
draft
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Jan 30, 2024
1 parent 78fcb5c commit 7460756
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,13 @@ auto extFlashAttention(const at::Tensor& q, const at::Tensor& k,

auto gen = createDIPUGenerator();

callDiopi(diopiFlashAttention, out, gen, &softmax_max, &softmax_sum,
&softmax_out, q, k, v, p_dropout, softmax_scale, is_causal);
[[maybe_unused]] auto context = callDiopiKeepContext(
diopiFlashAttention, out, gen, &softmax_max, &softmax_sum, &softmax_out,
q, k, v, p_dropout, softmax_scale, is_causal);

return std::make_tuple(std::move(out), *fromDiopiTensorHandle(softmax_max),
*fromDiopiTensorHandle(softmax_sum),
std::move(softmax_out), std::move(gen));
*fromDiopiTensorHandle(softmax_out), std::move(gen));
}

// grad_q, grad_k, grad_v are output args, and should be pre-allocated.
Expand Down

0 comments on commit 7460756

Please sign in to comment.