Skip to content

Commit

Permalink
move import statement
Browse files Browse the repository at this point in the history
  • Loading branch information
MayDomine committed Jun 11, 2024
1 parent 1614ad1 commit 4e1bc98
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions bmtrain/nn/burst_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import bmtrain as bmt
import torch
from flash_attn.flash_attn_interface import (
_flash_attn_forward as _flash_attn_forward_cuda,
)
from flash_attn.flash_attn_interface import (
_flash_attn_backward as _flash_attn_backward_cuda,
)
import inspect

class ops_wrapper:
Expand Down Expand Up @@ -228,6 +222,9 @@ def inter_flash_attn_backward_triton(


def inter_flash_cuda_fwd(q, k, v, o, lse, softmax_scale=1.0, causal=False):
from flash_attn.flash_attn_interface import (
_flash_attn_forward as _flash_attn_forward_cuda,
)
o_i, _, _, _, _, lse_i, _, _ = _flash_attn_forward_cuda(
q,
k,
Expand All @@ -250,6 +247,9 @@ def inter_flash_cuda_fwd(q, k, v, o, lse, softmax_scale=1.0, causal=False):
def inter_flash_cuda_bwd(do, q, k, v, o, lse, dq, dk, dv, softmax_scale, causal):
dk_ = torch.empty_like(q)
dv_ = torch.empty_like(q)
from flash_attn.flash_attn_interface import (
_flash_attn_backward as _flash_attn_backward_cuda,
)
if len(o.shape) == 3:
# use sum(o_i * gradoutput) as delta and pass a empty out to flash backward
# this feature requires a build of this PR: https://github.com/Dao-AILab/flash-attention/pull/905
Expand Down

0 comments on commit 4e1bc98

Please sign in to comment.