Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Apr 1, 2024
1 parent 4b5ca01 commit 24e046b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
15 changes: 6 additions & 9 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,15 @@ auto extMultiHeadAttentionVarLenBackward(

auto extFlashAttention(at::Tensor& out, const at::Tensor& q,
const at::Tensor& k, const at::Tensor& v,
double p_dropout, double softmax_scale, bool is_causal,
int64_t head_num, const std::string& input_layout) {
at::Generator& gen, double p_dropout,
double softmax_scale, bool is_causal, int64_t head_num,
const std::string& input_layout) {
diopiTensorHandle_t attention_mask = nullptr;
diopiTensorHandle_t dropout_mask = nullptr;
diopiTensorHandle_t softmax_max = nullptr;
diopiTensorHandle_t softmax_sum = nullptr;
diopiTensorHandle_t softmax_out = nullptr;

auto gen = createDIPUGenerator();

[[maybe_unused]] auto context = callDiopiKeepContext(
diopiFlashAttention, out, &attention_mask, &dropout_mask, &softmax_max,
&softmax_sum, &softmax_out, gen, q, k, v, p_dropout, softmax_scale,
Expand All @@ -195,16 +194,14 @@ auto extFlashAttention(at::Tensor& out, const at::Tensor& q,

auto extFlashAttentionV2(at::Tensor& out, const at::Tensor& q,
const at::Tensor& k, const at::Tensor& v,
const at::Tensor& attention_mask, double p_dropout,
double softmax_scale, int64_t head_num,
const std::string& input_layout) {
at::Generator& gen, const at::Tensor& attention_mask,
double p_dropout, double softmax_scale,
int64_t head_num, const std::string& input_layout) {
diopiTensorHandle_t dropout_mask = nullptr;
diopiTensorHandle_t softmax_max = nullptr;
diopiTensorHandle_t softmax_sum = nullptr;
diopiTensorHandle_t softmax_out = nullptr;

auto gen = createDIPUGenerator();

[[maybe_unused]] auto context = callDiopiKeepContext(
diopiFlashAttentionV2, out, &dropout_mask, &softmax_max, &softmax_sum,
&softmax_out, gen, q, k, v, attention_mask, p_dropout, softmax_scale,
Expand Down
3 changes: 3 additions & 0 deletions deeplink_ext/ascend_speed/flash_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch_dipu
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "fa_fwd_v2") and hasattr(ext, "fa_bwd")
Expand All @@ -10,6 +11,7 @@ def forward(
ctx, q, k, v, attention_mask, dropout_p, softmax_scale, head_num, input_layout
):
out = torch.empty_like(q)
gen = torch_dipu._C._create_dipu_generator(-1)
(
dropout_mask,
softmax_max,
Expand All @@ -20,6 +22,7 @@ def forward(
q,
k,
v,
gen,
attention_mask,
dropout_p,
softmax_scale,
Expand Down

0 comments on commit 24e046b

Please sign in to comment.