diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index c04c3223..ac617bce 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -140,12 +140,13 @@ auto extFlashAttention(const at::Tensor& q, const at::Tensor& k, auto gen = createDIPUGenerator(); - callDiopi(diopiFlashAttention, out, gen, &softmax_max, &softmax_sum, - &softmax_out, q, k, v, p_dropout, softmax_scale, is_causal); + [[maybe_unused]] auto context = callDiopiKeepContext( + diopiFlashAttention, out, gen, &softmax_max, &softmax_sum, &softmax_out, + q, k, v, p_dropout, softmax_scale, is_causal); return std::make_tuple(std::move(out), *fromDiopiTensorHandle(softmax_max), *fromDiopiTensorHandle(softmax_sum), - std::move(softmax_out), std::move(gen)); + *fromDiopiTensorHandle(softmax_out), std::move(gen)); } // grad_q, grad_k, grad_v are output args, and should be pre-allocated.