Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support flash attention for InternLM on ascend #41

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <diopi/functions_ext.h>

#include <csrc_dipu/diopirt/diopirt_impl.h>
#include <csrc_dipu/runtime/core/DIPUGeneratorImpl.h>

#include "diopi_helper.h"
Expand Down Expand Up @@ -126,6 +127,57 @@ auto extMultiHeadAttentionBackward(const at::Tensor& grad_out,
std::move(grad_v));
}

// for ascend
auto extFlashAttention(const at::Tensor& q, const at::Tensor& k,
const at::Tensor& v, double p_dropout,
double softmax_scale, bool is_causal, int64_t head_num) {
auto out = at::empty_like(q);
diopiTensorHandle_t attention_mask = nullptr;
diopiTensorHandle_t dropout_mask = nullptr;
diopiTensorHandle_t softmax_max = nullptr;
diopiTensorHandle_t softmax_sum = nullptr;
diopiTensorHandle_t softmax_out = nullptr;

auto gen = createDIPUGenerator();

[[maybe_unused]] auto context = callDiopiKeepContext(
diopiFlashAttention, out, &attention_mask, &dropout_mask, &softmax_max,
&softmax_sum, &softmax_out, gen, q, k, v, p_dropout, softmax_scale,
is_causal, head_num);

return std::make_tuple(
std::move(out),
attention_mask
? *dipu::diopi_helper::fromDiopiTensorHandle(attention_mask)
: at::Tensor(),
dropout_mask ? *dipu::diopi_helper::fromDiopiTensorHandle(dropout_mask)
: at::Tensor(),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_max),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_sum),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_out));
}

// for ascend
// grad_q, grad_k, grad_v are output args, and should be pre-allocated.
auto extFlashAttentionBackward(
c10::optional<at::Tensor>& grad_q_opt,
c10::optional<at::Tensor>& grad_k_opt,
c10::optional<at::Tensor>& grad_v_opt, const at::Tensor& grad_out,
const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
const at::Tensor& out, const at::Tensor& attention_mask,
const at::Tensor& dropout_mask, const at::Tensor& softmax_max,
const at::Tensor& softmax_sum, const at::Tensor& softmax_out,
double p_dropout, double softmax_scale, int64_t head_num) {
auto grad_q = grad_q_opt.has_value() ? grad_q_opt.value() : at::empty_like(q);
auto grad_k = grad_k_opt.has_value() ? grad_k_opt.value() : at::empty_like(k);
auto grad_v = grad_v_opt.has_value() ? grad_v_opt.value() : at::empty_like(v);
callDiopi(diopiFlashAttentionBackward, grad_q, grad_k, grad_v, grad_out, q, k,
v, out, attention_mask, dropout_mask, softmax_max, softmax_sum,
softmax_out, p_dropout, softmax_scale, head_num);
return std::make_tuple(std::move(grad_q), std::move(grad_k),
std::move(grad_v));
}

auto extMultiHeadAttentionVarLen(at::Tensor& q, at::Tensor& k, at::Tensor& v,
const at::Tensor& cum_seq_q,
const at::Tensor& cum_seq_k,
Expand Down Expand Up @@ -176,6 +228,62 @@ auto extMultiHeadAttentionVarLenBackward(
std::move(grad_v));
}

// // for ascend
// auto extFlashAttentionVarLen(const at::Tensor& q, const at::Tensor& k,
// const at::Tensor& v, const at::Tensor&
// cum_seq_q, const at::Tensor& cum_seq_k, double
// p_dropout, double softmax_scale, bool is_causal)
// {
// auto out = at::empty_like(q);
// diopiTensorHandle_t attention_mask = nullptr;
// diopiTensorHandle_t dropout_mask = nullptr;
// diopiTensorHandle_t softmax_max = nullptr;
// diopiTensorHandle_t softmax_sum = nullptr;
// diopiTensorHandle_t softmax_out = nullptr;

// auto gen = createDIPUGenerator();

// [[maybe_unused]] auto context = callDiopiKeepContext(
// diopiFlashAttentionVarLen, out, &attention_mask, &dropout_mask,
// &softmax_max, &softmax_sum, &softmax_out, gen, q, k, v, cum_seq_q,
// cum_seq_k, p_dropout, softmax_scale, is_causal);

// return std::make_tuple(
// std::move(out),
// attention_mask
// ? *dipu::diopi_helper::fromDiopiTensorHandle(attention_mask)
// : at::Tensor(),
// dropout_mask ? *dipu::diopi_helper::fromDiopiTensorHandle(dropout_mask)
// : at::Tensor(),
// *dipu::diopi_helper::fromDiopiTensorHandle(softmax_max),
// *dipu::diopi_helper::fromDiopiTensorHandle(softmax_sum),
// *dipu::diopi_helper::fromDiopiTensorHandle(softmax_out));
// }

