Skip to content

Commit

Permalink
Update cudnn-frontend to v1.6.1 (#1108)
Browse files Browse the repository at this point in the history
* update FE to 1.6

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update to 1.6.1-rc for testing

Signed-off-by: Charlene Yang <[email protected]>

* update to fe 1.6.1

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cyanguwa and pre-commit-ci[bot] committed Aug 21, 2024
1 parent 8e3561b commit 525de6c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 118 files
30 changes: 24 additions & 6 deletions transformer_engine/common/fused_attn/fused_attn_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1835,8 +1835,14 @@ void fused_attn_fp8_fwd_impl_v1(
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_O_Matrix);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride);
amax_o->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT);
amax_s->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT);
amax_o->set_output(true)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
amax_s->set_output(true)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);

Stats->set_output(true)
.set_data_type(fe::DataType_t::FLOAT)
Expand Down Expand Up @@ -2182,10 +2188,22 @@ void fused_attn_fp8_bwd_impl_v1(
dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride);
dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride);
dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride);
amax_dQ->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT);
amax_dK->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT);
amax_dV->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT);
amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT);
amax_dQ->set_output(true)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
amax_dK->set_output(true)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
amax_dV->set_output(true)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
amax_dP->set_output(true)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);

dO->set_data_type(bwd_tensor_type);
dQ->set_data_type(bwd_tensor_type);
Expand Down

0 comments on commit 525de6c

Please sign in to comment.