Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: fix flash atten #66

Closed
wants to merge 15 commits into from
2 changes: 1 addition & 1 deletion .github/workflows/static.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on:

env:
DEEPLINK_PATH: /mnt/cache/share/deeplinkci/github/${{ github.repository }}
ENV_SOURCE: /mnt/cache/share/platform/env/dipu_latest
ENV_SOURCE: /mnt/cache/share/platform/env/dipu_latest_ci
PROXY_SOURCE: /mnt/cache/share/platform/env/proxy
CLANGD_EXEC: /mnt/cache/share/platform/dep/clang-17/bin/clangd

Expand Down
108 changes: 108 additions & 0 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,80 @@ void extAdamW(at::Tensor& param, at::Tensor& exp_avg, at::Tensor& exp_avg_sq,
beta1, beta2, epsilon, weight_decay, step, amsgrad);
}

auto extFlashAttention(const at::Tensor& q, const at::Tensor& k,
const at::Tensor& v, double p_dropout,
double softmax_scale, bool is_causal, int64_t head_num) {
auto out = at::empty_like(q);
diopiTensorHandle_t attention_mask = nullptr;
diopiTensorHandle_t dropout_mask = nullptr;
diopiTensorHandle_t softmax_max = nullptr;
diopiTensorHandle_t softmax_sum = nullptr;
diopiTensorHandle_t softmax_out = nullptr;

auto gen = createDIPUGenerator();

[[maybe_unused]] auto context = callDiopiKeepContext(
diopiFlashAttention, out, &attention_mask, &dropout_mask, &softmax_max,
&softmax_sum, &softmax_out, gen, q, k, v, p_dropout, softmax_scale,
is_causal, head_num);

return std::make_tuple(
std::move(out),
attention_mask
? *dipu::diopi_helper::fromDiopiTensorHandle(attention_mask)
: at::Tensor(),
dropout_mask ? *dipu::diopi_helper::fromDiopiTensorHandle(dropout_mask)
: at::Tensor(),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_max),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_sum),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_out));
}

auto extFlashAttentionV2(const at::Tensor& q, const at::Tensor& k,
const at::Tensor& v, const at::Tensor& attention_mask,
double p_dropout, double softmax_scale,
int64_t head_num) {
auto out = at::empty_like(q);
diopiTensorHandle_t dropout_mask = nullptr;
diopiTensorHandle_t softmax_max = nullptr;
diopiTensorHandle_t softmax_sum = nullptr;
diopiTensorHandle_t softmax_out = nullptr;

auto gen = createDIPUGenerator();

[[maybe_unused]] auto context = callDiopiKeepContext(
diopiFlashAttentionV2, out, &dropout_mask, &softmax_max, &softmax_sum,
&softmax_out, gen, q, k, v, attention_mask, p_dropout, softmax_scale,
head_num);

return std::make_tuple(
std::move(out),
dropout_mask ? *dipu::diopi_helper::fromDiopiTensorHandle(dropout_mask)
: at::Tensor(),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_max),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_sum),
*dipu::diopi_helper::fromDiopiTensorHandle(softmax_out));
}

auto extFlashAttentionBackward(
c10::optional<at::Tensor>& grad_q_opt,
c10::optional<at::Tensor>& grad_k_opt,
c10::optional<at::Tensor>& grad_v_opt, const at::Tensor& grad_out,
const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
const at::Tensor& out, const at::Tensor& attention_mask,
const at::Tensor& dropout_mask, const at::Tensor& softmax_max,
const at::Tensor& softmax_sum, const at::Tensor& softmax_out,
double p_dropout, double softmax_scale, int64_t head_num) {
auto grad_q = grad_q_opt.has_value() ? grad_q_opt.value() : at::empty_like(q);
auto grad_k = grad_k_opt.has_value() ? grad_k_opt.value() : at::empty_like(k);
auto grad_v = grad_v_opt.has_value() ? grad_v_opt.value() : at::empty_like(v);
callDiopi(diopiFlashAttentionBackward, grad_q, grad_k, grad_v, grad_out, q, k,
v, out, attention_mask, dropout_mask, softmax_max, softmax_sum,
softmax_out, p_dropout, softmax_scale, head_num);
return std::make_tuple(std::move(grad_q), std::move(grad_k),
std::move(grad_v));
}

auto extRmsNorm(at::Tensor& output, at::Tensor& inv_rms,
const at::Tensor& input,
const OptionalIntArray& normalized_shape,
Expand Down Expand Up @@ -169,6 +243,23 @@ void extDestIndexCopyKV(const at::Tensor& k, const at::Tensor& dest_loc,
callDiopi(diopiDestIndexCopyKV, out, k, dest_loc);
}

auto extScaledMaskedSoftmax(const at::Tensor& input, const at::Tensor& mask,
double scale, bool fixed_triu_mask) {
auto out = at::empty_like(input);
callDiopi(diopiScaledMaskedSoftmax, out, input, mask, scale, fixed_triu_mask);
return out;
}

auto extScaledMaskedSoftmaxBackward(const at::Tensor& grad_output,
const at::Tensor& out,
const at::Tensor& mask, double scale,
bool fixed_triu_mask) {
auto grad_input = at::empty_like(grad_output);
callDiopi(diopiScaledMaskedSoftmaxBackward, grad_input, grad_output, out,
mask, scale, fixed_triu_mask);
return grad_input;
}

void extTokenAttentionInference(const at::Tensor& q, const at::Tensor& k,
at::Tensor& out, const at::Tensor& b_loc,
const at::Tensor& b_start_loc,
Expand Down Expand Up @@ -231,6 +322,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
if (&diopiAdamW != nullptr) {
m.def("adamw", &extAdamW, "deeplink ext_adamw");
}
if (&diopiFlashAttention != nullptr) {
m.def("fa_fwd", &extFlashAttention, "deeplink ext_fa_fwd");
}
if (&diopiFlashAttentionV2 != nullptr) {
m.def("fa_fwd_v2", &extFlashAttentionV2, "deeplink ext_fa_fwd_v2");
}
if (&diopiFlashAttentionBackward != nullptr) {
m.def("fa_bwd", &extFlashAttentionBackward, "deeplink ext_fa_bwd");
}
if (&diopiRMSNorm != nullptr) {
m.def("rms_norm", &extRmsNorm, "deeplink ext_rms_norm");
}
Expand Down Expand Up @@ -259,6 +359,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dest_index_copy_kv", &extDestIndexCopyKV,
"deeplink ext_dest_index_copy_kv");
}
if (&diopiScaledMaskedSoftmax != nullptr) {
m.def("scaled_masked_softmax_fwd", &extScaledMaskedSoftmax,
"deeplink ext_scaled_masked_softmax_fwd");
}
if (&diopiScaledMaskedSoftmaxBackward != nullptr) {
m.def("scaled_masked_softmax_bwd", &extScaledMaskedSoftmaxBackward,
"deeplink ext_scaled_masked_softmax_bwd");
}
if (&diopiTokenAttentionInference != nullptr) {
m.def("token_attention_inference", &extTokenAttentionInference,
"deeplink ext_token_attention_inference");
Expand Down
Loading
Loading