Skip to content

Commit

Permalink
feat: support varlen flash attention on ascend (#101)
Browse files Browse the repository at this point in the history
Support varlen flash attention on ascend.

---------

Co-authored-by: yangbo <[email protected]>
  • Loading branch information
POI-WX and yangbofun authored May 6, 2024
1 parent 9dfc6b1 commit 63bc8ce
Show file tree
Hide file tree
Showing 9 changed files with 957 additions and 89 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/static.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ concurrency:
cancel-in-progress: true

jobs:
static-checks-on-sco:
name: static checks on sco
static-checks-on-sco-nvidia:
name: static checks on nv
runs-on: tps-sco-ci
steps:
- name: setting up environment
Expand Down Expand Up @@ -88,10 +88,11 @@ jobs:
- name: prepare code
run: |
set -ex
rm ${DEEPLINK_PATH}/${{ github.run_number }}/* -rf
mkdir -p ${DEEPLINK_PATH}/${{ github.run_number }} && cd ${DEEPLINK_PATH}/${{ github.run_number }}
git clone https://github.com/DeepLink-org/DeepLinkExt.git && cd DeepLinkExt
git checkout ${{ github.event.pull_request.head.sha }} && git merge --no-edit ${{ github.base_ref }}
- name: build dipu and deeplink_ext
- name: build deeplink_ext
run: |
source /mnt/cache/share/platform/cienv/dipu_latest_ci
cd ${DEEPLINK_PATH}/${{ github.run_number }}/DeepLinkExt
Expand All @@ -106,5 +107,4 @@ jobs:
run: |
source /mnt/cache/share/platform/cienv/dipu_latest_ci
cd ${DEEPLINK_PATH}/${{ github.run_number }}/DeepLinkExt
python -m pytest tests/internevo
python -m pytest tests/ascend_speed
python -m pytest tests
61 changes: 58 additions & 3 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ auto extFlashAttentionV2(at::Tensor& out, const at::Tensor& q,
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_out));
}

auto extFlashAttentionBackward(at::Tensor& grad_q, at::Tensor& grad_k,
void extFlashAttentionBackward(at::Tensor& grad_q, at::Tensor& grad_k,
at::Tensor& grad_v, const at::Tensor& grad_out,
const at::Tensor& q, const at::Tensor& k,
const at::Tensor& v, const at::Tensor& out,
Expand All @@ -230,8 +230,6 @@ auto extFlashAttentionBackward(at::Tensor& grad_q, at::Tensor& grad_k,
v, out, attention_mask, dropout_mask, softmax_max, softmax_sum,
softmax_out, p_dropout, softmax_scale, head_num,
input_layout.c_str());
return std::make_tuple(std::move(grad_q), std::move(grad_k),
std::move(grad_v));
}

void extFlashAttentionV3(at::Tensor& out, at::Tensor& softmax_lse,
Expand All @@ -255,6 +253,55 @@ void extFlashAttentionV3Backward(at::Tensor& grad_q, at::Tensor& grad_k,
is_causal);
}

// for ascend
auto extFlashAttentionVarLen(at::Tensor& out, const at::Tensor& q,
const at::Tensor& k, const at::Tensor& v,
at::Generator& gen,
const at::IntArrayRef& cum_seq_q,
const at::IntArrayRef& cum_seq_k,
int64_t max_seqlen_q, int64_t max_seqlen_kv,
double p_dropout, double softmax_scale,
bool is_causal) {
diopiTensorHandle_t attention_mask = nullptr;
diopiTensorHandle_t dropout_mask = nullptr;
diopiTensorHandle_t softmax_max = nullptr;
diopiTensorHandle_t softmax_sum = nullptr;
diopiTensorHandle_t softmax_out = nullptr;

[[maybe_unused]] auto context = callDiopiKeepContext(
diopiFlashAttentionVarLen, out, &attention_mask, &dropout_mask,
&softmax_max, &softmax_sum, &softmax_out, gen, q, k, v, cum_seq_q,
cum_seq_k, max_seqlen_q, max_seqlen_kv, p_dropout, softmax_scale,
is_causal);

return std::make_tuple(
attention_mask
? *dipu::diopi_helper::fromDiopiTensorHandle(attention_mask)
: at::Tensor(),
dropout_mask ? *dipu::diopi_helper::fromDiopiTensorHandle(dropout_mask)
: at::Tensor(),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_max),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_sum),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_out));
}

// for ascend
void extFlashAttentionVarLenBackward(
at::Tensor& grad_q, at::Tensor& grad_k, at::Tensor& grad_v,
const at::Tensor& grad_out, const at::Tensor& q, const at::Tensor& k,
const at::Tensor& v, const at::IntArrayRef& cum_seq_q,
const at::IntArrayRef& cum_seq_k, const at::Tensor& out,
const c10::optional<at::Tensor>& attention_mask,
const c10::optional<at::Tensor>& dropout_mask,
const at::Tensor& softmax_max, const at::Tensor& softmax_sum,
const at::Tensor& softmax_out, int64_t max_seqlen_q, int64_t max_seqlen_kv,
double p_dropout, double softmax_scale) {
callDiopi(diopiFlashAttentionVarLenBackward, grad_q, grad_k, grad_v, grad_out,
q, k, v, cum_seq_q, cum_seq_k, out, attention_mask, dropout_mask,
softmax_max, softmax_sum, softmax_out, max_seqlen_q, max_seqlen_kv,
p_dropout, softmax_scale);
}

void extScaledMaskedSoftmax(at::Tensor& out, const at::Tensor& input,
const at::Tensor& mask, double scale,
bool fixed_triu_mask) {
Expand Down Expand Up @@ -337,6 +384,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
if (&diopiFlashAttentionV3Backward != nullptr) {
m.def("fa_bwd_v3", &extFlashAttentionV3Backward, "deeplink ext_fa_bwd_v3");
}
if (&diopiFlashAttentionVarLen != nullptr) {
m.def("fa_varlen_fwd", &extFlashAttentionVarLen,
"deeplink ext_fa_varlen_fwd");
}
if (&diopiFlashAttentionVarLenBackward != nullptr) {
m.def("fa_varlen_bwd", &extFlashAttentionVarLenBackward,
"deeplink ext_fa_varlen_bwd");
}
if (&diopiRMSNorm != nullptr) {
m.def("rms_norm", &extRmsNorm, "deeplink ext_rms_norm");
}
Expand Down
Loading

0 comments on commit 63bc8ce

Please sign in to comment.