// // for ascend
// // grad_q, grad_k, grad_v are output args, and should be pre-allocated.
// auto extFlashAttentionVarLenBackward(
// c10::optional<at::Tensor>& grad_q_opt,
// c10::optional<at::Tensor>& grad_k_opt,
// c10::optional<at::Tensor>& grad_v_opt, const at::Tensor& grad_out,
// const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
// const at::Tensor& cum_seq_q, const at::Tensor& cum_seq_k,
// const at::Tensor& out, const at::Tensor& attention_mask,
// const at::Tensor& dropout_mask, const at::Tensor& softmax_max,
// const at::Tensor& softmax_sum, const at::Tensor& softmax_out,
// double p_dropout, double softmax_scale) {
// auto grad_q = grad_q_opt.has_value() ? grad_q_opt.value() :
// at::empty_like(q); auto grad_k = grad_k_opt.has_value() ?
// grad_k_opt.value() : at::empty_like(k); auto grad_v =
// grad_v_opt.has_value() ? grad_v_opt.value() : at::empty_like(v);
// callDiopi(diopiFlashAttentionVarLenBackward, grad_q, grad_k, grad_v,
// grad_out,
// q, k, v, cum_seq_q, cum_seq_k, out, attention_mask, dropout_mask,
// softmax_max, softmax_sum, softmax_out, p_dropout, softmax_scale);
// return std::make_tuple(std::move(grad_q), std::move(grad_k),
// std::move(grad_v));
// }

void extDestIndexCopyKV(const at::Tensor& k, const at::Tensor& dest_loc,
at::Tensor& out) {
callDiopi(diopiDestIndexCopyKV, out, k, dest_loc);
Expand Down Expand Up @@ -277,6 +385,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("mha_varlen_bwd", &extMultiHeadAttentionVarLenBackward,
"deeplink ext_mha_varlen_bwd");
}
if (&diopiFlashAttention != nullptr) {
m.def("fa_fwd", &extFlashAttention, "deeplink ext_fa_fwd");
}
if (&diopiFlashAttentionBackward != nullptr) {
m.def("fa_bwd", &extFlashAttentionBackward, "deeplink ext_fa_bwd");
}
// if (&diopiFlashAttentionVarLen != nullptr) {
// m.def("fa_varlen_fwd", &extFlashAttentionVarLen,
// "deeplink ext_fa_varlen_fwd");
// }
// if (&diopiFlashAttentionVarLenBackward != nullptr) {
// m.def("fa_varlen_bwd", &extFlashAttentionVarLenBackward,
// "deeplink ext_fa_varlen_bwd");
// }
if (&diopiDestIndexCopyKV != nullptr) {
m.def("dest_index_copy_kv", &extDestIndexCopyKV,
"deeplink ext_dest_index_copy_kv");
Expand Down
76 changes: 76 additions & 0 deletions deeplink_ext/internlm_ops/mha/fa_kvpacked_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2024, DeepLink.

import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "fa_fwd") and hasattr(ext, "fa_bwd")


class DeepLinkFlashAttentionKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, kv, dropout_p, softmax_scale, causal):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
head_num = q.shape[2]
(
out,
attention_mask,
dropout_mask,
softmax_max,
softmax_sum,
softmax_out,
) = ext.fa_fwd(
q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal, head_num
)
ctx.save_for_backward(
q,
kv,
out,
attention_mask,
dropout_mask,
softmax_max,
softmax_sum,
softmax_out,
)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.head_num = head_num
return out

@staticmethod
def backward(ctx, dout):
(
q,
kv,
out,
attention_mask,
dropout_mask,
softmax_max,
softmax_sum,
softmax_out,
) = ctx.saved_tensors
attention_mask = (
torch.Tensor().cuda() if attention_mask is None else attention_mask
)
dropout_mask = torch.Tensor().cuda() if dropout_mask is None else dropout_mask
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
ext.fa_bwd(
dq,
dkv[:, :, 0],
dkv[:, :, 1],
dout,
q,
kv[:, :, 0],
kv[:, :, 1],
out,
attention_mask,
dropout_mask,
softmax_max,
softmax_sum,
softmax_out,
ctx.dropout_p,
ctx.softmax_scale,
ctx.head_num,
)
return dq, dkv, None, None, None, None
79 changes: 79 additions & 0 deletions deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2024, DeepLink.

import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "fa_fwd") and hasattr(ext, "fa_bwd")


class DeepLinkFlashAttentionQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, dropout_p, softmax_scale, causal):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
head_num = qkv.shape[3]
(
out,
attention_mask,
dropout_mask,
softmax_max,
softmax_sum,
softmax_out,
) = ext.fa_fwd(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
dropout_p,
softmax_scale,
causal,
head_num,
)
ctx.save_for_backward(
qkv,
out,
attention_mask,
dropout_mask,
softmax_max,
softmax_sum,
softmax_out,
)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.head_num = head_num
return out

@staticmethod
def backward(ctx, dout):
(
qkv,
out,
attention_mask,
dropout_mask,
softmax_max,
softmax_sum,
softmax_out,
) = ctx.saved_tensors
attention_mask = (
torch.Tensor().cuda() if attention_mask is None else attention_mask
)
dropout_mask = torch.Tensor().cuda() if dropout_mask is None else dropout_mask
dqkv = torch.empty_like(qkv)
ext.fa_bwd(
dqkv[:, :, 0],
dqkv[:, :, 1],
dqkv[:, :, 2],
dout,
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
out,
attention_mask,
dropout_mask,
softmax_max,
softmax_sum,
softmax_out,
ctx.dropout_p,
ctx.softmax_scale,
ctx.head_num,
)
return dqkv, None, None, None, None
Loading
Loading