diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 7b361054..7e5fe344 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -167,16 +167,15 @@ auto extMultiHeadAttentionVarLenBackward( auto extFlashAttention(at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, - double p_dropout, double softmax_scale, bool is_causal, - int64_t head_num, const std::string& input_layout) { + at::Generator& gen, double p_dropout, + double softmax_scale, bool is_causal, int64_t head_num, + const std::string& input_layout) { diopiTensorHandle_t attention_mask = nullptr; diopiTensorHandle_t dropout_mask = nullptr; diopiTensorHandle_t softmax_max = nullptr; diopiTensorHandle_t softmax_sum = nullptr; diopiTensorHandle_t softmax_out = nullptr; - auto gen = createDIPUGenerator(); - [[maybe_unused]] auto context = callDiopiKeepContext( diopiFlashAttention, out, &attention_mask, &dropout_mask, &softmax_max, &softmax_sum, &softmax_out, gen, q, k, v, p_dropout, softmax_scale, @@ -195,16 +194,14 @@ auto extFlashAttention(at::Tensor& out, const at::Tensor& q, auto extFlashAttentionV2(at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, - const at::Tensor& attention_mask, double p_dropout, - double softmax_scale, int64_t head_num, - const std::string& input_layout) { + at::Generator& gen, const at::Tensor& attention_mask, + double p_dropout, double softmax_scale, + int64_t head_num, const std::string& input_layout) { diopiTensorHandle_t dropout_mask = nullptr; diopiTensorHandle_t softmax_max = nullptr; diopiTensorHandle_t softmax_sum = nullptr; diopiTensorHandle_t softmax_out = nullptr; - auto gen = createDIPUGenerator(); - [[maybe_unused]] auto context = callDiopiKeepContext( diopiFlashAttentionV2, out, &dropout_mask, &softmax_max, &softmax_sum, &softmax_out, gen, q, k, v, attention_mask, p_dropout, softmax_scale, diff --git a/deeplink_ext/ascend_speed/flash_attention.py b/deeplink_ext/ascend_speed/flash_attention.py index 8d1b1a93..6d0d14e1 100644 --- a/deeplink_ext/ascend_speed/flash_attention.py +++ b/deeplink_ext/ascend_speed/flash_attention.py @@ -1,4 +1,5 @@ import torch +import torch_dipu import deeplink_ext.cpp_extensions as ext assert hasattr(ext, "fa_fwd_v2") and hasattr(ext, "fa_bwd") @@ -10,6 +11,7 @@ def forward( ctx, q, k, v, attention_mask, dropout_p, softmax_scale, head_num, input_layout ): out = torch.empty_like(q) + gen = torch_dipu._C._create_dipu_generator(-1) ( dropout_mask, softmax_max, @@ -20,6 +22,7 @@ def forward( q, k, v, + gen, attention_mask, dropout_p, softmax_scale,