From 525de6cc65d0a06ef13ba6cf67680796456475d3 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 20 Aug 2024 20:14:41 -0700 Subject: [PATCH] Update cudnn-frontend to v1.6.1 (#1108) * update FE to 1.6 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update to 1.6.1-rc for testing Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update to fe 1.6.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- .../common/fused_attn/fused_attn_fp8.cu | 30 +++++++++++++++---- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 98ca4e1941..2533f5e5c1 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 98ca4e1941fe3263f128f74f10063a3ea35c7019 +Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index bda3f5beba..fb7765e1a4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1835,8 +1835,14 @@ void fused_attn_fp8_fwd_impl_v1( generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); - amax_o->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_s->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_o->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_s->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); Stats->set_output(true) .set_data_type(fe::DataType_t::FLOAT) @@ -2182,10 +2188,22 @@ void fused_attn_fp8_bwd_impl_v1( dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); - amax_dQ->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dK->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dV->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); - amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_dQ->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dK->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dV->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); + amax_dP->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT); dO->set_data_type(bwd_tensor_type); dQ->set_data_type(bwd_tensor_type);