Skip to content

Commit

Permalink
fix according to black
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Mar 19, 2024
1 parent bacd543 commit db78907
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 34 deletions.
11 changes: 8 additions & 3 deletions deeplink_ext/internlm_ops/mha/fa_kvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ class DeepLinkFlashAttentionKVPackedFunc(torch.autograd.Function):
def forward(ctx, q, kv, dropout_p, softmax_scale, causal):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, attention_mask, dropout_mask, softmax_max, softmax_sum, softmax_out = (
ext.fa_fwd(q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal)
)
(
out,
attention_mask,
dropout_mask,
softmax_max,
softmax_sum,
softmax_out,
) = ext.fa_fwd(q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal)
ctx.save_for_backward(
q,
kv,
Expand Down
23 changes: 14 additions & 9 deletions deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@ class DeepLinkFlashAttentionQKVPackedFunc(torch.autograd.Function):
def forward(ctx, qkv, dropout_p, softmax_scale, causal):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, attention_mask, dropout_mask, softmax_max, softmax_sum, softmax_out = (
ext.fa_fwd(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
dropout_p,
softmax_scale,
causal,
)
(
out,
attention_mask,
dropout_mask,
softmax_max,
softmax_sum,
softmax_out,
) = ext.fa_fwd(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
dropout_p,
softmax_scale,
causal,
)
ctx.save_for_backward(
qkv,
Expand Down
27 changes: 16 additions & 11 deletions deeplink_ext/internlm_ops/mha/fa_varlen_kvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,22 @@ def forward(
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, attention_mask, dropout_mask, softmax_max, softmax_sum, softmax_out = (
ext.fa_varlen_fwd(
q,
kv[:, :, 0],
kv[:, :, 1],
cu_seqlens_q[1:],
cu_seqlens_k[1:],
dropout_p,
softmax_scale,
causal,
)
(
out,
attention_mask,
dropout_mask,
softmax_max,
softmax_sum,
softmax_out,
) = ext.fa_varlen_fwd(
q,
kv[:, :, 0],
kv[:, :, 1],
cu_seqlens_q[1:],
cu_seqlens_k[1:],
dropout_p,
softmax_scale,
causal,
)
ctx.save_for_backward(
q,
Expand Down
27 changes: 16 additions & 11 deletions deeplink_ext/internlm_ops/mha/fa_varlen_qkvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,22 @@ def forward(
):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, attention_mask, dropout_mask, softmax_max, softmax_sum, softmax_out = (
ext.fa_varlen_fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
cu_seqlens[1:].to(torch.int64),
cu_seqlens[1:].to(torch.int64),
dropout_p,
softmax_scale,
causal,
)
(
out,
attention_mask,
dropout_mask,
softmax_max,
softmax_sum,
softmax_out,
) = ext.fa_varlen_fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
cu_seqlens[1:].to(torch.int64),
cu_seqlens[1:].to(torch.int64),
dropout_p,
softmax_scale,
causal,
)
ctx.save_for_backward(
qkv,
Expand Down

0 comments on commit db78907

Please sign in to comment.