Skip to content

Commit

Permalink
resolve comments
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Aug 15, 2024
1 parent 0e837c3 commit 33c3ed6
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 168 deletions.
108 changes: 72 additions & 36 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3588,12 +3588,18 @@ def forward(
fused_attention_backend,
attn_bias,
cu_seqlens_padded,
fp8_meta["scaling_fwd"].scale_inv,
fp8_meta["scaling_fwd"].scale,
fp8_meta["scaling_fwd"].amax_history,
META_QKV,
META_S,
META_O,
fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv
META_QKV, # d_scale_qkv_offset
fp8_meta["scaling_fwd"].scale_inv, # d_scale_s
META_S, # d_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_s
META_S, # q_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_o
META_O, # q_scale_o_offset
fp8_meta["scaling_fwd"].amax_history, # amax_s
META_S, # amax_s_offset
fp8_meta["scaling_fwd"].amax_history, # amax_o
META_O, # amax_o_offset
attn_scale,
dropout_p,
fast_zero_fill,
Expand Down Expand Up @@ -3653,12 +3659,18 @@ def forward(
fused_attention_backend,
attn_bias,
cu_seqlens_padded,
None,
None,
None,
0,
0,
0,
None, # d_scale_qkv
0, # d_scale_qkv_offset
None, # d_scale_s
0, # d_scale_s_offset
None, # q_scale_s
0, # q_scale_s_offset
None, # q_scale_o
0, # q_scale_o_offset
None, # amax_s
0, # amax_s_offset
None, # amax_o
0, # amax_o_offset
attn_scale,
dropout_p,
fast_zero_fill,
Expand Down Expand Up @@ -3969,12 +3981,18 @@ def forward(
attn_bias,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
fp8_meta["scaling_fwd"].scale_inv,
fp8_meta["scaling_fwd"].scale,
fp8_meta["scaling_fwd"].amax_history,
META_QKV,
META_S,
META_O,
fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv
META_QKV, # d_scale_qkv_offset
fp8_meta["scaling_fwd"].scale_inv, # d_scale_s
META_S, # d_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_s
META_S, # q_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_o
META_O, # q_scale_o_offset
fp8_meta["scaling_fwd"].amax_history, # amax_s
META_S, # amax_s_offset
fp8_meta["scaling_fwd"].amax_history, # amax_o
META_O, # amax_o_offset
attn_scale,
dropout_p,
fast_zero_fill,
Expand Down Expand Up @@ -4042,12 +4060,18 @@ def forward(
attn_bias,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
None,
None,
None,
0,
0,
0,
None, # d_scale_qkv
0, # d_scale_qkv_offset
None, # d_scale_s
0, # d_scale_s_offset
None, # q_scale_s
0, # q_scale_s_offset
None, # q_scale_o
0, # q_scale_o_offset
None, # amax_s
0, # amax_s_offset
None, # amax_o
0, # amax_o_offset
attn_scale,
dropout_p,
fast_zero_fill,
Expand Down Expand Up @@ -4423,12 +4447,18 @@ def forward(
attn_bias,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
fp8_meta["scaling_fwd"].scale_inv,
fp8_meta["scaling_fwd"].scale,
fp8_meta["scaling_fwd"].amax_history,
META_QKV,
META_S,
META_O,
fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv
META_QKV, # d_scale_qkv_offset
fp8_meta["scaling_fwd"].scale_inv, # d_scale_s
META_S, # d_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_s
META_S, # q_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_o
META_O, # q_scale_o_offset
fp8_meta["scaling_fwd"].amax_history, # amax_s
META_S, # amax_s_offset
fp8_meta["scaling_fwd"].amax_history, # amax_o
META_O, # amax_o_offset
attn_scale,
dropout_p,
fast_zero_fill,
Expand Down Expand Up @@ -4546,12 +4576,18 @@ def forward(
attn_bias,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
None,
None,
None,
0,
0,
0,
None, # d_scale_qkv
0, # d_scale_qkv_offset
None, # d_scale_s
0, # d_scale_s_offset
None, # q_scale_s
0, # q_scale_s_offset
None, # q_scale_o
0, # q_scale_o_offset
None, # amax_s
0, # amax_s_offset
None, # amax_o
0, # amax_o_offset
attn_scale,
dropout_p,
fast_zero_fill,
Expand Down
Loading

0 comments on commit 33c3ed6

Please sign in to comment.