diff --git a/bmtrain/nn/burst_utils.py b/bmtrain/nn/burst_utils.py index 413100f..ebd9099 100644 --- a/bmtrain/nn/burst_utils.py +++ b/bmtrain/nn/burst_utils.py @@ -1,11 +1,5 @@ import bmtrain as bmt import torch -from flash_attn.flash_attn_interface import ( - _flash_attn_forward as _flash_attn_forward_cuda, -) -from flash_attn.flash_attn_interface import ( - _flash_attn_backward as _flash_attn_backward_cuda, -) import inspect class ops_wrapper: @@ -228,6 +222,9 @@ def inter_flash_attn_backward_triton( def inter_flash_cuda_fwd(q, k, v, o, lse, softmax_scale=1.0, causal=False): + from flash_attn.flash_attn_interface import ( + _flash_attn_forward as _flash_attn_forward_cuda, + ) o_i, _, _, _, _, lse_i, _, _ = _flash_attn_forward_cuda( q, k, @@ -250,6 +247,9 @@ def inter_flash_cuda_fwd(q, k, v, o, lse, softmax_scale=1.0, causal=False): def inter_flash_cuda_bwd(do, q, k, v, o, lse, dq, dk, dv, softmax_scale, causal): dk_ = torch.empty_like(q) dv_ = torch.empty_like(q) + from flash_attn.flash_attn_interface import ( + _flash_attn_backward as _flash_attn_backward_cuda, + ) if len(o.shape) == 3: # use sum(o_i * gradoutput) as delta and pass a empty out to flash backward # this feature requires a build of this PR: https://github.com/Dao-AILab/flash-attention/pull/905