Skip to content

Commit

Permalink
Fix QKV dtype in the bwd of FP8+CP (#1134)
Browse files Browse the repository at this point in the history
* fix qkv_dtype of FP8+CP

Signed-off-by: Xiaowei Ren <[email protected]>

* config cp correction dtype of FP8+CP

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* code style change

Signed-off-by: Xiaowei Ren <[email protected]>

* always do FP8 CP correction in FP32

Signed-off-by: Xiaowei Ren <[email protected]>

---------

Signed-off-by: Xiaowei Ren <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Charlene Yang <[email protected]>
  • Loading branch information
3 people authored Aug 30, 2024
1 parent aecd5a8 commit 9437ceb
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2261,8 +2261,9 @@ def backward(ctx, dout):

if ctx.fp8:
if ctx.use_fused_attention:
fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
fused_attn_qkv_dtype = fp8_dtype_backward
fused_attn_qkv_dtype = fp8_dtype_forward
fused_attn_dqkv_dtype = fp8_dtype_backward
fused_attn_backend = FusedAttnBackend["FP8"]
dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device)
Expand Down Expand Up @@ -2304,7 +2305,7 @@ def backward(ctx, dout):
if ctx.use_fused_attention:
fp8_meta_kwargs = {}
fused_attn_qkv_dtype = TE_DType[q.dtype]
fused_attn_dqkv_dtype = TE_DType[q.dtype]
fused_attn_dqkv_dtype = TE_DType[dout.dtype]
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]

out = out.view(*q.shape)
Expand Down

0 comments on commit 9437ceb

Please sign in to comment.