From add8cca8832e4ee4668f9f83298d4a4ccd8a30db Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 11 Oct 2025 17:02:57 +0800 Subject: [PATCH 01/10] remove debug print --- tilelang/engine/phase.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index f64ac272b..5e2c9ec5c 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -156,11 +156,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: if allow_fence_proxy(target=target): # in hopper device, wgmma is an async proxy # so we need to inject a fence proxy before it - print("Before injectFenceProxy") - print(mod) mod = tilelang.transform.InjectFenceProxy()(mod) - print("After InjectFenceProxy") - print(mod) mod = tilelang.transform.LowerOpaqueBlock()(mod) mod = tir.transform.NarrowDataType(32)(mod) From 8783cd978f3c4bae2bd0797a8e56b17bddec3a2a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 11 Oct 2025 17:48:09 +0800 Subject: [PATCH 02/10] pipeline fix --- src/transform/inject_pipeline.cc | 8 +++ src/transform/legalize_block_access.cc | 86 ++++++++++++++++++++++++++ tilelang/engine/phase.py | 3 + tilelang/transform/__init__.py | 11 ++++ 4 files changed, 108 insertions(+) create mode 100644 src/transform/legalize_block_access.cc diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index 20f0861e2..5ea03e773 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -243,8 +243,16 @@ class PipelineRewriter : public StmtExprMutator { // number of versions need to maintain for each buffer. std::unordered_map infos = GetBufferAccessInfo(); + LOG(INFO) << "buffer_infos:"; + for (const auto &kv : infos) { + const Buffer &buffer = kv.first; + const BufferAccessInfo &info = kv.second; + LOG(INFO) << " buffer=" << buffer->name << " def=" << info.def + << " use=" << info.use; + } for (const Buffer &buffer : pipeline_allocs_) { int num_versions = ComputeBufferVersions(buffer, infos.at(buffer)); + LOG(INFO) << "Get num_versions " << num_versions; if (num_versions > 1) { buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions)); } diff --git a/src/transform/legalize_block_access.cc b/src/transform/legalize_block_access.cc new file mode 100644 index 000000000..9c4b3817a --- /dev/null +++ b/src/transform/legalize_block_access.cc @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2024 TileLang + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file legalize_block_access.cc + * \brief Populate block read/write regions ahead of pipeline passes. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +class BlockAccessLegalizer : public StmtExprMutator { +public: + explicit BlockAccessLegalizer(Map buffer_data_to_buffer) + : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {} + + Stmt VisitStmt_(const BlockNode *op) final { + for (const Buffer &alloc : op->alloc_buffers) { + buffer_data_to_buffer_.Set(alloc->data, alloc); + } + + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + + Array> access = + GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + BlockNode *n = block.CopyOnWrite(); + n->reads = access[0]; + n->writes = access[1]; + + for (const Buffer &alloc : op->alloc_buffers) { + buffer_data_to_buffer_.erase(alloc->data); + } + + return block; + } + +private: + Map buffer_data_to_buffer_; +}; + +namespace transform { + +tvm::transform::Pass LegalizeBlockAccess() { + auto pass_func = [](PrimFunc f, const IRModule &, + const tvm::transform::PassContext &) { + Map buffer_data_to_buffer; + for (const auto &[_, buffer] : f->buffer_map) { + buffer_data_to_buffer.Set(buffer->data, buffer); + } + BlockAccessLegalizer legalizer(buffer_data_to_buffer); + f.CopyOnWrite()->body = legalizer(f->body); + return f; + }; + return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, + "tl.LegalizeBlockAccess", {}); +} + +} // namespace transform + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LegalizeBlockAccess", + transform::LegalizeBlockAccess); +}); + +} // namespace tl +} // namespace tvm diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 5e2c9ec5c..02eb4f314 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -136,6 +136,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.WarpSpecialized()(mod) mod = tilelang.transform.InjectTmaBarrier()(mod) mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod) + mod = tilelang.transform.LegalizeBlockAccess()(mod) # if tma is not enabled, we can also do pipeline planning # to get better performance with async copy mod = tilelang.transform.PipelinePlanning()(mod) @@ -150,7 +151,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: else: mod = tilelang.transform.IfStmtBinding()(mod) mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) + mod = tilelang.transform.LegalizeBlockAccess()(mod) mod = tilelang.transform.PipelinePlanning()(mod) + print("before InjectSoftwarePipeline ", mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.MergeIfStmt()(mod) if allow_fence_proxy(target=target): diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index 8a01d7111..7c122a0c4 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -25,6 +25,17 @@ def ClusterPlanning(): return _ffi_api.ClusterPlanning() # type: ignore +def LegalizeBlockAccess(): + """Populate block reads/writes via access analysis. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LegalizeBlockAccess() # type: ignore + + def PipelinePlanning(): """infer the fragment/shared memory layout From 3a51846fe31128bb1e7d92f4ae9a81a679f0f42c Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 11 Oct 2025 18:15:06 +0800 Subject: [PATCH 03/10] use the correct buffer access scope --- src/transform/inject_pipeline.cc | 8 -- src/transform/legalize_block_access.cc | 86 -------------------- tilelang/engine/phase.py | 3 - tilelang/intrinsics/wgmma_macro_generator.py | 4 +- tilelang/transform/__init__.py | 11 --- 5 files changed, 2 insertions(+), 110 deletions(-) delete mode 100644 src/transform/legalize_block_access.cc diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index 5ea03e773..20f0861e2 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -243,16 +243,8 @@ class PipelineRewriter : public StmtExprMutator { // number of versions need to maintain for each buffer. std::unordered_map infos = GetBufferAccessInfo(); - LOG(INFO) << "buffer_infos:"; - for (const auto &kv : infos) { - const Buffer &buffer = kv.first; - const BufferAccessInfo &info = kv.second; - LOG(INFO) << " buffer=" << buffer->name << " def=" << info.def - << " use=" << info.use; - } for (const Buffer &buffer : pipeline_allocs_) { int num_versions = ComputeBufferVersions(buffer, infos.at(buffer)); - LOG(INFO) << "Get num_versions " << num_versions; if (num_versions > 1) { buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions)); } diff --git a/src/transform/legalize_block_access.cc b/src/transform/legalize_block_access.cc deleted file mode 100644 index 9c4b3817a..000000000 --- a/src/transform/legalize_block_access.cc +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright (c) 2024 TileLang - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! - * \file legalize_block_access.cc - * \brief Populate block read/write regions ahead of pipeline passes. - */ - -#include -#include -#include -#include -#include - -namespace tvm { -namespace tl { - -using namespace tir; - -class BlockAccessLegalizer : public StmtExprMutator { -public: - explicit BlockAccessLegalizer(Map buffer_data_to_buffer) - : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {} - - Stmt VisitStmt_(const BlockNode *op) final { - for (const Buffer &alloc : op->alloc_buffers) { - buffer_data_to_buffer_.Set(alloc->data, alloc); - } - - Block block = Downcast(StmtExprMutator::VisitStmt_(op)); - - Array> access = - GetBlockReadWriteRegion(block, buffer_data_to_buffer_); - BlockNode *n = block.CopyOnWrite(); - n->reads = access[0]; - n->writes = access[1]; - - for (const Buffer &alloc : op->alloc_buffers) { - buffer_data_to_buffer_.erase(alloc->data); - } - - return block; - } - -private: - Map buffer_data_to_buffer_; -}; - -namespace transform { - -tvm::transform::Pass LegalizeBlockAccess() { - auto pass_func = [](PrimFunc f, const IRModule &, - const tvm::transform::PassContext &) { - Map buffer_data_to_buffer; - for (const auto &[_, buffer] : f->buffer_map) { - buffer_data_to_buffer.Set(buffer->data, buffer); - } - BlockAccessLegalizer legalizer(buffer_data_to_buffer); - f.CopyOnWrite()->body = legalizer(f->body); - return f; - }; - return tvm::tir::transform::CreatePrimFuncPass(pass_func, 0, - "tl.LegalizeBlockAccess", {}); -} - -} // namespace transform - -TVM_FFI_STATIC_INIT_BLOCK({ - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tl.transform.LegalizeBlockAccess", - transform::LegalizeBlockAccess); -}); - -} // namespace tl -} // namespace tvm diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 02eb4f314..5e2c9ec5c 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -136,7 +136,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.WarpSpecialized()(mod) mod = tilelang.transform.InjectTmaBarrier()(mod) mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod) - mod = tilelang.transform.LegalizeBlockAccess()(mod) # if tma is not enabled, we can also do pipeline planning # to get better performance with async copy mod = tilelang.transform.PipelinePlanning()(mod) @@ -151,9 +150,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: else: mod = tilelang.transform.IfStmtBinding()(mod) mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) - mod = tilelang.transform.LegalizeBlockAccess()(mod) mod = tilelang.transform.PipelinePlanning()(mod) - print("before InjectSoftwarePipeline ", mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.MergeIfStmt()(mod) if allow_fence_proxy(target=target): diff --git a/tilelang/intrinsics/wgmma_macro_generator.py b/tilelang/intrinsics/wgmma_macro_generator.py index b2ee0a23a..9d64a15fe 100644 --- a/tilelang/intrinsics/wgmma_macro_generator.py +++ b/tilelang/intrinsics/wgmma_macro_generator.py @@ -245,9 +245,9 @@ def _warp_mma(A_buf, B_buf, C_local_buf): # TODO(lei): inject warpgroup_fence_operand for C_local_buf desc_a = T.alloc_descriptor() desc_b = T.alloc_descriptor() - T.initialize_descriptor(desc_a, A_buf.access_ptr("w"), a_swizzle_mode, + T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode, int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) - T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode, + T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) T.warpgroup_arrive() for ki in T.serial(0, (k_dim // micro_size_k)): diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index 7c122a0c4..8a01d7111 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -25,17 +25,6 @@ def ClusterPlanning(): return _ffi_api.ClusterPlanning() # type: ignore -def LegalizeBlockAccess(): - """Populate block reads/writes via access analysis. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.LegalizeBlockAccess() # type: ignore - - def PipelinePlanning(): """infer the fragment/shared memory layout From bac37ae6dae9f12e7dec077d9871565b49630387 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 12 Oct 2025 02:51:13 +0800 Subject: [PATCH 04/10] rs support --- src/op/builtin.h | 10 +- src/op/gemm.cc | 2 + src/target/codegen_cuda.cc | 143 ++- src/target/codegen_cuda.h | 4 + src/target/ptx.cc | 15 + src/target/ptx.h | 5 + src/tl_templates/cuda/instruction/mma.h | 153 +++ src/tl_templates/cuda/instruction/wgmma.h | 1025 ++++++++---------- tilelang/intrinsics/mma_macro_generator.py | 12 +- tilelang/intrinsics/wgmma_macro_generator.py | 33 +- tilelang/jit/adapter/wrapper.py | 26 +- tilelang/language/tir/op.py | 2 - 12 files changed, 763 insertions(+), 667 deletions(-) create mode 100644 src/tl_templates/cuda/instruction/mma.h diff --git a/src/op/builtin.h b/src/op/builtin.h index f8a80e021..a331dfe90 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -238,11 +238,11 @@ TVM_DLL const Op &ptx_wgmma_ss(); /*! * \brief tvm intrinsics for ptx tensor core wgmma instructions. * - * void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool - * a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm - * b_dtype_abbrv, StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr - * A_offset, Var B_descriptor, Var B_offset, Var C_data, Var C_offset, bool - * scale_out, bool scale_in_a, bool scale_in_b); + * void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, + * bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv, + * StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var + * B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out, + * bool scale_in_a, bool scale_in_b); */ TVM_DLL const Op &ptx_wgmma_rs(); diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 059f7f6f3..afee0ebe4 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -582,6 +582,8 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { if (A.scope() == "local.fragment") { ICHECK(B.scope() != "local.fragment"); + ICHECK(!trans_A) + << "gemm_rs requires the A operand to be in non-transposed layout."; op_name = "tl::gemm_rs"; } else if (B.scope() == "local.fragment") { op_name = "tl::gemm_sr"; diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index f1993bdd9..4918bfcb3 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -259,6 +259,12 @@ std::string CodeGenTileLangCUDA::Finish() { if (need_mma_h_) { decl_stream << "#include \n"; } + if (need_mma_instruction_h_) { + decl_stream << "#include \n"; + } + if (need_wgmma_instruction_h_) { + decl_stream << "#include \n"; + } if (enable_fp8_) { decl_stream << "#include \n"; } @@ -1494,14 +1500,41 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { std::string b_bias = this->PrintExpr(op->args[9]); std::string c_ref = this->PrintExpr(op->args[10]); std::string c_bias = this->PrintExpr(op->args[11]); - bool saturate = Downcast(op->args[12])->value; - std::string bit_op = - op->args.size() > 13 ? Downcast(op->args[13])->value : ""; - std::string asm_code = PrintMMAAssembly( - shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, - b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate); + auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype); + auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype); + auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype); + auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); + + need_mma_instruction_h_ = true; this->PrintIndent(); - this->stream << asm_code; + std::string mma_call = + "tl::mma_sync<(AType), (BType), (CType), (M), (N), (K), (TransA), " + "(TransB)>(reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), " + "reinterpret_cast((A_ptr) + (A_offset)), " + "reinterpret_cast((B_ptr) + (B_offset)));\n"; + tl::codegen::Replacer replacer; + replacer.register_rule("(AType)", + tl::codegen::ptx::DTypeEnumToString(dtype_a_enum)); + replacer.register_rule("(BType)", + tl::codegen::ptx::DTypeEnumToString(dtype_b_enum)); + replacer.register_rule("(CType)", + tl::codegen::ptx::DTypeEnumToString(dtype_c_enum)); + replacer.register_rule("(M)", std::to_string(m)); + replacer.register_rule("(N)", std::to_string(n)); + replacer.register_rule("(K)", std::to_string(k)); + replacer.register_rule("(TransA)", A_layout == "row" ? "false" : "true"); + replacer.register_rule("(TransB)", B_layout == "row" ? "false" : "true"); + replacer.register_rule("(ARegType)", + tl::codegen::GetMMARegisterType(dtype_a_enum)); + replacer.register_rule("(BRegType)", + tl::codegen::GetMMARegisterType(dtype_b_enum)); + replacer.register_rule("(A_ptr)", a_ref); + replacer.register_rule("(A_offset)", a_bias); + replacer.register_rule("(B_ptr)", b_ref); + replacer.register_rule("(B_offset)", b_bias); + replacer.register_rule("(C_ptr)", c_ref); + replacer.register_rule("(C_offset)", c_bias); + this->stream << replacer.rewrite(mma_call); } else if (op->op.same_as(builtin::ptx_mma_sp())) { // arg 0: shape: mXnXkX // arg 1: A layout: row/col @@ -1578,6 +1611,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b, a_is_shared, "", "", "", false); auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); + need_wgmma_instruction_h_ = true; std::string wgmma_asm_code = "tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), " "(tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), " @@ -1606,41 +1640,74 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { wgmma_asm_code = replacer.rewrite(wgmma_asm_code); this->stream << wgmma_asm_code; } else if (op->op.same_as(tl::ptx_wgmma_rs())) { - // arg 0: dtype - // arg 1: shape - // arg 2: A_layout - // arg 3: B_layout - // arg 4: A_dtype - // arg 5: B_dtype - // arg 6: C_dtype - // arg 7: multiplicand_a - // arg 8: multiplicand_b + // arg 0: shape + // arg 1: B_layout + // arg 2: A_dtype + // arg 3: B_dtype + // arg 4: C_dtype + // arg 5: multiplicand_a + // arg 6: multiplicand_a offset + // arg 7: multiplicand_b descriptor + // arg 8: multiplicand_b offset // arg 9: accumulator - // arg 10: saturate - ICHECK_EQ(op->args.size(), 15U) << "ptx_wgmma_rs args is " << op->args; + // arg 10: accumulator offset + // arg 11: scale_out + // arg 12: scale_in_a + // arg 13: scale_in_b + ICHECK_EQ(op->args.size(), 14U) << "ptx_wgmma_rs args is " << op->args; std::string shape = Downcast(op->args[0])->value; - bool A_layout = Downcast(op->args[1])->value; - bool B_layout = Downcast(op->args[2])->value; - std::string A_dtype = Downcast(op->args[3])->value; - std::string B_dtype = Downcast(op->args[4])->value; - std::string C_dtype = Downcast(op->args[5])->value; - std::string a_ref = this->PrintExpr(op->args[6]); - std::string A_offset = this->PrintExpr(op->args[7]); - std::string b_desc = this->PrintExpr(op->args[8]); - std::string B_offset = this->PrintExpr(op->args[9]); - std::string c_ref = this->PrintExpr(op->args[10]); - std::string c_offset = this->PrintExpr(op->args[11]); - bool scale_out = Downcast(op->args[12])->value; - bool scale_in_a = Downcast(op->args[13])->value; - bool scale_in_b = Downcast(op->args[14])->value; + bool b_is_k_major = Downcast(op->args[1])->value; + std::string A_dtype = Downcast(op->args[2])->value; + std::string B_dtype = Downcast(op->args[3])->value; + std::string C_dtype = Downcast(op->args[4])->value; + std::string a_ref = this->PrintExpr(op->args[5]); + std::string A_offset = this->PrintExpr(op->args[6]); + std::string b_desc = this->PrintExpr(op->args[7]); + std::string B_offset = this->PrintExpr(op->args[8]); + std::string c_ref = this->PrintExpr(op->args[9]); + std::string c_offset = this->PrintExpr(op->args[10]); + bool scale_out = Downcast(op->args[11])->value; + bool scale_in_a = Downcast(op->args[12])->value; + bool scale_in_b = Downcast(op->args[13])->value; + + auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype); + auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype); + auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype); + auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); - const bool a_is_shared = false; + need_wgmma_instruction_h_ = true; this->PrintIndent(); - std::string asm_code = PrintWGMMAAssembly( - shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, A_offset, - b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b, - a_is_shared, "", "", "", false); - this->stream << asm_code; + std::string wgmma_call = + "tl::wgmma_rs<(AType), (BType), (CType), (M), (N), (K), (tnspA), " + "(tnspB), (scaleA), (scaleB)>(reinterpret_cast((A_ptr) + (A_offset)), " + "uint64_t((desc_b) + (B_offset)), reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), " + "(scale_out));\n"; + + tl::codegen::Replacer replacer; + replacer.register_rule("(AType)", + tl::codegen::ptx::DTypeEnumToString(dtype_a_enum)); + replacer.register_rule("(BType)", + tl::codegen::ptx::DTypeEnumToString(dtype_b_enum)); + replacer.register_rule("(CType)", + tl::codegen::ptx::DTypeEnumToString(dtype_c_enum)); + replacer.register_rule("(M)", std::to_string(m)); + replacer.register_rule("(N)", std::to_string(n)); + replacer.register_rule("(K)", std::to_string(k)); + replacer.register_rule("(tnspA)", "false"); + replacer.register_rule("(tnspB)", b_is_k_major ? "false" : "true"); + replacer.register_rule("(scaleA)", scale_in_a ? "1" : "-1"); + replacer.register_rule("(scaleB)", scale_in_b ? "1" : "-1"); + replacer.register_rule("(CRegType)", + tl::codegen::GetMMARegisterType(dtype_c_enum)); + replacer.register_rule("(A_ptr)", a_ref); + replacer.register_rule("(A_offset)", A_offset); + replacer.register_rule("(desc_b)", b_desc); + replacer.register_rule("(B_offset)", B_offset); + replacer.register_rule("(C_ptr)", c_ref); + replacer.register_rule("(C_offset)", c_offset); + replacer.register_rule("(scale_out)", scale_out ? "true" : "false"); + wgmma_call = replacer.rewrite(wgmma_call); + this->stream << wgmma_call; } else if (op->op.same_as(builtin::ptx_ldmatrix())) { // arg 0: whether the matrix is loaded in column major format or not. // arg 1: number of matrices to load. diff --git a/src/target/codegen_cuda.h b/src/target/codegen_cuda.h index 16ceff165..1618995e0 100644 --- a/src/target/codegen_cuda.h +++ b/src/target/codegen_cuda.h @@ -106,6 +106,10 @@ class CodeGenTileLangCUDA final : public CodeGenC { bool need_math_constants_h_{false}; // whether need mma.h bool need_mma_h_{false}; + // whether need tl mma instruction header + bool need_mma_instruction_h_{false}; + // whether need tl wgmma instruction header + bool need_wgmma_instruction_h_{false}; // whether need cast_smem_ptr_to_int helper function bool need_cast_smem_ptr_to_int_{false}; // whether need cooperative_groups.h diff --git a/src/target/ptx.cc b/src/target/ptx.cc index 9de548fc2..0710bffca 100644 --- a/src/target/ptx.cc +++ b/src/target/ptx.cc @@ -1529,5 +1529,20 @@ std::string PrintWaitBarrierAsm(const std::string &barrier) { return predicated_asm_code; } +std::string GetMMARegisterType(const ptx::DataType &dtype) { + switch (dtype) { + case ptx::DataType::kInt32: + return "unsigned"; + case ptx::DataType::kUInt32: + return "unsigned"; + case ptx::DataType::kFloat32: + return "float"; + case ptx::DataType::kFloat64: + return "double"; + default: + return "unsigned"; + } +} + } // namespace codegen } // namespace tvm::tl diff --git a/src/target/ptx.h b/src/target/ptx.h index 68d5b04a3..566cded6f 100644 --- a/src/target/ptx.h +++ b/src/target/ptx.h @@ -269,6 +269,11 @@ std::string PrintArriveBarrierExpectTxAsm(const std::string &barrier, */ std::string PrintWaitBarrierAsm(const std::string &barrier); +/*! + * \brief Return the register-level C++ type used by MMA fragments. + */ +std::string GetMMARegisterType(const ptx::DataType &dtype); + } // namespace codegen } // namespace tvm::tl diff --git a/src/tl_templates/cuda/instruction/mma.h b/src/tl_templates/cuda/instruction/mma.h new file mode 100644 index 000000000..d114c58f3 --- /dev/null +++ b/src/tl_templates/cuda/instruction/mma.h @@ -0,0 +1,153 @@ +#pragma once + +#include "../common.h" +#include +#include + +#include +#include + +namespace tl { + +#ifndef TL_ALWAYS_FALSE_V_DEFINED +#define TL_ALWAYS_FALSE_V_DEFINED +template inline constexpr bool always_false_v = false; +#endif + +namespace detail { + +template struct MmaImplTraits { + using DReg = std::remove_extent_t; + using AReg = std::remove_extent_t; + using BReg = std::remove_extent_t; + using CReg = std::remove_extent_t; + + static constexpr int kDRegs = std::extent_v; + static constexpr int kARegs = std::extent_v; + static constexpr int kBRegs = std::extent_v; + static constexpr int kCRegs = std::extent_v; +}; + +template +TL_DEVICE void call_fma_impl( + typename MmaImplTraits::DReg *d, + const typename MmaImplTraits::AReg *a, + const typename MmaImplTraits::BReg *b, + const typename MmaImplTraits::CReg *c, std::index_sequence, + std::index_sequence, std::index_sequence, + std::index_sequence) { + Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]...); +} + +template +TL_DEVICE void call_fma(typename MmaImplTraits::DReg *d, + const typename MmaImplTraits::AReg *a, + const typename MmaImplTraits::BReg *b, + const typename MmaImplTraits::CReg *c) { + call_fma_impl( + d, a, b, c, + std::make_index_sequence::kDRegs>{}, + std::make_index_sequence::kARegs>{}, + std::make_index_sequence::kBRegs>{}, + std::make_index_sequence::kCRegs>{}); +} + +template +struct MmaDispatcher { + using CRegType = void; + using ARegType = void; + using BRegType = void; + + static TL_DEVICE void exec(CRegType *, const ARegType *, const BRegType *, + const CRegType *) { + static_assert(always_false_v>, + "tl::mma_sync: unsupported configuration"); + } +}; + +#define TL_DEFINE_MMA_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, MValue, \ + NValue, KValue, TransAValue, TransBValue, \ + SaturateValue, ImplType) \ + template <> \ + struct MmaDispatcher { \ + using Impl = ImplType; \ + using Traits = MmaImplTraits; \ + using CRegType = typename Traits::DReg; \ + using ARegType = typename Traits::AReg; \ + using BRegType = typename Traits::BReg; \ + static_assert(std::is_same_v,\ + "tl::mma_sync requires matching accumulator/output regs"); \ + static TL_DEVICE void exec(CRegType *d, const ARegType *a, \ + const BRegType *b, const CRegType *c) { \ + call_fma(d, a, b, c); \ + } \ + }; + +// FP16 inputs (TN layout: A row-major, B column-major) +TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat16, 16, 8, 16, false, true, + false, cute::SM80_16x8x16_F16F16F16F16_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat32, 16, 8, 16, false, true, + false, cute::SM80_16x8x16_F32F16F16F32_TN) + +// BF16 inputs +TL_DEFINE_MMA_DISPATCHER(kBFloat16, kBFloat16, kFloat32, 16, 8, 16, false, true, + false, cute::SM80_16x8x16_F32BF16BF16F32_TN) + +// INT8 inputs (k32) +TL_DEFINE_MMA_DISPATCHER(kInt8, kInt8, kInt32, 16, 8, 32, false, true, false, + cute::SM80_16x8x32_S32S8S8S32_TN) +TL_DEFINE_MMA_DISPATCHER(kUInt8, kUInt8, kInt32, 16, 8, 32, false, true, false, + cute::SM80_16x8x32_S32U8U8S32_TN) + +// INT4 inputs (k32) +TL_DEFINE_MMA_DISPATCHER(kInt4, kInt4, kInt32, 16, 8, 32, false, true, false, + cute::SM80_16x8x32_S32S4S4S32_TN) +TL_DEFINE_MMA_DISPATCHER(kUInt4, kUInt4, kInt32, 16, 8, 32, false, true, false, + cute::SM80_16x8x32_S32U4U4S32_TN) + +// FP8 inputs (k32) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F16E4M3E4M3F16_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F32E4M3E4M3F32_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F16E4M3E5M2F16_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F32E4M3E5M2F32_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F16E5M2E4M3F16_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F32E5M2E4M3F32_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F16E5M2E5M2F16_TN) +TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 16, 8, 32, false, + true, false, cute::SM89_16x8x32_F32E5M2E5M2F32_TN) + +#undef TL_DEFINE_MMA_DISPATCHER + +} // namespace detail + +template +TL_DEVICE void +mma_sync(typename detail::MmaDispatcher::CRegType *c, + const typename detail::MmaDispatcher::ARegType *a, + const typename detail::MmaDispatcher::BRegType *b) { + using Dispatcher = + detail::MmaDispatcher; + static_assert(!std::is_void_v, + "tl::mma_sync: unsupported configuration"); + Dispatcher::exec(c, a, b, c); +} + +} // namespace tl diff --git a/src/tl_templates/cuda/instruction/wgmma.h b/src/tl_templates/cuda/instruction/wgmma.h index 0e9717280..743be7379 100644 --- a/src/tl_templates/cuda/instruction/wgmma.h +++ b/src/tl_templates/cuda/instruction/wgmma.h @@ -1,516 +1,458 @@ #pragma once + #include "../common.h" -#include "cute/arch/mma_sm90_gmma.hpp" +#include +#include + +#include +#include namespace tl { +#ifndef TL_ALWAYS_FALSE_V_DEFINED +#define TL_ALWAYS_FALSE_V_DEFINED template inline constexpr bool always_false_v = false; +#endif -// 主类模板 - 移除默认参数,因为特化不能有默认参数 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - printf("DEBUG: WgmmaSSImpl fallback - A_type=%d (kFloat16=%d), B_type=%d, " - "C_type=%d, M=%d, N=%d, K=%d, tnspA=%d, tnspB=%d, scaleA=%d, " - "scaleB=%d\n", - (int)A_type, (int)DataType::kFloat16, (int)B_type, (int)C_type, M, N, - K, (int)tnspA, (int)tnspB, scaleA, scaleB); - // 暂时注释掉 static_assert 来看调试输出 - // static_assert(always_false_v, - // "wgmma_ss: No specialization available for given template - // parameters!"); - }; -}; - -// ================================= F16 x F16 -> F16 -// ================================= - -// M64N8K16 F16 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %4, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " - "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; - -// M64N16K16 F16 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " - "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; +namespace detail { -// M64N32K16 F16 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), - "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), - "n"(int32_t(tnspB))); - } +template +struct MajorValue { + static constexpr auto value = + IsMnMajor ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K; }; -// M64N64K16 F16 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - "%8, %9, %10, %11, %12, %13, %14, %15}," - " %16, %17, p, %19, %20, %21, %22;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), - "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]), - "+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), - "+r"(c[14]), "+r"(c[15]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } +template +struct ScaleInValue { + static_assert(Scale == 1 || Scale == -1, + "tl::wgmma requires scale factors of +1 or -1."); + static constexpr auto value = + Scale == 1 ? cute::SM90::GMMA::ScaleIn::One + : cute::SM90::GMMA::ScaleIn::Neg; }; -// M64N96K16 F16 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %26, 0;\n" - "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - "%8, %9, %10, %11, %12, %13, %14, %15, " - "%16, %17, %18, %19, %20, %21, %22, %23}, " - "%24, %25, p, %27, %28, %29, %30;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), - "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]), - "+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), - "+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]), - "+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]), - "+r"(c[22]), "+r"(c[23]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; +template +inline constexpr bool IsValidScale = (Scale == 1 || Scale == -1); -// M64N128K16 F16 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - "%8, %9, %10, %11, %12, %13, %14, %15, " - "%16, %17, %18, %19, %20, %21, %22, %23, " - "%24, %25, %26, %27, %28, %29, %30, %31}, " - "%32, %33, p, %35, %36, %37, %38;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), - "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]), - "+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), - "+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]), - "+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]), - "+r"(c[22]), "+r"(c[23]), "+r"(c[24]), "+r"(c[25]), - "+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]), - "+r"(c[30]), "+r"(c[31]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; +template struct CallWgmmaSS { + using CReg = std::remove_extent_t; + static constexpr int kCRegs = std::extent_v; + static_assert(sizeof(CReg) == sizeof(uint32_t), + "tl::wgmma_ss expects 32-bit accumulator registers."); -// M64N192K16 F16 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %50, 0;\n" - "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - "%8, %9, %10, %11, %12, %13, %14, %15, " - "%16, %17, %18, %19, %20, %21, %22, %23, " - "%24, %25, %26, %27, %28, %29, %30, %31, " - "%32, %33, %34, %35, %36, %37, %38, %39, " - "%40, %41, %42, %43, %44, %45, %46, %47}, " - "%48, %49, p, %51, %52, %53, %54;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), - "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]), - "+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), "+r"(c[14]), - "+r"(c[15]), "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]), - "+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]), "+r"(c[24]), - "+r"(c[25]), "+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]), - "+r"(c[30]), "+r"(c[31]), "+r"(c[32]), "+r"(c[33]), "+r"(c[34]), - "+r"(c[35]), "+r"(c[36]), "+r"(c[37]), "+r"(c[38]), "+r"(c[39]), - "+r"(c[40]), "+r"(c[41]), "+r"(c[42]), "+r"(c[43]), "+r"(c[44]), - "+r"(c[45]), "+r"(c[46]), "+r"(c[47]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), - "n"(int32_t(tnspB))); + template + TL_DEVICE static void Run(uint64_t desc_a, uint64_t desc_b, CReg *c, + cute::SM90::GMMA::ScaleOut scale, + std::index_sequence) { + Impl::fma(desc_a, desc_b, c[Idx]..., scale); } -}; -// M64N256K16 F16 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %66, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - "%8, %9, %10, %11, %12, %13, %14, %15, " - "%16, %17, %18, %19, %20, %21, %22, %23, " - "%24, %25, %26, %27, %28, %29, %30, %31, " - "%32, %33, %34, %35, %36, %37, %38, %39, " - "%40, %41, %42, %43, %44, %45, %46, %47, " - "%48, %49, %50, %51, %52, %53, %54, %55, " - "%56, %57, %58, %59, %60, %61, %62, %63}, " - "%64, %65, p, %67, %68, %69, %70;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), - "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]), - "+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), "+r"(c[14]), - "+r"(c[15]), "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]), - "+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]), "+r"(c[24]), - "+r"(c[25]), "+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]), - "+r"(c[30]), "+r"(c[31]), "+r"(c[32]), "+r"(c[33]), "+r"(c[34]), - "+r"(c[35]), "+r"(c[36]), "+r"(c[37]), "+r"(c[38]), "+r"(c[39]), - "+r"(c[40]), "+r"(c[41]), "+r"(c[42]), "+r"(c[43]), "+r"(c[44]), - "+r"(c[45]), "+r"(c[46]), "+r"(c[47]), "+r"(c[48]), "+r"(c[49]), - "+r"(c[50]), "+r"(c[51]), "+r"(c[52]), "+r"(c[53]), "+r"(c[54]), - "+r"(c[55]), "+r"(c[56]), "+r"(c[57]), "+r"(c[58]), "+r"(c[59]), - "+r"(c[60]), "+r"(c[61]), "+r"(c[62]), "+r"(c[63]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), - "n"(int32_t(tnspB))); + TL_DEVICE static void exec(uint64_t desc_a, uint64_t desc_b, uint32_t *c_raw, + bool scale_out) { + auto scale = scale_out ? cute::SM90::GMMA::ScaleOut::One + : cute::SM90::GMMA::ScaleOut::Zero; + auto c = reinterpret_cast(c_raw); + Run(desc_a, desc_b, c, scale, std::make_index_sequence{}); } }; -// ================================= F16 x F16 -> F32 -// ================================= - -// M64N8K16 F16->F32 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " - "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +template struct CallWgmmaRS { + using AReg = std::remove_extent_t; + using CReg = std::remove_extent_t; + static constexpr int kARegs = std::extent_v; + static constexpr int kCRegs = std::extent_v; + static_assert(sizeof(AReg) == sizeof(uint32_t), + "tl::wgmma_rs expects 32-bit register operands for A."); + static_assert(sizeof(CReg) == sizeof(uint32_t) || sizeof(CReg) == sizeof(float), + "tl::wgmma_rs expects 32-bit accumulator registers."); + + template + TL_DEVICE static void Run(const AReg *a, uint64_t desc_b, CReg *c, + cute::SM90::GMMA::ScaleOut scale, + std::index_sequence, + std::index_sequence) { + Impl::fma(a[AIdx]..., desc_b, c[CIdx]..., scale); } -}; -// M64N16K16 F16->F32 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), - "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), - "n"(int32_t(tnspB))); + TL_DEVICE static void exec(const uint32_t *a_raw, uint64_t desc_b, + uint32_t *c_raw, bool scale_out) { + auto scale = scale_out ? cute::SM90::GMMA::ScaleOut::One + : cute::SM90::GMMA::ScaleOut::Zero; + auto a = reinterpret_cast(a_raw); + auto c = reinterpret_cast(c_raw); + Run(a, desc_b, c, scale, std::make_index_sequence{}, + std::make_index_sequence{}); } }; -// M64N32K16 F16->F32 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %18, 0;\n" - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - "%8, %9, %10, %11, %12, %13, %14, %15}, " - "%16, %17, p, %19, %20, %21, %22;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), - "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]), - "+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), - "+r"(c[14]), "+r"(c[15]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; +} // namespace detail -// M64N64K16 F16->F32 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %34, 0;\n" - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " - "{%0, %1, %2, %3, %4, %5, %6, %7, " - "%8, %9, %10, %11, %12, %13, %14, %15, " - "%16, %17, %18, %19, %20, %21, %22, %23, " - "%24, %25, %26, %27, %28, %29, %30, %31}, " - "%32, %33, p, %35, %36, %37, %38;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), - "+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]), - "+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), - "+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]), - "+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]), - "+r"(c[22]), "+r"(c[23]), "+r"(c[24]), "+r"(c[25]), - "+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]), - "+r"(c[30]), "+r"(c[31]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +template +struct WgmmaSSImpl { + static_assert(detail::IsValidScale, "tl::wgmma_ss: invalid scaleA"); + static_assert(detail::IsValidScale, "tl::wgmma_ss: invalid scaleB"); + TL_DEVICE static void execute(uint64_t, uint64_t, uint32_t *, bool) { + static_assert(always_false_v>, + "tl::wgmma_ss: unsupported configuration"); } }; -// ================================= BF16 x BF16 -> F32 -// ================================= - -// M64N8K16 BF16->F32 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " - "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +template +struct WgmmaRSImpl { + static_assert(detail::IsValidScale, "tl::wgmma_rs: invalid scaleA"); + static_assert(detail::IsValidScale, "tl::wgmma_rs: invalid scaleB"); + TL_DEVICE static void execute(const uint32_t *, uint64_t, uint32_t *, bool) { + static_assert(always_false_v>, + "tl::wgmma_rs: unsupported configuration"); } }; -// M64N16K16 BF16->F32 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), - "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), - "n"(int32_t(tnspB))); - } -}; +#define TL_WGMMA_DEFINE_SS_GENERAL(AType, BType, CType, M, N, K, ImplName) \ + template \ + struct WgmmaSSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_ss: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_ss: invalid scaleB"); \ + using Impl = cute::SM90::GMMA::ImplName< \ + detail::MajorValue::value, detail::MajorValue::value, \ + detail::ScaleInValue::value, \ + detail::ScaleInValue::value>; \ + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ + uint32_t *c, bool scale_out) { \ + detail::CallWgmmaSS::exec(desc_a, desc_b, c, scale_out); \ + } \ + }; -// ================================= TF32 x TF32 -> F32 -// ================================= - -// M64N8K8 TF32->F32 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " - "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; +#define TL_WGMMA_DEFINE_SS_TN(AType, BType, CType, M, N, K, ImplName) \ + template \ + struct WgmmaSSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_ss: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_ss: invalid scaleB"); \ + using Impl = cute::SM90::GMMA::ImplName< \ + detail::ScaleInValue::value, \ + detail::ScaleInValue::value>; \ + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ + uint32_t *c, bool scale_out) { \ + detail::CallWgmmaSS::exec(desc_a, desc_b, c, scale_out); \ + } \ + }; -// M64N16K8 TF32->F32 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile( - "{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %10, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), - "+r"(c[5]), "+r"(c[6]), "+r"(c[7]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), - "n"(int32_t(tnspB))); - } -}; +#define TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \ + ImplName) \ + template \ + struct WgmmaSSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_ss: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_ss: invalid scaleB"); \ + static_assert(scaleA == 1 && scaleB == 1, \ + "tl::wgmma_ss: only +1 scaling supported for this WGMMA"); \ + using Impl = cute::SM90::GMMA::ImplName; \ + TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ + uint32_t *c, bool scale_out) { \ + detail::CallWgmmaSS::exec(desc_a, desc_b, c, scale_out); \ + } \ + }; -// ================================= INT8 x INT8 -> INT32 -// ================================= - -// M64N8K32 S8->S32 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %4, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " - "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; +#define TL_WGMMA_DEFINE_RS_GENERAL(AType, BType, CType, M, N, K, ImplName) \ + template \ + struct WgmmaRSImpl { \ + static_assert(!tnspA, "tl::wgmma_rs: operand A must be K-major"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_rs: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_rs: invalid scaleB"); \ + using Impl = cute::SM90::GMMA::ImplName< \ + detail::MajorValue::value, detail::MajorValue::value, \ + detail::ScaleInValue::value, \ + detail::ScaleInValue::value>; \ + TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ + uint32_t *c, bool scale_out) { \ + detail::CallWgmmaRS::exec(a, desc_b, c, scale_out); \ + } \ + }; -// M64N16K32 S8->S32 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " - "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; +#define TL_WGMMA_DEFINE_RS_TN(AType, BType, CType, M, N, K, ImplName) \ + template \ + struct WgmmaRSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_rs: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_rs: invalid scaleB"); \ + using Impl = cute::SM90::GMMA::ImplName< \ + detail::ScaleInValue::value, \ + detail::ScaleInValue::value>; \ + TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ + uint32_t *c, bool scale_out) { \ + detail::CallWgmmaRS::exec(a, desc_b, c, scale_out); \ + } \ + }; -// ================================= FP8 x FP8 -> F16/F32 -// ================================= - -// M64N8K32 E4M3->F16 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %4, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " - "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; +#define TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \ + ImplName) \ + template \ + struct WgmmaRSImpl { \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_rs: invalid scaleA"); \ + static_assert(detail::IsValidScale, \ + "tl::wgmma_rs: invalid scaleB"); \ + static_assert(scaleA == 1 && scaleB == 1, \ + "tl::wgmma_rs: only +1 scaling supported for this WGMMA"); \ + using Impl = cute::SM90::GMMA::ImplName; \ + TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ + uint32_t *c, bool scale_out) { \ + detail::CallWgmmaRS::exec(a, desc_b, c, scale_out); \ + } \ + }; -// M64N8K32 E4M3->F32 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %6, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " - "{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; +#define TL_WGMMA_FOREACH_N_FLOAT_MUL8(OP) \ + OP(8) \ + OP(16) \ + OP(24) \ + OP(32) \ + OP(40) \ + OP(48) \ + OP(56) \ + OP(64) \ + OP(72) \ + OP(80) \ + OP(88) \ + OP(96) \ + OP(104) \ + OP(112) \ + OP(120) \ + OP(128) \ + OP(136) \ + OP(144) \ + OP(152) \ + OP(160) \ + OP(168) \ + OP(176) \ + OP(184) \ + OP(192) \ + OP(200) \ + OP(208) \ + OP(216) \ + OP(224) \ + OP(232) \ + OP(240) \ + OP(248) \ + OP(256) + +#define TL_WGMMA_FOREACH_N_INT32_MUL8(OP) \ + OP(8) \ + OP(16) \ + OP(24) \ + OP(32) \ + OP(48) \ + OP(64) \ + OP(80) \ + OP(96) \ + OP(112) \ + OP(128) \ + OP(144) \ + OP(160) \ + OP(176) \ + OP(192) \ + OP(208) \ + OP(224) \ + OP(240) \ + OP(256) + +#define TL_WGMMA_DEFINE_F16_F16_F16_SS(N) \ + TL_WGMMA_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 16, \ + MMA_64x##N##x16_F16F16F16_SS) +#define TL_WGMMA_DEFINE_F16_F16_F32_SS(N) \ + TL_WGMMA_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 16, \ + MMA_64x##N##x16_F32F16F16_SS) +#define TL_WGMMA_DEFINE_BF16_BF16_F32_SS(N) \ + TL_WGMMA_DEFINE_SS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 16, \ + MMA_64x##N##x16_F32BF16BF16_SS) + +#define TL_WGMMA_DEFINE_F32_TF32_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, 8, \ + MMA_64x##N##x8_F32TF32TF32_SS_TN) + +#define TL_WGMMA_DEFINE_S32_S8S8_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32S8S8_SS_TN) +#define TL_WGMMA_DEFINE_S32_S8U8_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32S8U8_SS_TN) +#define TL_WGMMA_DEFINE_S32_U8S8_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32U8S8_SS_TN) +#define TL_WGMMA_DEFINE_S32_U8U8_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32U8U8_SS_TN) + +#define TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E4M3E4M3_SS_TN) +#define TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E4M3E4M3_SS_TN) +#define TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E4M3E5M2_SS_TN) +#define TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E4M3E5M2_SS_TN) +#define TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E5M2E4M3_SS_TN) +#define TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E5M2E4M3_SS_TN) +#define TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E5M2E5M2_SS_TN) +#define TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN(N) \ + TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E5M2E5M2_SS_TN) + +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F16_SS); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F32_SS); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_BF16_BF16_F32_SS); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_TF32_SS_TN); + +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8S8_SS_TN); +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8U8_SS_TN); +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8S8_SS_TN); +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8U8_SS_TN); + +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN); + +#define TL_WGMMA_DEFINE_F16_F16_F16_RS(N) \ + TL_WGMMA_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 16, \ + MMA_64x##N##x16_F16F16F16_RS) +#define TL_WGMMA_DEFINE_F16_F16_F32_RS(N) \ + TL_WGMMA_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 16, \ + MMA_64x##N##x16_F32F16F16_RS) +#define TL_WGMMA_DEFINE_BF16_BF16_F32_RS(N) \ + TL_WGMMA_DEFINE_RS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 16, \ + MMA_64x##N##x16_F32BF16BF16_RS) + +#define TL_WGMMA_DEFINE_F32_TF32_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, 8, \ + MMA_64x##N##x8_F32TF32TF32_RS_TN) + +#define TL_WGMMA_DEFINE_S32_S8S8_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32S8S8_RS_TN) +#define TL_WGMMA_DEFINE_S32_S8U8_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32S8U8_RS_TN) +#define TL_WGMMA_DEFINE_S32_U8S8_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32U8S8_RS_TN) +#define TL_WGMMA_DEFINE_S32_U8U8_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 32, \ + MMA_64x##N##x32_S32U8U8_RS_TN) + +#define TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E4M3E4M3_RS_TN) +#define TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E4M3E4M3_RS_TN) +#define TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E4M3E5M2_RS_TN) +#define TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E4M3E5M2_RS_TN) +#define TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E5M2E4M3_RS_TN) +#define TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E5M2E4M3_RS_TN) +#define TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 32, \ + MMA_64x##N##x32_F16E5M2E5M2_RS_TN) +#define TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN(N) \ + TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 32, \ + MMA_64x##N##x32_F32E5M2E5M2_RS_TN) + +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F16_RS); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F32_RS); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_BF16_BF16_F32_RS); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_TF32_RS_TN); + +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8S8_RS_TN); +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8U8_RS_TN); +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8S8_RS_TN); +TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8U8_RS_TN); + +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN); +TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN); + +#undef TL_WGMMA_DEFINE_F16_F16_F16_SS +#undef TL_WGMMA_DEFINE_F16_F16_F32_SS +#undef TL_WGMMA_DEFINE_BF16_BF16_F32_SS +#undef TL_WGMMA_DEFINE_F32_TF32_SS_TN +#undef TL_WGMMA_DEFINE_S32_S8S8_SS_TN +#undef TL_WGMMA_DEFINE_S32_S8U8_SS_TN +#undef TL_WGMMA_DEFINE_S32_U8S8_SS_TN +#undef TL_WGMMA_DEFINE_S32_U8U8_SS_TN +#undef TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN +#undef TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN +#undef TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN +#undef TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN +#undef TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN +#undef TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN +#undef TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN +#undef TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN +#undef TL_WGMMA_DEFINE_F16_F16_F16_RS +#undef TL_WGMMA_DEFINE_F16_F16_F32_RS +#undef TL_WGMMA_DEFINE_BF16_BF16_F32_RS +#undef TL_WGMMA_DEFINE_F32_TF32_RS_TN +#undef TL_WGMMA_DEFINE_S32_S8S8_RS_TN +#undef TL_WGMMA_DEFINE_S32_S8U8_RS_TN +#undef TL_WGMMA_DEFINE_S32_U8S8_RS_TN +#undef TL_WGMMA_DEFINE_S32_U8U8_RS_TN +#undef TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN +#undef TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN +#undef TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN +#undef TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN +#undef TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN +#undef TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN +#undef TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN +#undef TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN +#undef TL_WGMMA_FOREACH_N_FLOAT_MUL8 +#undef TL_WGMMA_FOREACH_N_INT32_MUL8 +#undef TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE +#undef TL_WGMMA_DEFINE_SS_GENERAL +#undef TL_WGMMA_DEFINE_SS_TN +#undef TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE +#undef TL_WGMMA_DEFINE_RS_GENERAL +#undef TL_WGMMA_DEFINE_RS_TN -// 函数模板委托给类模板 template TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c, @@ -519,129 +461,12 @@ TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c, scaleB>::execute(desc_a, desc_b, c, scale_out); } -// ================================= Mixed Precision Support -// ================================= - -// Mixed precision: S8 x U8 -> S32 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %4, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " - "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; - -// Mixed precision: U8 x S8 -> S32 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %4, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " - "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; - -// Mixed precision: U8 x U8 -> S32 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %4, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " - "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; - -// Mixed precision FP8: E4M3 x E5M2 -> F16 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %4, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " - "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; - -// Mixed precision FP8: E5M2 x E4M3 -> F16 -template -struct WgmmaSSImpl { - TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, - bool scale_out) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %4, 0;\n" - "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " - "{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" - "}\n" - : "+r"(c[0]), "+r"(c[1]) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), - "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), - "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); - } -}; - -// ================================= Convenience Templates -// ================================= - -// Type trait to determine the number of output registers needed -template struct WgmmaOutputRegs { - static constexpr int value = - (M * N * (C_type == DataType::kFloat32 ? 32 : 16)) / (32 * 8); -}; - -// Type trait to get element size in bits -template struct ElementBits { - static constexpr int value = - (dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 || - dtype == DataType::kInt32) - ? 32 - : (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 || - dtype == DataType::kInt16 || dtype == DataType::kUInt16) - ? 16 - : (dtype == DataType::kInt8 || dtype == DataType::kUInt8 || - dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2) - ? 8 - : (dtype == DataType::kInt4 || dtype == DataType::kUInt4) ? 4 - : 8; -}; +template +TL_DEVICE void wgmma_rs(const uint32_t *a, uint64_t desc_b, uint32_t *c, + bool scale_out) { + WgmmaRSImpl::execute(a, desc_b, c, scale_out); +} -} // namespace tl \ No newline at end of file +} // namespace tl diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 65d2ab0ca..6d9a86f6b 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -104,9 +104,15 @@ def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): self.local_size_out = (m_dim * n_dim) // warp_size def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): - self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] - self.b_dtype_abbrv = self.dtype_abbrv[b_dtype] - self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] + self.a_dtype_abbrv = self._get_dtype_abbrv(a_dtype) + self.b_dtype_abbrv = self._get_dtype_abbrv(b_dtype) + self.accum_dtype_abbrv = self._get_dtype_abbrv(accum_dtype) + + def _get_dtype_abbrv(self, dtype: str) -> str: + try: + return self.dtype_abbrv[dtype] + except KeyError as err: + raise ValueError(f"Unsupported dtype: {dtype}") from err def _initialize_mma_prefix(self, k_dim: int = 16): if k_dim == 8: diff --git a/tilelang/intrinsics/wgmma_macro_generator.py b/tilelang/intrinsics/wgmma_macro_generator.py index 9d64a15fe..7be8b6da6 100644 --- a/tilelang/intrinsics/wgmma_macro_generator.py +++ b/tilelang/intrinsics/wgmma_macro_generator.py @@ -292,53 +292,56 @@ def wgmma_rs(self, assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" elems_in_bytes = DataType(self.a_dtype).bits // 8 - b_is_k_major = self.b_transposed b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( + ) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * elems_in_bytes) - b_stride_byte_offset = (8 * k_dim * elems_in_bytes) if b_is_k_major else (8 * 8 * - elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * + elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else + (8 * 8 * elems_in_bytes)) if not b_swizzle_mode.is_none(): # swizzle mode doesn't require LBO/SBO to be 1 # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset if b_is_k_major: b_leading_byte_offset = 16 + b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() else: # MN Major # LBO represents the distance between two atoms along the N dimension # SBO represents the distance between two atoms along the K dimension - b_n_axis_atoms = n_dim // (b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + b_n_axis_atoms = n_dim // b_swizzle_atom_elems if b_n_axis_atoms <= 1: b_leading_byte_offset = 0 else: - b_leading_byte_offset = 8 * b_swizzle_mode.swizzle_atom_size() * ( - b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) - + b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim if b_n_axis_atoms <= 1: b_stride_byte_offset = 8 * elems_in_bytes * n_dim else: - b_stride_byte_offset = 8 * elems_in_bytes * ( - b_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems + + bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) @T.macro def _warp_mma(A_buf, B_buf, C_local_buf): desc_b = T.alloc_descriptor() - T.initialize_descriptor(desc_b, B_buf.access_ptr("w"), b_swizzle_mode, + T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) + T.warpgroup_arrive() for ki in T.serial(0, (k_dim // micro_size_k)): for i in T.serial(m_dim // 64): - k_dim_offset = ki * micro_size_k A_offset = ki * warp_rows * local_size_a + i * local_size_a - B_offset = k_dim_offset if b_is_k_major else k_dim_offset * B_buf.shape[-1] + B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( + ki % bk_atom_size + ) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k C_offset = i * warp_cols * local_size_out # 4 warps as an unit T.ptx_wgmma_rs( accum_dtype, wgmma_prefix, - self.a_transposed, - not self.b_transposed, + self.b_transposed, a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, @@ -352,6 +355,8 @@ def _warp_mma(A_buf, B_buf, C_local_buf): scale_in_a, scale_in_b, ) + T.warpgroup_commit_batch() + T.warpgroup_wait(0) return _warp_mma(A_buf, B_buf, C_local_buf) diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index f3b044605..b7cc065af 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -227,6 +227,12 @@ def __init__(self, def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: return pythonic_expr(expr, self._TYPE_MAP) + def _lookup_type(self, dtype: Union[str, Any]) -> str: + key = dtype if isinstance(dtype, str) else str(dtype) + result = self._TYPE_MAP.get(key) + assert result is not None, f"Unsupported dtype {dtype}" + return result + def is_tma_descriptor_arg(self, arg_name: str) -> bool: return arg_name in self.prim_func.buffer_map @@ -244,10 +250,10 @@ def create_dispatch_func(self, code, function_informations): buffer = self.prim_func.buffer_map[param] function_args.append({ "name": buffer.data.name, - "type": self._TYPE_MAP[buffer.dtype] + "* __restrict__", + "type": self._lookup_type(buffer.dtype) + "* __restrict__", }) elif isinstance(param, tvm.tir.Var): - function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]}) + function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) else: raise ValueError( f"Parameter {param} is not in the buffer map of the primary function.") @@ -652,6 +658,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): "float16": "ctypes.c_uint16", "bfloat16": "ctypes.c_uint16", "float8_e4m3": "ctypes.c_uint8", + "float8_e4m3fn": "ctypes.c_uint8", "float8_e5m2": "ctypes.c_uint8", "float64": "ctypes.c_double", "int64": "ctypes.c_int64", @@ -688,7 +695,9 @@ def create_dispatch_func(self, code, function_informations): "type": "ctypes.c_void_p", }) elif isinstance(param, tvm.tir.Var): - function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]}) + function_args.append( + {"name": param.name, "type": self._lookup_type(param.dtype)} + ) else: raise ValueError( f"Parameter {param} is not in the buffer map of the primary function.") @@ -858,6 +867,7 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper): "float16": "half_t", "bfloat16": "bfloat16_t", "float8_e4m3": "fp8_e4_t", + "float8_e4m3fn": "fp8_e4_t", "float8_e5m2": "fp8_e5_t", "float8_e4m3fnuz": "fp8_e4_t", "e4m3fnuz_float8": "fp8_e4_t", @@ -949,6 +959,12 @@ def __init__(self, self.libpath: Optional[str] = None self.lib_code: Optional[str] = self.update_lib_code(source) + def _lookup_type(self, dtype: Union[str, Any]) -> str: + key = dtype if isinstance(dtype, str) else str(dtype) + result = self._TYPE_MAP.get(key) + assert result is not None, f"Unsupported dtype {dtype}" + return result + def create_call_func(self, code, function_informations): # Extract the set of dynamic symbolic names used in the primary function dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) @@ -960,10 +976,10 @@ def create_call_func(self, code, function_informations): buffer = self.prim_func.buffer_map[param] function_args.append({ "name": buffer.name, - "type": self._TYPE_MAP[buffer.dtype] + "*", + "type": self._lookup_type(buffer.dtype) + "*", }) elif isinstance(param, tvm.tir.Var): - function_args.append({"name": param.name, "type": self._TYPE_MAP[param.dtype]}) + function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) else: raise ValueError( f"Parameter {param} is not in the buffer map of the primary function.") diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index 10ca7ca93..cd87f691b 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -1106,7 +1106,6 @@ def ptx_wgmma_ss( def ptx_wgmma_rs( dtype, wgmma_prefix, - a_is_k_major, b_is_k_major, a_dtype_abbrv, b_dtype_abbrv, @@ -1126,7 +1125,6 @@ def ptx_wgmma_rs( dtype, _tvm_op.Op.get("tl.ptx_wgmma_rs"), wgmma_prefix, - a_is_k_major, b_is_k_major, a_dtype_abbrv, b_dtype_abbrv, From 4a9603261e86e263f1cd1ad3c62ccfd0857fa9e2 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 12 Oct 2025 12:24:08 +0800 Subject: [PATCH 05/10] warp warpgroup_fence_operand --- src/op/builtin.cc | 5 ++ src/op/builtin.h | 8 ++ src/target/codegen_cuda.cc | 22 ++++- src/tl_templates/cuda/instruction/mma.h | 51 ++++++----- src/tl_templates/cuda/instruction/wgmma.h | 91 ++++++++++---------- src/tl_templates/cuda/intrin.h | 14 +++ tilelang/intrinsics/wgmma_macro_generator.py | 13 ++- tilelang/jit/adapter/wrapper.py | 4 +- tilelang/language/builtin.py | 63 +++++++++++++- 9 files changed, 191 insertions(+), 80 deletions(-) diff --git a/src/op/builtin.cc b/src/op/builtin.cc index e2aeea3ee..748e84094 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -218,6 +218,11 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand) + .set_num_inputs(4) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(wait_wgmma) .set_num_inputs(1) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index a331dfe90..44bbc21ff 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -358,6 +358,14 @@ TVM_DLL const Op &warpgroup_commit_batch(); */ TVM_DLL const Op &warpgroup_wait(); +/*! + * \brief Fence accumulator operand registers for upcoming WGMMA operations + * + * warpgroup_fence_operand(dtype, ptr, offset, num_regs) + * + */ +TVM_DLL const Op &warpgroup_fence_operand(); + /*! * \brief Wait the previous wgmma to finish * diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 4918bfcb3..6f0b6db50 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1389,6 +1389,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { int num_mma = Downcast(op->args[0])->value; this->stream << "tl::warpgroup_wait<" << std::to_string(num_mma) << ">();\n"; + } else if (op->op.same_as(tl::warpgroup_fence_operand())) { + ICHECK_EQ(op->args.size(), 4U); + std::string dtype = Downcast(op->args[0])->value; + std::string data_ptr = this->PrintExpr(op->args[1]); + std::string offset = this->PrintExpr(op->args[2]); + std::string num_regs = this->PrintExpr(op->args[3]); + auto dtype_enum = tl::codegen::ptx::DTypeFromString(dtype); + std::string cast_type = "uint32_t"; + if (dtype_enum == tl::codegen::ptx::DataType::kFloat32 || + dtype_enum == tl::codegen::ptx::DataType::kTensorFloat32) { + cast_type = "float"; + } + this->PrintIndent(); + this->stream << "tl::warpgroup_fence_operand(reinterpret_cast<" << cast_type + << "*>(" << data_ptr << " + " << offset << "), " << num_regs + << ");\n"; } else if (op->op.same_as(tl::set_max_nreg())) { this->PrintIndent(); int nreg = Downcast(op->args[0])->value; @@ -1679,8 +1695,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->PrintIndent(); std::string wgmma_call = "tl::wgmma_rs<(AType), (BType), (CType), (M), (N), (K), (tnspA), " - "(tnspB), (scaleA), (scaleB)>(reinterpret_cast((A_ptr) + (A_offset)), " - "uint64_t((desc_b) + (B_offset)), reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), " + "(tnspB), (scaleA), (scaleB)>(reinterpret_cast((A_ptr) + (A_offset)), " + "uint64_t((desc_b) + (B_offset)), " + "reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), " "(scale_out));\n"; tl::codegen::Replacer replacer; diff --git a/src/tl_templates/cuda/instruction/mma.h b/src/tl_templates/cuda/instruction/mma.h index d114c58f3..8346b7a1f 100644 --- a/src/tl_templates/cuda/instruction/mma.h +++ b/src/tl_templates/cuda/instruction/mma.h @@ -30,13 +30,13 @@ template struct MmaImplTraits { template -TL_DEVICE void call_fma_impl( - typename MmaImplTraits::DReg *d, - const typename MmaImplTraits::AReg *a, - const typename MmaImplTraits::BReg *b, - const typename MmaImplTraits::CReg *c, std::index_sequence, - std::index_sequence, std::index_sequence, - std::index_sequence) { +TL_DEVICE void +call_fma_impl(typename MmaImplTraits::DReg *d, + const typename MmaImplTraits::AReg *a, + const typename MmaImplTraits::BReg *b, + const typename MmaImplTraits::CReg *c, + std::index_sequence, std::index_sequence, + std::index_sequence, std::index_sequence) { Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]...); } @@ -45,12 +45,11 @@ TL_DEVICE void call_fma(typename MmaImplTraits::DReg *d, const typename MmaImplTraits::AReg *a, const typename MmaImplTraits::BReg *b, const typename MmaImplTraits::CReg *c) { - call_fma_impl( - d, a, b, c, - std::make_index_sequence::kDRegs>{}, - std::make_index_sequence::kARegs>{}, - std::make_index_sequence::kBRegs>{}, - std::make_index_sequence::kCRegs>{}); + call_fma_impl(d, a, b, c, + std::make_index_sequence::kDRegs>{}, + std::make_index_sequence::kARegs>{}, + std::make_index_sequence::kBRegs>{}, + std::make_index_sequence::kCRegs>{}); } template ,\ - "tl::mma_sync requires matching accumulator/output regs"); \ + static_assert( \ + std::is_same_v, \ + "tl::mma_sync requires matching accumulator/output regs"); \ static TL_DEVICE void exec(CRegType *d, const ARegType *a, \ const BRegType *b, const CRegType *c) { \ call_fma(d, a, b, c); \ @@ -133,18 +133,15 @@ TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 16, 8, 32, false, template -TL_DEVICE void -mma_sync(typename detail::MmaDispatcher::CRegType *c, - const typename detail::MmaDispatcher::ARegType *a, - const typename detail::MmaDispatcher::BRegType *b) { - using Dispatcher = - detail::MmaDispatcher; +TL_DEVICE void mma_sync( + typename detail::MmaDispatcher::CRegType *c, + const typename detail::MmaDispatcher::ARegType *a, + const typename detail::MmaDispatcher::BRegType *b) { + using Dispatcher = detail::MmaDispatcher; static_assert(!std::is_void_v, "tl::mma_sync: unsupported configuration"); Dispatcher::exec(c, a, b, c); diff --git a/src/tl_templates/cuda/instruction/wgmma.h b/src/tl_templates/cuda/instruction/wgmma.h index 743be7379..b5ef59c26 100644 --- a/src/tl_templates/cuda/instruction/wgmma.h +++ b/src/tl_templates/cuda/instruction/wgmma.h @@ -16,19 +16,16 @@ template inline constexpr bool always_false_v = false; namespace detail { -template -struct MajorValue { +template struct MajorValue { static constexpr auto value = IsMnMajor ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K; }; -template -struct ScaleInValue { +template struct ScaleInValue { static_assert(Scale == 1 || Scale == -1, "tl::wgmma requires scale factors of +1 or -1."); - static constexpr auto value = - Scale == 1 ? cute::SM90::GMMA::ScaleIn::One - : cute::SM90::GMMA::ScaleIn::Neg; + static constexpr auto value = Scale == 1 ? cute::SM90::GMMA::ScaleIn::One + : cute::SM90::GMMA::ScaleIn::Neg; }; template @@ -63,14 +60,14 @@ template struct CallWgmmaRS { static constexpr int kCRegs = std::extent_v; static_assert(sizeof(AReg) == sizeof(uint32_t), "tl::wgmma_rs expects 32-bit register operands for A."); - static_assert(sizeof(CReg) == sizeof(uint32_t) || sizeof(CReg) == sizeof(float), + static_assert(sizeof(CReg) == sizeof(uint32_t) || + sizeof(CReg) == sizeof(float), "tl::wgmma_rs expects 32-bit accumulator registers."); template - TL_DEVICE static void Run(const AReg *a, uint64_t desc_b, CReg *c, - cute::SM90::GMMA::ScaleOut scale, - std::index_sequence, - std::index_sequence) { + TL_DEVICE static void + Run(const AReg *a, uint64_t desc_b, CReg *c, cute::SM90::GMMA::ScaleOut scale, + std::index_sequence, std::index_sequence) { Impl::fma(a[AIdx]..., desc_b, c[CIdx]..., scale); } @@ -113,14 +110,15 @@ struct WgmmaRSImpl { template \ struct WgmmaSSImpl { \ - static_assert(detail::IsValidScale, \ + static_assert(detail::IsValidScale, \ "tl::wgmma_ss: invalid scaleA"); \ - static_assert(detail::IsValidScale, \ + static_assert(detail::IsValidScale, \ "tl::wgmma_ss: invalid scaleB"); \ - using Impl = cute::SM90::GMMA::ImplName< \ - detail::MajorValue::value, detail::MajorValue::value, \ - detail::ScaleInValue::value, \ - detail::ScaleInValue::value>; \ + using Impl = \ + cute::SM90::GMMA::ImplName::value, \ + detail::MajorValue::value, \ + detail::ScaleInValue::value, \ + detail::ScaleInValue::value>; \ TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ uint32_t *c, bool scale_out) { \ detail::CallWgmmaSS::exec(desc_a, desc_b, c, scale_out); \ @@ -130,14 +128,14 @@ struct WgmmaRSImpl { #define TL_WGMMA_DEFINE_SS_TN(AType, BType, CType, M, N, K, ImplName) \ template \ struct WgmmaSSImpl { \ - static_assert(detail::IsValidScale, \ + K, false, false, scaleA, scaleB> { \ + static_assert(detail::IsValidScale, \ "tl::wgmma_ss: invalid scaleA"); \ - static_assert(detail::IsValidScale, \ + static_assert(detail::IsValidScale, \ "tl::wgmma_ss: invalid scaleB"); \ - using Impl = cute::SM90::GMMA::ImplName< \ - detail::ScaleInValue::value, \ - detail::ScaleInValue::value>; \ + using Impl = \ + cute::SM90::GMMA::ImplName::value, \ + detail::ScaleInValue::value>; \ TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ uint32_t *c, bool scale_out) { \ detail::CallWgmmaSS::exec(desc_a, desc_b, c, scale_out); \ @@ -148,14 +146,14 @@ struct WgmmaRSImpl { ImplName) \ template \ struct WgmmaSSImpl { \ - static_assert(detail::IsValidScale, \ + K, false, false, scaleA, scaleB> { \ + static_assert(detail::IsValidScale, \ "tl::wgmma_ss: invalid scaleA"); \ - static_assert(detail::IsValidScale, \ + static_assert(detail::IsValidScale, \ "tl::wgmma_ss: invalid scaleB"); \ static_assert(scaleA == 1 && scaleB == 1, \ - "tl::wgmma_ss: only +1 scaling supported for this WGMMA"); \ - using Impl = cute::SM90::GMMA::ImplName; \ + "tl::wgmma_ss: only +1 scaling supported for this WGMMA"); \ + using Impl = cute::SM90::GMMA::ImplName; \ TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ uint32_t *c, bool scale_out) { \ detail::CallWgmmaSS::exec(desc_a, desc_b, c, scale_out); \ @@ -167,14 +165,15 @@ struct WgmmaRSImpl { struct WgmmaRSImpl { \ static_assert(!tnspA, "tl::wgmma_rs: operand A must be K-major"); \ - static_assert(detail::IsValidScale, \ + static_assert(detail::IsValidScale, \ "tl::wgmma_rs: invalid scaleA"); \ - static_assert(detail::IsValidScale, \ + static_assert(detail::IsValidScale, \ "tl::wgmma_rs: invalid scaleB"); \ - using Impl = cute::SM90::GMMA::ImplName< \ - detail::MajorValue::value, detail::MajorValue::value, \ - detail::ScaleInValue::value, \ - detail::ScaleInValue::value>; \ + using Impl = \ + cute::SM90::GMMA::ImplName::value, \ + detail::MajorValue::value, \ + detail::ScaleInValue::value, \ + detail::ScaleInValue::value>; \ TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ uint32_t *c, bool scale_out) { \ detail::CallWgmmaRS::exec(a, desc_b, c, scale_out); \ @@ -184,14 +183,14 @@ struct WgmmaRSImpl { #define TL_WGMMA_DEFINE_RS_TN(AType, BType, CType, M, N, K, ImplName) \ template \ struct WgmmaRSImpl { \ - static_assert(detail::IsValidScale, \ + K, false, false, scaleA, scaleB> { \ + static_assert(detail::IsValidScale, \ "tl::wgmma_rs: invalid scaleA"); \ - static_assert(detail::IsValidScale, \ + static_assert(detail::IsValidScale, \ "tl::wgmma_rs: invalid scaleB"); \ - using Impl = cute::SM90::GMMA::ImplName< \ - detail::ScaleInValue::value, \ - detail::ScaleInValue::value>; \ + using Impl = \ + cute::SM90::GMMA::ImplName::value, \ + detail::ScaleInValue::value>; \ TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ uint32_t *c, bool scale_out) { \ detail::CallWgmmaRS::exec(a, desc_b, c, scale_out); \ @@ -202,14 +201,14 @@ struct WgmmaRSImpl { ImplName) \ template \ struct WgmmaRSImpl { \ - static_assert(detail::IsValidScale, \ + K, false, false, scaleA, scaleB> { \ + static_assert(detail::IsValidScale, \ "tl::wgmma_rs: invalid scaleA"); \ - static_assert(detail::IsValidScale, \ + static_assert(detail::IsValidScale, \ "tl::wgmma_rs: invalid scaleB"); \ static_assert(scaleA == 1 && scaleB == 1, \ - "tl::wgmma_rs: only +1 scaling supported for this WGMMA"); \ - using Impl = cute::SM90::GMMA::ImplName; \ + "tl::wgmma_rs: only +1 scaling supported for this WGMMA"); \ + using Impl = cute::SM90::GMMA::ImplName; \ TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ uint32_t *c, bool scale_out) { \ detail::CallWgmmaRS::exec(a, desc_b, c, scale_out); \ diff --git a/src/tl_templates/cuda/intrin.h b/src/tl_templates/cuda/intrin.h index f2abc5c65..31da39323 100644 --- a/src/tl_templates/cuda/intrin.h +++ b/src/tl_templates/cuda/intrin.h @@ -14,6 +14,20 @@ template TL_DEVICE void warpgroup_wait() { cute::warpgroup_wait(); } +TL_DEVICE void warpgroup_fence_operand(uint32_t *regs, int count) { +#pragma unroll + for (int i = 0; i < count; ++i) { + cute::warpgroup_fence_operand(regs[i]); + } +} + +TL_DEVICE void warpgroup_fence_operand(float *regs, int count) { +#pragma unroll + for (int i = 0; i < count; ++i) { + cute::warpgroup_fence_operand(regs[i]); + } +} + // Template parameter: // thread_extent: the logical size (in number of threads) of each "group" // within which we want to elect exactly ONE representative diff --git a/tilelang/intrinsics/wgmma_macro_generator.py b/tilelang/intrinsics/wgmma_macro_generator.py index 7be8b6da6..49c4819ca 100644 --- a/tilelang/intrinsics/wgmma_macro_generator.py +++ b/tilelang/intrinsics/wgmma_macro_generator.py @@ -181,6 +181,8 @@ def wgmma(self, a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( ) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + accum_bits = DataType(accum_dtype).bits + accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 # by default, we utilize non-swizzle layout offset a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * @@ -242,13 +244,13 @@ def wgmma(self, @T.macro def _warp_mma(A_buf, B_buf, C_local_buf): - # TODO(lei): inject warpgroup_fence_operand for C_local_buf desc_a = T.alloc_descriptor() desc_b = T.alloc_descriptor() T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode, int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) + T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) T.warpgroup_arrive() for ki in T.serial(0, (k_dim // micro_size_k)): for i in T.serial(m_dim // 64): @@ -266,6 +268,7 @@ def _warp_mma(A_buf, B_buf, C_local_buf): scale_out, scale_in_a, scale_in_b) T.warpgroup_commit_batch() T.warpgroup_wait(0) + T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) return _warp_mma(A_buf, B_buf, C_local_buf) @@ -292,6 +295,10 @@ def wgmma_rs(self, assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" elems_in_bytes = DataType(self.a_dtype).bits // 8 + a_bits = DataType(self.a_dtype).bits + accum_bits = DataType(accum_dtype).bits + a_regs = ((warp_rows * local_size_a * (k_dim // micro_size_k)) * a_bits + 31) // 32 + accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 b_is_k_major = self.b_transposed b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) @@ -330,6 +337,8 @@ def _warp_mma(A_buf, B_buf, C_local_buf): desc_b = T.alloc_descriptor() T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) + T.warpgroup_fence_operand(A_buf, num_regs=a_regs) + T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) T.warpgroup_arrive() for ki in T.serial(0, (k_dim // micro_size_k)): for i in T.serial(m_dim // 64): @@ -357,6 +366,8 @@ def _warp_mma(A_buf, B_buf, C_local_buf): ) T.warpgroup_commit_batch() T.warpgroup_wait(0) + T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) + T.warpgroup_fence_operand(A_buf, num_regs=a_regs) return _warp_mma(A_buf, B_buf, C_local_buf) diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index b7cc065af..b02becfc5 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -695,9 +695,7 @@ def create_dispatch_func(self, code, function_informations): "type": "ctypes.c_void_p", }) elif isinstance(param, tvm.tir.Var): - function_args.append( - {"name": param.name, "type": self._lookup_type(param.dtype)} - ) + function_args.append({"name": param.name, "type": self._lookup_type(param.dtype)}) else: raise ValueError( f"Parameter {param} is not in the buffer map of the primary function.") diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 602c44509..de6905e56 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -4,7 +4,8 @@ from tilelang.language import ptx_arrive_barrier, evaluate from tilelang.language.kernel import get_thread_bindings, get_block_extents from tilelang.utils.target import check_hip_availability -from tvm import tir +from tvm import DataType, tir +from tvm.runtime import convert from typing import Union, Any from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad @@ -280,6 +281,66 @@ def warpgroup_wait(num_mma: int): return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma) +def warpgroup_fence_operand(buffer_or_ptr: Union[Buffer, PrimExpr], + offset: Union[int, PrimExpr] = 0, + num_regs: Union[int, PrimExpr, None] = None, + dtype: Union[str, None] = None): + """Insert a warpgroup fence for the destination accumulator registers. + + This prevents NVCC from sinking uses of accumulator fragments past the corresponding + WGMMA operations by issuing an empty inline assembly barrier on every register. + + Args: + buffer_or_ptr: Union[Buffer, PrimExpr] + Either a buffer representing the accumulator fragment or a pointer expression. + offset: Union[int, PrimExpr] + Element offset from the start of the accumulator fragment. + num_regs: Union[int, PrimExpr, None] + Number of 32-bit registers to fence. If None and a Buffer is provided, it will be + derived from the buffer shape and dtype. + dtype: Optional[str] + Data type string of the accumulator elements. Required when passing a pointer. + + Returns: + tir.Call: A handle to the warpgroup fence operation. + """ + if isinstance(buffer_or_ptr, BufferLoad): + raise TypeError("Expected a buffer handle or pointer expression, got BufferLoad.") + + if isinstance(buffer_or_ptr, Buffer): + data_ptr = buffer_or_ptr.data + inferred_dtype = buffer_or_ptr.dtype + if dtype is not None and dtype != inferred_dtype: + raise ValueError(f"dtype mismatch: provided {dtype}, buffer uses {inferred_dtype}.") + dtype = inferred_dtype + if num_regs is None: + total_elems = 1 + for dim in buffer_or_ptr.shape: + if isinstance(dim, tir.IntImm): + total_elems *= int(dim) + else: + raise ValueError( + "warpgroup_fence_operand requires num_regs when buffer shape is symbolic.") + bits_per_elem = DataType(dtype).bits + num_regs = (total_elems * bits_per_elem + 31) // 32 + else: + data_ptr = buffer_or_ptr + if dtype is None: + raise ValueError("dtype must be provided when passing a pointer expression.") + if num_regs is None: + raise ValueError("num_regs must be provided when passing a pointer expression.") + + return evaluate( + tir.call_intrin( + "handle", + tir.op.Op.get("tl.warpgroup_fence_operand"), + dtype, + data_ptr, + convert(offset), + convert(num_regs), + )) + + def wait_wgmma(id: int): """Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete. From a60701b2421bcb6fb2c813ec26875ef1b3dbe73e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 12 Oct 2025 12:52:10 +0800 Subject: [PATCH 06/10] fix --- src/target/codegen_cuda.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 6f0b6db50..14a86a140 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1698,7 +1698,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { "(tnspB), (scaleA), (scaleB)>(reinterpret_cast((A_ptr) + (A_offset)), " "uint64_t((desc_b) + (B_offset)), " - "reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), " + "reinterpret_cast((C_ptr) + (C_offset)), " "(scale_out));\n"; tl::codegen::Replacer replacer; @@ -1715,8 +1715,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { replacer.register_rule("(tnspB)", b_is_k_major ? "false" : "true"); replacer.register_rule("(scaleA)", scale_in_a ? "1" : "-1"); replacer.register_rule("(scaleB)", scale_in_b ? "1" : "-1"); - replacer.register_rule("(CRegType)", - tl::codegen::GetMMARegisterType(dtype_c_enum)); replacer.register_rule("(A_ptr)", a_ref); replacer.register_rule("(A_offset)", A_offset); replacer.register_rule("(desc_b)", b_desc); From c8eec620ea9f4b4791f4a504ad7b029183abf1c1 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 12 Oct 2025 13:47:48 +0800 Subject: [PATCH 07/10] fp8 dtype ptx enhance --- src/target/ptx.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/target/ptx.cc b/src/target/ptx.cc index 0710bffca..53f83ded9 100644 --- a/src/target/ptx.cc +++ b/src/target/ptx.cc @@ -74,9 +74,9 @@ DataType DTypeFromString(const std::string str) { return DataType::kInt64; } else if (str == "uint64" || str == ".u64") { return DataType::kUInt64; - } else if (str == "e4m3" || str == ".e4m3") { + } else if (str == "float8_e4m3" || str == "e4m3" || str == ".e4m3") { return DataType::kFloat8_e4m3; - } else if (str == "e5m2" || str == ".e5m2") { + } else if (str == "float8_e5m2" || str == "e5m2" || str == ".e5m2") { return DataType::kFloat8_e5m2; } else if (str == "float16" || str == "fp16" || str == ".f16") { return DataType::kFloat16; From 9fd79ebf2a372eb3a1a25d51956ffb1cadeb2fc4 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 12 Oct 2025 14:02:48 +0800 Subject: [PATCH 08/10] mma fix --- src/target/codegen_cuda.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 14a86a140..97ec86ec9 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1544,6 +1544,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { tl::codegen::GetMMARegisterType(dtype_a_enum)); replacer.register_rule("(BRegType)", tl::codegen::GetMMARegisterType(dtype_b_enum)); + replacer.register_rule("(CRegType)", + tl::codegen::GetMMARegisterType(dtype_c_enum)); replacer.register_rule("(A_ptr)", a_ref); replacer.register_rule("(A_offset)", a_bias); replacer.register_rule("(B_ptr)", b_ref); From 73fa0af3cb3b2ba24b9c0cc013805b3fddaaa58c Mon Sep 17 00:00:00 2001 From: Zhiwen Mo Date: Sun, 12 Oct 2025 14:52:42 +0800 Subject: [PATCH 09/10] TCGEN05 Interface --- src/op/gemm_py.cc | 99 ++++++++++++++++++++++++++++++-- src/op/gemm_py.h | 4 +- tilelang/language/gemm.py | 8 +++ tilelang/tileop/gemm/__init__.py | 11 +++- 4 files changed, 115 insertions(+), 7 deletions(-) diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index 4e48389ee..b64d12c42 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -13,12 +13,80 @@ #include "../target/utils.h" #include "tvm/ffi/string.h" +#include namespace tvm { namespace tl { using namespace tir; +struct TCGEN5MMAMeta { + int atom_m, atom_n, atom_k; +}; + +// Return {is_success, meta} +static inline std::pair +GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { +// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. +#define FAIL \ + return { \ + false, TCGEN5MMAMeta { 0, 0, 0 } \ + } +#define SUCCESS(atom_m, atom_n, atom_k) \ + return { \ + true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \ + } + std::vector ws_valid_atom_ns = {256, 128, 64}; + if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) && + (c_dtype.is_float() && c_dtype.bits() == 32)) { + if (K % 16 != 0) + FAIL; + if (M % 128 == 0) { + for (int atom_n = 256; atom_n >= 16; atom_n -= 16) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 16); + FAIL; + } else if (M % 64 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(64, atom_n, 16); + FAIL; + } else if (M % 32 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(32, atom_n, 16); + FAIL; + } else { + FAIL; + } + } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) && + (c_dtype.is_float() && c_dtype.bits() == 32)) { + if (K % 32 != 0) + FAIL; + if (M % 128 == 0) { + for (int atom_n = 256; atom_n >= 16; atom_n -= 16) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 32); + FAIL; + } else if (M % 64 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(64, atom_n, 32); + FAIL; + } else if (M % 32 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(32, atom_n, 32); + FAIL; + } else { + FAIL; + } + } + FAIL; +#undef FAIL +#undef SUCCESS +} + /** * @brief Construct a Gemm operator from serialized TL arguments and a buffer * map. @@ -92,16 +160,37 @@ TileOperator GemmPyNode::Clone() const { return GemmPy(op); } -GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const { +bool GemmPyNode::AllowTCGEN5MMA(Target target) const { + return TargetIsSm100(target) && + ((A.scope() == "shared.dyn" || A.scope() == "shared" || + A.scope() == "shared.tmem") && + (B.scope() == "shared.dyn" || B.scope() == "shared") && + C.scope() == "shared.tmem") && + GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first; +} + +bool GemmPyNode::AllowWGMMA(int block_size, Target target) const { + tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); + int warp_size = TargetGetWarpSize(target); int num_warps = block_size / warp_size; - bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) && - (num_warps % 4 == 0) && CheckWGMMA(); - if (allow_wgmma) { + return !ctxt->GetConfig(kDisableWGMMA, Optional()).value_or(false) && + TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) && + CheckWGMMA(); +} + +GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const { + bool allow_tcgen5mma = AllowTCGEN5MMA(target); + bool allow_wgmma = AllowWGMMA(block_size, target); + if (allow_tcgen5mma) { + return GemmInst::kTCGEN5MMA; + } else if (allow_wgmma) { return GemmInst::kWGMMA; } else if (TargetIsCDNA(target)) { return GemmInst::kMFMA; - } else if (TargetIsCuda(target)) { + } else if (TargetIsVolta(target) || TargetIsAmpere(target) || + TargetIsTuring(target) || TargetIsHopper(target) || + TargetIsSm100(target)) { return GemmInst::kMMA; } else { ICHECK(0) << "Unsupported target for gemm: " << target->str(); diff --git a/src/op/gemm_py.h b/src/op/gemm_py.h index 65ed08c0f..cd4b929ea 100644 --- a/src/op/gemm_py.h +++ b/src/op/gemm_py.h @@ -19,6 +19,8 @@ using namespace tir; class GemmPyNode : public TileOperatorNode { public: bool CheckWGMMA() const; + bool AllowTCGEN5MMA(Target target) const; + bool AllowWGMMA(int block_size, Target target) const; tir::Buffer A, B, C; // pointer to the A, B, C PrimExpr Aptr, Bptr, Cptr; @@ -122,4 +124,4 @@ class GemmPy : public TileOperator { } // namespace tl } // namespace tvm -#endif // TVM_TL_OP_GEMM_PY_H_ \ No newline at end of file +#endif // TVM_TL_OP_GEMM_PY_H_ diff --git a/tilelang/language/gemm.py b/tilelang/language/gemm.py index 3c4aa5452..8fbb8b96b 100644 --- a/tilelang/language/gemm.py +++ b/tilelang/language/gemm.py @@ -223,6 +223,7 @@ def gemm_v2( clear_accum: bool = False, k_pack: int = 1, wg_wait: int = 0, + mbar: Optional[tir.Buffer] = None, ): """Perform a General Matrix Multiplication (GEMM) operation. @@ -239,6 +240,7 @@ def gemm_v2( clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False. k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1. wg_wait (int, optional): Warp group wait count. Defaults to 0. + mbar (tir.Buffer, optional): mbarrier for TCGEN5MMA synchronization Returns: tir.Call: A handle to the GEMM operation @@ -263,6 +265,7 @@ def legalize_arguments(arg: Union[tir.Buffer, tir.Var]): A = legalize_arguments(A) B = legalize_arguments(B) C = legalize_arguments(C) + mbar = legalize_arguments(mbar) if mbar is not None else None def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: if isinstance(object, tir.Buffer): @@ -406,6 +409,8 @@ def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr Aptr = retrieve_ptr(A, "r") Bptr = retrieve_ptr(B, "r") Cptr = retrieve_ptr(C, "rw") + mbarptr = retrieve_ptr(mbar, "rw") if mbar is not None else tir.const(0, "uint32") + C_coords = [r.min for r in C.region] if isinstance(C, tir.BufferRegion) else [0, 0] return tir.call_intrin( "handle", tir.op.Op.get("tl.gemm_py"), @@ -425,4 +430,7 @@ def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr offset_b, k_pack, wg_wait, + mbarptr, + C_coords[0], + C_coords[1], ) diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 63a999f4d..53afae42d 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -29,7 +29,8 @@ def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var): class GemmInst(IntEnum): MMA = 0 WGMMMA = 1 - MFMA = 2 + TCGEN5MMA = 2 + MFMA = 3 def is_mma(self) -> bool: return self == GemmInst.MMA @@ -37,9 +38,15 @@ def is_mma(self) -> bool: def is_wgmma(self) -> bool: return self == GemmInst.WGMMMA + def is_tcgen5mma(self) -> bool: + return self == GemmInst.TCGEN5MMA + def is_mfma(self) -> bool: return self == GemmInst.MFMA + def __repr__(self) -> str: + return self.name + @tvm.ffi.register_object("tl.GemmPy") class GemmPy(Node, Scriptable): @@ -114,6 +121,8 @@ def _get_implementation_class(self, gemm_inst: GemmInst): return GemmMMA elif gemm_inst.is_wgmma(): return GemmWGMMA + elif gemm_inst.is_tcgen5mma(): + raise NotImplementedError("TCGEN5MMA is not implemented") elif gemm_inst.is_mfma(): raise NotImplementedError("MFMA is not implemented") else: From 3e90be7c7b024237ad56ad1677260652d77e7c18 Mon Sep 17 00:00:00 2001 From: Zhiwen Mo Date: Tue, 14 Oct 2025 14:00:30 +0800 Subject: [PATCH 10/10] tcgen05 support --- docs/compiler_internals/inject_fence_proxy.md | 6 +- src/layout/layout.cc | 6 + src/op/builtin.cc | 12 +- src/op/builtin.h | 13 +- src/op/gemm.cc | 70 +-- src/op/gemm_py.cc | 109 ++--- src/op/gemm_py.h | 8 + src/op/tcgen5_meta.h | 166 +++++++ src/target/codegen_cuda.cc | 85 +++- src/target/codegen_cuda.h | 2 + src/tl_templates/cuda/common.h | 73 ++- .../cuda/instruction/tcgen05mma.h | 3 + src/transform/inject_fence_proxy.cc | 3 +- src/transform/lower_shared_tmem.cc | 12 +- ...t_tilelang_transform_inject_fence_proxy.py | 4 +- .../intrinsics/tcgen05_macro_generator.py | 423 ++++++++++++++++++ tilelang/intrinsics/wgmma_macro_generator.py | 10 +- tilelang/language/ast/ir.py | 2 + tilelang/language/builtin.py | 70 ++- tilelang/language/tir/ir.py | 1 + tilelang/language/tir/op.py | 42 ++ tilelang/layout/__init__.py | 1 + tilelang/layout/swizzle.py | 15 + tilelang/tileop/gemm/__init__.py | 3 +- tilelang/tileop/gemm/gemm_base.py | 12 + tilelang/tileop/gemm/gemm_tcgen05.py | 121 +++++ tilelang/utils/__init__.py | 1 + tilelang/utils/language.py | 13 + 28 files changed, 1110 insertions(+), 176 deletions(-) create mode 100644 src/op/tcgen5_meta.h create mode 100644 src/tl_templates/cuda/instruction/tcgen05mma.h create mode 100644 tilelang/intrinsics/tcgen05_macro_generator.py create mode 100644 tilelang/tileop/gemm/gemm_tcgen05.py diff --git a/docs/compiler_internals/inject_fence_proxy.md b/docs/compiler_internals/inject_fence_proxy.md index df173bdf5..192d53848 100644 --- a/docs/compiler_internals/inject_fence_proxy.md +++ b/docs/compiler_internals/inject_fence_proxy.md @@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the ### Timeline View ``` -generic initialize_descriptor → generic shared-store → async wgmma +generic initialize_wgmma_descriptor → generic shared-store → async wgmma │ │ │ └─ generic proxy ┴─ generic proxy ┴─ async proxy │ fence inserted here ↑ @@ -53,7 +53,7 @@ def kernel(): with T.Kernel(1): desc = T.decl_buffer((1,), "uint64", scope="local.descriptor") smem = T.decl_buffer((128,), "float16", scope="shared") - T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32) + T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32) smem[0] = T.float16(0) T.ptx_wgmma_ss( "float16", @@ -83,7 +83,7 @@ def kernel(): with T.Kernel(1): desc = T.decl_buffer((1,), "uint64", scope="local.descriptor") smem = T.decl_buffer((128,), "float16", scope="shared") - T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32) + T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32) smem[0] = T.float16(0) T.fence_proxy_async() T.ptx_wgmma_ss( diff --git a/src/layout/layout.cc b/src/layout/layout.cc index e58a8a04a..fed40d29e 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -535,6 +535,12 @@ TVM_FFI_STATIC_INIT_BLOCK({ return makeGemmABLayoutHopper(stride, mat_continuous, continuity, element_size, k_inner); }) + .def("tl.make_tcgen05mma_swizzled_layout", + [](int stride, int mat_continuous, int continuity, int element_size, + bool k_inner) { + return makeGemmABLayoutSm100(stride, mat_continuous, continuity, + element_size, k_inner); + }) .def("tl.make_full_bank_swizzled_layout", [](int stride, int continuous, int element_size) { return makeFullBankSwizzleLayout(stride, continuous, element_size); diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 748e84094..550c91186 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -154,6 +154,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss) + .set_num_inputs(13) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory) .set_num_inputs(2) .set_attr("TCallEffectKind", @@ -270,11 +275,16 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_TL_BUILTIN(initialize_descriptor) +TIR_DEFINE_TL_BUILTIN(initialize_wgmma_descriptor) .set_num_inputs(5) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(initialize_tcgen05_descriptor) + .set_num_inputs(7) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset) .set_num_inputs(2) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 44bbc21ff..edfcac7af 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -246,6 +246,11 @@ TVM_DLL const Op &ptx_wgmma_ss(); */ TVM_DLL const Op &ptx_wgmma_rs(); +/*! + * \brief tvm intrinsic for tcgen05 mma shared-shared instructions. + */ +TVM_DLL const Op &ptx_tcgen05_mma_ss(); + /*! * \brief tvm intrinsics for initializing tensor memory * @@ -467,7 +472,13 @@ TVM_DLL const Op &tl_shuffle_elect(); * This op is used to represent a descriptor initialization operation in * tilelang. */ -TVM_DLL const Op &initialize_descriptor(); +TVM_DLL const Op &initialize_wgmma_descriptor(); + +/*! + * \brief tilelang intrinsic for initializing a descriptor buffer for + * tcgen05 mma. + */ +TVM_DLL const Op &initialize_tcgen05_descriptor(); /*! * \brief tilelang intrinsic for setting the start address of a descriptor diff --git a/src/op/gemm.cc b/src/op/gemm.cc index afee0ebe4..0cb6f08ac 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -12,79 +12,13 @@ #include #include "../target/utils.h" +#include "tcgen5_meta.h" namespace tvm { namespace tl { using namespace tir; -struct TCGEN5MMAMeta { - int atom_m, atom_n, atom_k; -}; - -// Return {is_success, meta} -static inline std::pair -GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { -// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. -#define FAIL \ - return { \ - false, TCGEN5MMAMeta { 0, 0, 0 } \ - } -#define SUCCESS(atom_m, atom_n, atom_k) \ - return { \ - true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \ - } - std::vector ws_valid_atom_ns = {256, 128, 64}; - if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) && - (c_dtype.is_float() && c_dtype.bits() == 32)) { - if (K % 16 != 0) - FAIL; - if (M % 128 == 0) { - for (int atom_n = 256; atom_n >= 16; atom_n -= 16) - if (N % atom_n == 0) - SUCCESS(128, atom_n, 16); - FAIL; - } else if (M % 64 == 0) { - for (int atom_n : ws_valid_atom_ns) - if (N % atom_n == 0) - SUCCESS(64, atom_n, 16); - FAIL; - } else if (M % 32 == 0) { - for (int atom_n : ws_valid_atom_ns) - if (N % atom_n == 0) - SUCCESS(32, atom_n, 16); - FAIL; - } else { - FAIL; - } - } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) && - (c_dtype.is_float() && c_dtype.bits() == 32)) { - if (K % 32 != 0) - FAIL; - if (M % 128 == 0) { - for (int atom_n = 256; atom_n >= 16; atom_n -= 16) - if (N % atom_n == 0) - SUCCESS(128, atom_n, 32); - FAIL; - } else if (M % 64 == 0) { - for (int atom_n : ws_valid_atom_ns) - if (N % atom_n == 0) - SUCCESS(64, atom_n, 32); - FAIL; - } else if (M % 32 == 0) { - for (int atom_n : ws_valid_atom_ns) - if (N % atom_n == 0) - SUCCESS(32, atom_n, 32); - FAIL; - } else { - FAIL; - } - } - FAIL; -#undef FAIL -#undef SUCCESS -} - /** * @brief Construct a Gemm operator from serialized TL arguments and a buffer * map. @@ -199,7 +133,7 @@ GemmInst GemmNode::GetGemmInst(int block_size, Target target) const { TargetIsSm100(target)) { return GemmInst::kMMA; } else { - ICHECK(0) << "Unsupported target for gemm: " << target->str(); + ICHECK(0) << "Unsupported target for gemm: " << target; } } diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index b64d12c42..e9984edd4 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -13,80 +13,13 @@ #include "../target/utils.h" #include "tvm/ffi/string.h" -#include +#include "tcgen5_meta.h" namespace tvm { namespace tl { using namespace tir; -struct TCGEN5MMAMeta { - int atom_m, atom_n, atom_k; -}; - -// Return {is_success, meta} -static inline std::pair -GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { -// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. -#define FAIL \ - return { \ - false, TCGEN5MMAMeta { 0, 0, 0 } \ - } -#define SUCCESS(atom_m, atom_n, atom_k) \ - return { \ - true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \ - } - std::vector ws_valid_atom_ns = {256, 128, 64}; - if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) && - (c_dtype.is_float() && c_dtype.bits() == 32)) { - if (K % 16 != 0) - FAIL; - if (M % 128 == 0) { - for (int atom_n = 256; atom_n >= 16; atom_n -= 16) - if (N % atom_n == 0) - SUCCESS(128, atom_n, 16); - FAIL; - } else if (M % 64 == 0) { - for (int atom_n : ws_valid_atom_ns) - if (N % atom_n == 0) - SUCCESS(64, atom_n, 16); - FAIL; - } else if (M % 32 == 0) { - for (int atom_n : ws_valid_atom_ns) - if (N % atom_n == 0) - SUCCESS(32, atom_n, 16); - FAIL; - } else { - FAIL; - } - } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) && - (c_dtype.is_float() && c_dtype.bits() == 32)) { - if (K % 32 != 0) - FAIL; - if (M % 128 == 0) { - for (int atom_n = 256; atom_n >= 16; atom_n -= 16) - if (N % atom_n == 0) - SUCCESS(128, atom_n, 32); - FAIL; - } else if (M % 64 == 0) { - for (int atom_n : ws_valid_atom_ns) - if (N % atom_n == 0) - SUCCESS(64, atom_n, 32); - FAIL; - } else if (M % 32 == 0) { - for (int atom_n : ws_valid_atom_ns) - if (N % atom_n == 0) - SUCCESS(32, atom_n, 32); - FAIL; - } else { - FAIL; - } - } - FAIL; -#undef FAIL -#undef SUCCESS -} - /** * @brief Construct a Gemm operator from serialized TL arguments and a buffer * map. @@ -144,6 +77,20 @@ GemmPy::GemmPy(Array args, BufferMap vmap) { if (args.size() > 15) { node->wg_wait = args[15].as().value()->value; } + if (args.size() > 16) { + node->mbarptr = args[16]; + } else { + node->mbarptr = IntImm(DataType::UInt(32), 0); + } + if (args.size() > 18) { + node->C_coords = Array({args[17], args[18]}); + } else if (args.size() > 17) { + node->C_coords = + Array({args[17], IntImm(DataType::Int(32), 0)}); + } else { + node->C_coords = Array( + {IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), 0)}); + } data_ = std::move(node); } @@ -378,5 +325,31 @@ TVM_FFI_STATIC_INIT_BLOCK({ }); }); +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "tl.get_tcgen5_mma_meta", + [](int M, int N, int K, DataType ab_dtype, DataType c_dtype) { + auto [success, meta] = GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype); + Array result; + if (success) { + result.push_back(Integer(meta.atom_m)); + result.push_back(Integer(meta.atom_n)); + result.push_back(Integer(meta.atom_k)); + } + return result; + }); + refl::GlobalDef().def( + "tl.get_tcgen5_instr_desc", + [](int atom_m, int atom_n, int atom_k, DataType ab_dtype, + DataType c_dtype, bool a_is_k_major, bool b_is_k_major, + int scale_in_a, int scale_in_b) { + uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype, + c_dtype, a_is_k_major, b_is_k_major, + scale_in_a, scale_in_b); + return Integer(static_cast(desc)); + }); +}); + } // namespace tl } // namespace tvm diff --git a/src/op/gemm_py.h b/src/op/gemm_py.h index cd4b929ea..ce20a4417 100644 --- a/src/op/gemm_py.h +++ b/src/op/gemm_py.h @@ -29,6 +29,8 @@ class GemmPyNode : public TileOperatorNode { int stride_A, stride_B; int offset_A, offset_B; PrimExpr clear_accum = const_false(); + PrimExpr mbarptr; + Array C_coords; // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack // only will be enabled under cdna mfma instructions int kPack = 1; @@ -57,6 +59,8 @@ class GemmPyNode : public TileOperatorNode { .def_ro("offset_A", &GemmPyNode::offset_A) .def_ro("offset_B", &GemmPyNode::offset_B) .def_ro("clear_accum", &GemmPyNode::clear_accum) + .def_ro("mbarptr", &GemmPyNode::mbarptr) + .def_ro("C_coords", &GemmPyNode::C_coords) .def_ro("kPack", &GemmPyNode::kPack) .def_ro("wg_wait", &GemmPyNode::wg_wait) .def_ro("policy", &GemmPyNode::policy); @@ -73,6 +77,8 @@ class GemmPyNode : public TileOperatorNode { equal(offset_A, other->offset_B) && equal(offset_B, other->offset_B) && equal(clear_accum, other->clear_accum) && + equal(mbarptr, other->mbarptr) && + equal(C_coords, other->C_coords) && equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) && equal(policy, other->policy); } @@ -94,6 +100,8 @@ class GemmPyNode : public TileOperatorNode { hash_reduce(offset_A); hash_reduce(offset_B); hash_reduce(clear_accum); + hash_reduce(mbarptr); + hash_reduce(C_coords); hash_reduce(kPack); hash_reduce(wg_wait); hash_reduce(policy); diff --git a/src/op/tcgen5_meta.h b/src/op/tcgen5_meta.h new file mode 100644 index 000000000..56d890384 --- /dev/null +++ b/src/op/tcgen5_meta.h @@ -0,0 +1,166 @@ +#ifndef TVM_TL_OP_TCGEN5_META_H_ +#define TVM_TL_OP_TCGEN5_META_H_ + +#include +#include +#include + +#include +#include + +namespace tvm { +namespace tl { + +using runtime::DataType; + +struct TCGEN5MMAMeta { + int atom_m, atom_n, atom_k; +}; + +inline std::pair +GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { +// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. +#define FAIL \ + return { \ + false, TCGEN5MMAMeta { 0, 0, 0 } \ + } +#define SUCCESS(atom_m, atom_n, atom_k) \ + return { \ + true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \ + } + std::vector ws_valid_atom_ns = {256, 128, 64}; + if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) && + (c_dtype.is_float() && c_dtype.bits() == 32)) { + if (K % 16 != 0) + FAIL; + if (M % 128 == 0) { + for (int atom_n = 256; atom_n >= 16; atom_n -= 16) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 16); + FAIL; + } else if (M % 64 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(64, atom_n, 16); + FAIL; + } else if (M % 32 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(32, atom_n, 16); + FAIL; + } else { + FAIL; + } + } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) && + (c_dtype.is_float() && c_dtype.bits() == 32)) { + if (K % 32 != 0) + FAIL; + if (M % 128 == 0) { + for (int atom_n = 256; atom_n >= 16; atom_n -= 16) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 32); + FAIL; + } else if (M % 64 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(64, atom_n, 32); + FAIL; + } else if (M % 32 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(32, atom_n, 32); + FAIL; + } else { + FAIL; + } + } + FAIL; +#undef FAIL +#undef SUCCESS +} + +inline uint32_t GetTCGEN5InstrDesc(int atom_m, int atom_n, int atom_k, + DataType ab_dtype, DataType c_dtype, + bool a_is_k_major, bool b_is_k_major, + int scale_in_a, int scale_in_b) { + ICHECK(atom_m % 16 == 0) << "atom_m must be divisible by 16"; + ICHECK(atom_n % 8 == 0) << "atom_n must be divisible by 8"; + ICHECK(atom_k == 16 || atom_k == 32) + << "Unsupported atom_k for TCGEN5MMA descriptor: " << atom_k; + ICHECK(scale_in_a == 1 || scale_in_a == -1) + << "scale_in_a must be +/-1 for TCGEN5MMA"; + ICHECK(scale_in_b == 1 || scale_in_b == -1) + << "scale_in_b must be +/-1 for TCGEN5MMA"; + + auto encode_dtype = [&](DataType dtype) -> uint32_t { + if (dtype.is_float16()) { + return static_cast(0); + } else if (dtype.is_bfloat16()) { + return static_cast(1); + } else if (dtype.is_float8_e4m3fn() || dtype.is_float8_e4m3fnuz() || + dtype.is_float8_e4m3()) { + return static_cast(0); + } else if (dtype.is_float8_e5m2fnuz() || + dtype.is_float8_e5m2()) { + return static_cast(1); + } + LOG(FATAL) << "Unsupported dtype for TCGEN5MMA descriptor: " << dtype; + return 0u; + }; + + uint32_t a_format = encode_dtype(ab_dtype); + uint32_t b_format = a_format; + + uint32_t c_format = 0; + if (c_dtype.is_float16()) { + c_format = 0; + } else if (c_dtype.is_float()) { + c_format = 1; + } else if (c_dtype.is_int()) { + c_format = 2; + } else { + LOG(FATAL) << "Unsupported accumulator dtype for TCGEN5MMA descriptor: " + << c_dtype; + } + + auto set_bits = [](uint32_t value, int start, int width) -> uint32_t { + uint32_t mask = (width == 32) ? 0xFFFFFFFFu : ((1u << width) - 1); + return (value & mask) << start; + }; + + uint32_t desc = 0; + desc |= set_bits(0, 0, 2); // sparse_id2 + desc |= set_bits(0, 2, 1); // sparse_flag + desc |= set_bits(0, 3, 1); // saturate + desc |= set_bits(c_format, 4, 2); + + desc |= set_bits(a_format, 7, 3); + desc |= set_bits(b_format, 10, 3); + + uint32_t a_neg = (scale_in_a == -1) ? 1u : 0u; + uint32_t b_neg = (scale_in_b == -1) ? 1u : 0u; + desc |= set_bits(a_neg, 13, 1); + desc |= set_bits(b_neg, 14, 1); + + uint32_t a_major = a_is_k_major ? 0u : 1u; + uint32_t b_major = b_is_k_major ? 0u : 1u; + desc |= set_bits(a_major, 15, 1); + desc |= set_bits(b_major, 16, 1); + + uint32_t n_dim = static_cast(atom_n >> 3); + uint32_t m_dim = static_cast(atom_m >> 4); + desc |= set_bits(n_dim, 17, 6); + desc |= set_bits(0, 23, 1); + desc |= set_bits(m_dim, 24, 5); + desc |= set_bits(0, 29, 1); + + uint32_t max_shift = 0u; + desc |= set_bits(max_shift, 30, 2); + + return desc; +} + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_TCGEN5_META_H_ diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 97ec86ec9..fcdf862c4 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -265,6 +265,9 @@ std::string CodeGenTileLangCUDA::Finish() { if (need_wgmma_instruction_h_) { decl_stream << "#include \n"; } + if (need_tcgen05mma_instruction_h_) { + decl_stream << "#include \n"; + } if (enable_fp8_) { decl_stream << "#include \n"; } @@ -1726,6 +1729,66 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { replacer.register_rule("(scale_out)", scale_out ? "true" : "false"); wgmma_call = replacer.rewrite(wgmma_call); this->stream << wgmma_call; + } else if (op->op.same_as(tl::ptx_tcgen05_mma_ss())) { + ICHECK_EQ(op->args.size(), 14U) << "ptx_tcgen05_mma_ss args is " + << op->args; + std::string A_dtype = Downcast(op->args[0])->value; + std::string B_dtype = Downcast(op->args[1])->value; + std::string C_dtype = Downcast(op->args[2])->value; + std::string a_desc = this->PrintExpr(op->args[3]); + std::string A_offset = this->PrintExpr(op->args[4]); + std::string b_desc = this->PrintExpr(op->args[5]); + std::string B_offset = this->PrintExpr(op->args[6]); + std::string c_ref = this->PrintExpr(op->args[7]); + PrimExpr desc_expr = op->args[8]; + bool scale_out = Downcast(op->args[9])->value; + std::string mask0 = this->PrintExpr(op->args[10]); + std::string mask1 = this->PrintExpr(op->args[11]); + std::string mask2 = this->PrintExpr(op->args[12]); + std::string mask3 = this->PrintExpr(op->args[13]); + + auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype); + auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype); + auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype); + uint32_t instr_desc = get_const_uint32(desc_expr); + auto decoded = decode_instr_desc(instr_desc, dtype_a_enum); + + need_tcgen05mma_instruction_h_ = true; + this->PrintIndent(); + std::string tcgen05_call = + "tl::tcgen05mma_ss<(AType), (BType), (CType), (M), (N), (K), " + "(tnspA), (tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + " + "(A_offset)), uint64_t((desc_b) + (B_offset)), " + "reinterpret_cast((C)), (scale_out), " + "static_cast((desc_val)), (mask0), (mask1), (mask2), " + "(mask3));\n"; + tl::codegen::Replacer replacer; + replacer.register_rule("(AType)", + tl::codegen::ptx::DTypeEnumToString(dtype_a_enum)); + replacer.register_rule("(BType)", + tl::codegen::ptx::DTypeEnumToString(dtype_b_enum)); + replacer.register_rule("(CType)", + tl::codegen::ptx::DTypeEnumToString(dtype_c_enum)); + replacer.register_rule("(M)", std::to_string(decoded.m)); + replacer.register_rule("(N)", std::to_string(decoded.n)); + replacer.register_rule("(K)", std::to_string(decoded.k)); + replacer.register_rule("(tnspA)", decoded.a_is_k_major ? "false" : "true"); + replacer.register_rule("(tnspB)", decoded.b_is_k_major ? "false" : "true"); + replacer.register_rule("(scaleA)", decoded.scale_in_a_pos ? "1" : "-1"); + replacer.register_rule("(scaleB)", decoded.scale_in_b_pos ? "1" : "-1"); + replacer.register_rule("(desc_a)", a_desc); + replacer.register_rule("(A_offset)", A_offset); + replacer.register_rule("(desc_b)", b_desc); + replacer.register_rule("(B_offset)", B_offset); + replacer.register_rule("(C)", c_ref); + replacer.register_rule("(scale_out)", scale_out ? "true" : "false"); + replacer.register_rule("(desc_val)", this->PrintExpr(desc_expr)); + replacer.register_rule("(mask0)", mask0); + replacer.register_rule("(mask1)", mask1); + replacer.register_rule("(mask2)", mask2); + replacer.register_rule("(mask3)", mask3); + tcgen05_call = replacer.rewrite(tcgen05_call); + this->stream << tcgen05_call; } else if (op->op.same_as(builtin::ptx_ldmatrix())) { // arg 0: whether the matrix is loaded in column major format or not. // arg 1: number of matrices to load. @@ -2050,19 +2113,35 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { op->args, true, os); } else if (op->op.same_as(tl::tl_shuffle_elect())) { os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()"; - } else if (op->op.same_as(tl::initialize_descriptor())) { + } else if (op->op.same_as(tl::initialize_wgmma_descriptor())) { ICHECK(op->args.size() == 5) - << "tl_initialize_descriptor expects 5 arguments but got " + << "tl_initialize_wgmma_descriptor expects 5 arguments but got " << op->args.size(); auto descriptor = op->args[0]; auto start_address = op->args[1]; auto layout_type = op->args[2]; auto leading_byte_offset = op->args[3]; auto stride_byte_offset = op->args[4]; - os << "tl::initialize_descriptor<" << PrintExpr(layout_type) << ", " + os << "tl::initialize_wgmma_descriptor<" << PrintExpr(layout_type) << ", " << PrintExpr(leading_byte_offset) << ", " << PrintExpr(stride_byte_offset) << ">(" << PrintExpr(descriptor) << ", " << PrintExpr(start_address) << ")"; + } else if (op->op.same_as(tl::initialize_tcgen05_descriptor())) { + ICHECK(op->args.size() == 7) + << "tl_initialize_tcgen05_descriptor expects 7 arguments but got " + << op->args.size(); + auto descriptor = op->args[0]; + auto start_address = op->args[1]; + auto leading_byte_offset = op->args[2]; + auto stride_byte_offset = op->args[3]; + auto base_offset = op->args[4]; + auto leading_abs = op->args[5]; + auto swizzle_mode = op->args[6]; + os << "tl::initialize_tcgen05_descriptor(" << PrintExpr(descriptor) << ", " + << PrintExpr(start_address) << ", " << PrintExpr(leading_byte_offset) + << ", " << PrintExpr(stride_byte_offset) << ", " + << PrintExpr(base_offset) << ", " << PrintExpr(leading_abs) << ", " + << PrintExpr(swizzle_mode) << ")"; } else if (op->op.same_as(tl::increase_descriptor_offset())) { ICHECK(op->args.size() == 2) << "tl_increase_descriptor_offset expects 2 arguments but got " diff --git a/src/target/codegen_cuda.h b/src/target/codegen_cuda.h index 1618995e0..2087d58dc 100644 --- a/src/target/codegen_cuda.h +++ b/src/target/codegen_cuda.h @@ -110,6 +110,8 @@ class CodeGenTileLangCUDA final : public CodeGenC { bool need_mma_instruction_h_{false}; // whether need tl wgmma instruction header bool need_wgmma_instruction_h_{false}; + // whether need tl tcgen05mma instruction header + bool need_tcgen05mma_instruction_h_{false}; // whether need cast_smem_ptr_to_int helper function bool need_cast_smem_ptr_to_int_{false}; // whether need cooperative_groups.h diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index 6ff99f58f..c7193007d 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -264,6 +264,54 @@ union GmmaDescriptor { } }; +union Tcgen05Descriptor { + CUTE_HOST_DEVICE constexpr Tcgen05Descriptor() noexcept : desc_(0) {} + CUTE_HOST_DEVICE constexpr Tcgen05Descriptor(uint64_t desc) noexcept + : desc_(desc) {} + CUTE_HOST_DEVICE constexpr Tcgen05Descriptor(Tcgen05Descriptor const &t) noexcept + : desc_(t.desc_) {} + CUTE_HOST_DEVICE constexpr Tcgen05Descriptor(Tcgen05Descriptor &&t) noexcept + : desc_(t.desc_) {} + + CUTE_HOST_DEVICE constexpr Tcgen05Descriptor & + operator=(Tcgen05Descriptor const &t) noexcept { + desc_ = t.desc_; + return *this; + } + + CUTE_HOST_DEVICE constexpr Tcgen05Descriptor & + operator=(Tcgen05Descriptor &&t) noexcept { + desc_ = t.desc_; + return *this; + } + + uint64_t desc_; + uint32_t reg32_[2]; + + // Bitfield implementation avoids the need for shifts in assignment + struct { + // start_address, bit [0,14), 4LSB not included + uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // leading dimension byte offset, bit [16,30), 4LSB not included + uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // stride dimension byte offset, bit [32,46), 4LSB not included + uint16_t stride_byte_offset_ : 14, version_ : 2; // 14 bits [0,14), 2 bits [14,16) + // base_offset, bit [49,52). leading_byte_offset_mode, bit [52,53). + uint8_t : 1, base_offset_ : 3, lbo_mode_ : 1, : 3; // 1 bit unused, 3 bits [1,4), 1 bit [4,5), 3 bits unused + // layout type, bit [61,64), SWIZZLE_NONE matrix descriptor = 0, SWIZZLE_128B matrix descriptor = 2, SWIZZLE_64B descriptor = 4, SWIZZLE_32B descriptor = 6, SWIZZLE_128B_BASE32B = 1, N/A = 3, N/A = 5, N/A = 7 + uint8_t : 5, layout_type_ : 3; // 6 bits unused, 3 bits [5,8) + } bitfield; + // Separate the field, as we may only update one part of desc + struct { + uint32_t lo; + uint32_t hi; + } words; + + CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept { + return desc_; + } +}; + // Any template TL_DEVICE bool Any(T *a, int size) { for (int i = 0; i < size; i++) { @@ -302,8 +350,8 @@ TL_DEVICE void __sync_thread_partial() { template -TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, - T *start_address) { +TL_DEVICE void initialize_wgmma_descriptor(GmmaDescriptor &descriptor, + T *start_address) { descriptor.bitfield.start_address_ = cute::cast_smem_ptr_to_uint(start_address) >> 4; descriptor.bitfield.layout_type_ = layout_type; @@ -312,6 +360,27 @@ TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; } +template +TL_DEVICE void initialize_tcgen05_descriptor(Tcgen05Descriptor &descriptor, + T *start_address, + int leading_byte_offset, + int stride_byte_offset, + int base_offset, + bool leading_is_absolute, + int swizzle_mode) { + descriptor.desc_ = 0; + descriptor.bitfield.start_address_ = + cute::cast_smem_ptr_to_uint(start_address) >> 4; + descriptor.bitfield.leading_byte_offset_ = leading_byte_offset; + descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; + descriptor.bitfield.version_ = 0; + descriptor.bitfield.base_offset_ = base_offset & 0x7; + descriptor.bitfield.lbo_mode_ = leading_is_absolute ? 1 : 0; + descriptor.bitfield.layout_type_ = swizzle_mode & 0x7; + descriptor.words.hi |= (1u << (48 - 32)); + descriptor.words.hi |= (0xB0u << (53 - 32)); +} + template TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor, T offset) { diff --git a/src/tl_templates/cuda/instruction/tcgen05mma.h b/src/tl_templates/cuda/instruction/tcgen05mma.h new file mode 100644 index 000000000..4d22b1788 --- /dev/null +++ b/src/tl_templates/cuda/instruction/tcgen05mma.h @@ -0,0 +1,3 @@ +#pragma once + +#include "../common.h" \ No newline at end of file diff --git a/src/transform/inject_fence_proxy.cc b/src/transform/inject_fence_proxy.cc index b95780398..1f7f356c7 100644 --- a/src/transform/inject_fence_proxy.cc +++ b/src/transform/inject_fence_proxy.cc @@ -103,7 +103,8 @@ bool IsKnownGeneric(const CallNode *call) { return false; } return call->op.same_as(ptx_ldmatrix()) || call->op.same_as(ptx_stmatrix()) || - call->op.same_as(initialize_descriptor()); + call->op.same_as(initialize_wgmma_descriptor()) || + call->op.same_as(initialize_tcgen05_descriptor()); } ProxyKind ProxyFromAttrValue(const ObjectRef &value) { diff --git a/src/transform/lower_shared_tmem.cc b/src/transform/lower_shared_tmem.cc index 661b39949..afb64b459 100644 --- a/src/transform/lower_shared_tmem.cc +++ b/src/transform/lower_shared_tmem.cc @@ -88,6 +88,7 @@ class SharedTmemRewriter : public StmtExprMutator { Array new_data_vars; for (auto buffer : tmem_buffers) { auto data = buffer->data; + if (var_remap_.count(data)) continue; auto new_data = Var(data->name_hint, PointerType(PrimType(tmem_dtype_), "shared")); var_remap_.Set(data, new_data); @@ -107,6 +108,7 @@ class SharedTmemRewriter : public StmtExprMutator { buffer->buffer_type); new_buffers.push_back(new_buffer); buffer_remap_.Set(buffer, new_buffer); + buffer_data_to_buffer_.Set(new_data, new_buffer); } // remove the tmem buffers @@ -255,7 +257,15 @@ class SharedTmemRewriter : public StmtExprMutator { op->dtype, op->op, {op->args[0], new_data, op->args[2], op->args[3], op->args[4]}); } - return StmtExprMutator::VisitExpr_(op); + auto expr = StmtExprMutator::VisitExpr_(op); + return expr; + } + PrimExpr VisitExpr_(const VarNode *op) final { + Var var = GetRef(op); + if (var_remap_.count(var)) { + return var_remap_[var]; + } + return var; } Stmt VisitStmt_(const AttrStmtNode *op) final { diff --git a/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py b/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py index 6d6fbf3c3..f67e408a5 100644 --- a/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py +++ b/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py @@ -193,8 +193,8 @@ def before(): desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor") desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor") C_local = T.decl_buffer((32,), "float16", scope="local") - T.initialize_descriptor(desc_a, T.uint64(0), 2, 1, 32) - T.initialize_descriptor(desc_b, T.uint64(0), 2, 1, 32) + T.initialize_wgmma_descriptor(desc_a, T.uint64(0), 2, 1, 32) + T.initialize_wgmma_descriptor(desc_b, T.uint64(0), 2, 1, 32) T.warpgroup_arrive() T.ptx_wgmma_ss("float16", "m64n64k16", T.bool(True), T.bool(True), "fp16", "fp16", "fp16", desc_a.data, T.int32(0), desc_b.data, T.int32(0), C_local.data, diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/intrinsics/tcgen05_macro_generator.py new file mode 100644 index 000000000..047300800 --- /dev/null +++ b/tilelang/intrinsics/tcgen05_macro_generator.py @@ -0,0 +1,423 @@ +from enum import IntEnum +from typing import Optional, Callable +import tilelang.language as T +from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter +from tvm import DataType +from tvm.tir import PrimExpr, Buffer, Var, IndexMap +import tvm +from tilelang import _ffi_api +from tilelang.utils import is_tensor_memory +from tilelang.layout import ( + Layout, + make_full_bank_swizzled_layout, + make_half_bank_swizzled_layout, + make_quarter_bank_swizzled_layout, + make_linear_layout, +) +from tvm.runtime import convert +from tilelang.intrinsics.mma_layout import (shared_16x8_to_mma_32x4_layout_sr_a, + shared_16x16_to_mma_32x8_layout_sr_a, + shared_16x32_to_mma_32x16_layout_sr_a) + +lift = convert + + +class SwizzleMode(IntEnum): + # SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 + NONE = 0 + SWIZZLE_128B = 1 + SWIZZLE_64B = 2 + SWIZZLE_32B = 3 + + def is_none(self) -> bool: + return self == SwizzleMode.NONE + + def is_swizzle_32b(self) -> bool: + return self == SwizzleMode.SWIZZLE_32B + + def is_swizzle_64b(self) -> bool: + return self == SwizzleMode.SWIZZLE_64B + + def is_swizzle_128b(self) -> bool: + return self == SwizzleMode.SWIZZLE_128B + + def swizzle_byte_size(self) -> int: + if self.is_swizzle_32b(): + return 32 + elif self.is_swizzle_64b(): + return 64 + elif self.is_swizzle_128b(): + return 128 + else: + return 1 + + def swizzle_atom_size(self) -> int: + if self.is_swizzle_32b(): + return 32 // 16 + elif self.is_swizzle_64b(): + return 64 // 16 + elif self.is_swizzle_128b(): + return 128 // 16 + else: + return 1 + + +# derive from MMAIntrinEmitter as some layouts are the same +class TensorCoreIntrinEmitter(MMAIntrinEmitter): + """ + To eliminate Python syntax within TIR Macro. + """ + + # should be rewritten to support dynamic k_dim + tcgen05_prefix: str + + a_shared_layout: Layout = None + b_shared_layout: Layout = None + + def __init__( + self, + a_dtype: str = "float16", + b_dtype: str = "float16", + accum_dtype: str = "float16", + a_transposed: bool = False, + b_transposed: bool = False, + block_row_warps: int = 2, + block_col_warps: int = 2, + warp_row_tiles: int = 8, + warp_col_tiles: int = 8, + chunk: int = 16, + reduce_k: int = 1, + num_elems_per_byte: int = 1, + is_m_first: Optional[bool] = False, + thread_var: Optional[Var] = None, + ): + super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps, + block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k, + num_elems_per_byte, is_m_first, thread_var) + self._initialize_tcgen05_prefix(self.n_dim) + + def _assign_a_shared_layout(self, layout: Layout): + self.a_shared_layout = layout + return self + + def _assign_b_shared_layout(self, layout: Layout): + self.b_shared_layout = layout + return self + + def _initialize_tcgen05_prefix(self, n_dim: int = 16): + inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles + # 256 bits per instruction + inst_k = 256 // DataType(self.a_dtype).bits + self.tcgen05_prefix = f"m{inst_m}n{inst_n}k{inst_k}" + + def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): + warp_row_tiles = self.warp_row_tiles + warp_col_tiles = self.warp_col_tiles + assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}" + assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}" + assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}" + assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}" + + # four warps per block + self.warp_rows = warp_row_tiles // m_dim + if warp_col_tiles % 16 == 0: + self.n_dim = 16 + self.micro_size_y = 16 + self.warp_cols = warp_col_tiles // 16 + else: + # must be divisible by 8 + self.n_dim = 8 + self.micro_size_y = 8 + self.warp_cols = warp_col_tiles // 8 + + self.micro_size_x = m_dim + self.micro_size_k = k_dim + + def _determinate_swizzle_mode(self, buffer: Buffer, layout: Layout) -> SwizzleMode: + # same behavior to src/layout/gemm_layouts.cc::makeGemmABLayoutHopper + if layout is None or layout.is_equal(make_linear_layout(buffer)): + return SwizzleMode.NONE + elif layout.is_equal(make_quarter_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_32B + elif layout.is_equal(make_half_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_64B + elif layout.is_equal(make_full_bank_swizzled_layout(buffer)): + return SwizzleMode.SWIZZLE_128B + else: + raise ValueError(f"Unsupported swizzle mode: {layout}") + + def tcgen05mma(self, + A_buf: Buffer, + B_buf: Buffer, + C_local_buf: Buffer, + mbar, + clear_accum: PrimExpr = False): + + if is_tensor_memory(A_buf): + return self.tcgen05mma_rs(A_buf, B_buf, C_local_buf, clear_accum) + + local_size_out = self.local_size_out + accum_dtype = self.accum_dtype + m_dim = self.block_row_warps * self.warp_row_tiles + warp_cols = self.warp_cols + micro_size_k = self.micro_size_k + k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles + scale_out = ~clear_accum + scale_in_a = 1 + scale_in_b = 1 + + assert k_dim >= micro_size_k, f"k_dim must be greater than or equal to {micro_size_k}, got k_dim: {k_dim}" + + a_is_k_major = not self.a_transposed + b_is_k_major = self.b_transposed + a_swizzle_mode = self._determinate_swizzle_mode(A_buf, self.a_shared_layout) + b_swizzle_mode = self._determinate_swizzle_mode(B_buf, self.b_shared_layout) + + elems_in_bits = DataType(self.a_dtype).bits + elems_in_bytes = elems_in_bits // 8 + a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes + b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none( + ) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes + accum_bits = DataType(accum_dtype).bits + accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 + # by default, we utilize non-swizzle layout offset + a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * + elems_in_bytes) + a_stride_byte_offset = (8 * k_dim * elems_in_bytes) if a_is_k_major else (8 * 8 * + elems_in_bytes) + + if not a_swizzle_mode.is_none(): + # swizzle mode doesn't require LBO/SBO to be 1 + # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset + if a_is_k_major: + a_leading_byte_offset = 16 + a_stride_byte_offset = 8 * a_swizzle_mode.swizzle_byte_size() + else: + # MN Major + # LBO represents the distance between two atoms along the M dimension + # SBO represents the distance between two atoms along the K dimension + a_m_axis_atoms = m_dim // a_swizzle_atom_elems + if a_m_axis_atoms <= 1: + a_leading_byte_offset = 0 + else: + a_leading_byte_offset = 8 * a_swizzle_mode.swizzle_atom_size() * ( + a_swizzle_mode.swizzle_byte_size() // elems_in_bytes) + + if a_m_axis_atoms <= 1: + a_stride_byte_offset = 8 * elems_in_bytes * m_dim + else: + a_stride_byte_offset = 8 * elems_in_bytes * a_swizzle_atom_elems + + b_leading_byte_offset = (8 * 8 * elems_in_bytes) if b_is_k_major else (8 * n_dim * + elems_in_bytes) + b_stride_byte_offset = (8 * k_dim * + elems_in_bytes) if b_is_k_major else (0 if n_dim == 8 else + (8 * 8 * elems_in_bytes)) + if not b_swizzle_mode.is_none(): + # swizzle mode doesn't require LBO/SBO to be 1 + # https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-leading-dimension-byte-offset + if b_is_k_major: + b_leading_byte_offset = 16 + b_stride_byte_offset = 8 * b_swizzle_mode.swizzle_byte_size() + else: + # MN Major, K * N + # LBO represents the distance between two atoms along the N dimension + # SBO represents the distance between two atoms along the K dimension + b_n_axis_atoms = n_dim // b_swizzle_atom_elems + if b_n_axis_atoms <= 1: + b_leading_byte_offset = 0 + else: + b_leading_byte_offset = 8 * 8 * elems_in_bytes * k_dim + if b_n_axis_atoms <= 1: + b_stride_byte_offset = 8 * elems_in_bytes * n_dim + else: + b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems + + # for example, if [n, k] where k is 128, we should split it into 2 atoms + # where max specially handles the case when n_dim is 8. + ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1) + bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) + + meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim) + if len(meta) != 3: + raise ValueError( + f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, " + f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}" + ) + atom_m, atom_n, atom_k = (int(x) for x in meta) + layout_mode = "ss" if atom_m == 128 else "ws" + a_dtype_abbrv = self.a_dtype_abbrv + b_dtype_abbrv = self.b_dtype_abbrv + + print("Before get get_tcgen5_instr_desc") + instr_desc = T.Cast( + "uint32", + self.get_tcgen5_instr_desc( + atom_m, + atom_n, + atom_k, + a_is_k_major, + b_is_k_major, + scale_in_a, + scale_in_b, + ), + ) + print("instr_desc, ", instr_desc) + mask_full = T.Cast("int32", -1) + mask_zero = T.Cast("int32", 0) + mask0 = mask1 = mask2 = mask3 = mask_full if layout_mode == "ss" else mask_zero + + @T.macro + def _warp_mma(A_buf, B_buf, C_local_buf): + desc_a = T.alloc_descriptor() + desc_b = T.alloc_descriptor() + + for ki in T.serial(0, (k_dim // micro_size_k)): + for i in T.serial(m_dim // 64): + A_elem_offset = (ki % ak_atom_size) * micro_size_k + i * 64 * a_swizzle_atom_elems + ( + ki // ak_atom_size + ) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k + B_elem_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( + ki % bk_atom_size + ) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k + A_byte_offset = A_elem_offset * elems_in_bytes + B_byte_offset = B_elem_offset * elems_in_bytes + A_ptr = A_buf.access_ptr("r") + B_ptr = B_buf.access_ptr("r") + T.initialize_tcgen05_descriptor( + desc_a, + A_ptr, + int(a_leading_byte_offset >> 4), + int(a_stride_byte_offset >> 4), + 0, + False, + int(a_swizzle_mode), + ) + T.initialize_tcgen05_descriptor( + desc_b, + B_ptr, + int(b_leading_byte_offset >> 4), + int(b_stride_byte_offset >> 4), + 0, + False, + int(b_swizzle_mode), + ) + T.ptx_tcgen05_mma_ss( + a_dtype_abbrv, + b_dtype_abbrv, + a_dtype_abbrv, + desc_a.data, + A_byte_offset, + desc_b.data, + B_byte_offset, + C_local_buf.data, + instr_desc, + scale_out, + mask0, + mask1, + mask2, + mask3, + ) + + return _warp_mma(A_buf, B_buf, C_local_buf) + + def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment: + raise NotImplementedError + + def make_mma_store_layout(self, tmem_buf: Buffer) -> Layout: + """ + Create the TCGEN5 tensor-memory layout used to store MMA accumulators. + + Parameters + ---------- + tmem_buf : tir.Buffer + The local buffer representing tensormemory of a mma's output + + Returns + ------- + Layout + Layout object describing how logical (i, j) coordinates map to the + swizzled tensor-memory offsets required by TCGEN5MMA. + + Raises + ------ + AssertionError + If `tmem_buf` is not detected to be a tensor-memory buffer. + """ + assert is_tensor_memory(tmem_buf), "tmem_buf must reside in tensor memory (shared.tmem)" + if len(tmem_buf.shape) != 2: + raise ValueError( + f"TCGEN5MMA expects a 2-D tensor-memory buffer, got shape {tmem_buf.shape}" + ) + + m = int(tmem_buf.shape[0]) + n = int(tmem_buf.shape[1]) + k = int(self.chunk) + + meta = self.get_tcgen5_mma_meta(m, n, k) + if len(meta) != 3: + raise ValueError( + f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, " + f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}" + ) + atom_m, atom_n, _ = (int(x) for x in meta) + + if m % atom_m != 0 or n % atom_n != 0: + raise ValueError( + f"Invalid TCGEN5MMA store layout for shape ({m}, {n}) with atoms ({atom_m}, {atom_n})" + ) + + def forward(i: PrimExpr, j: PrimExpr): + atom_idx = (i // atom_m) + (j // atom_n) * (m // atom_m) + ai = i % atom_m + aj = j % atom_n + + if atom_m == 128: + # Layout D + return [ + ai, + aj + atom_idx * atom_n, + ] + if atom_m == 64: + # Layout E (.ws variant) + half_atom_n = atom_n // 2 + return [ + (ai // 32) * 32 + ai % 32 + (aj // half_atom_n) * 64, + (aj % half_atom_n) + atom_idx * half_atom_n, + ] + if atom_m == 32: + # Layout G + quarter_atom_n = atom_n // 4 + return [ + ai % 32 + (aj // quarter_atom_n) * 32, + (aj % quarter_atom_n) + atom_idx * quarter_atom_n, + ] + + raise ValueError(f"Unsupported TCGEN5 atom_m={atom_m}") + + return Layout([m, n], forward) + + + def get_tcgen5_mma_meta(self, m: int, n: int, k: int): + return _ffi_api.get_tcgen5_mma_meta(int(m), int(n), int(k), DataType(self.a_dtype), DataType(self.accum_dtype)) + + def get_tcgen5_instr_desc(self, + atom_m: int, + atom_n: int, + atom_k: int, + a_is_k_major: bool, + b_is_k_major: bool, + scale_in_a: int, + scale_in_b: int) -> PrimExpr: + desc = _ffi_api.get_tcgen5_instr_desc( + atom_m, + atom_n, + atom_k, + DataType(self.a_dtype), + DataType(self.accum_dtype), + a_is_k_major, + b_is_k_major, + scale_in_a, + scale_in_b, + ) + return lift(desc) diff --git a/tilelang/intrinsics/wgmma_macro_generator.py b/tilelang/intrinsics/wgmma_macro_generator.py index 49c4819ca..7cc00849b 100644 --- a/tilelang/intrinsics/wgmma_macro_generator.py +++ b/tilelang/intrinsics/wgmma_macro_generator.py @@ -163,7 +163,7 @@ def wgmma(self, micro_size_k = self.micro_size_k k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles wgmma_prefix = self.wgmma_prefix - scale_out = not clear_accum + scale_out = ~clear_accum scale_in_a = 1 scale_in_b = 1 @@ -246,9 +246,9 @@ def wgmma(self, def _warp_mma(A_buf, B_buf, C_local_buf): desc_a = T.alloc_descriptor() desc_b = T.alloc_descriptor() - T.initialize_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode, + T.initialize_wgmma_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode, int(a_leading_byte_offset >> 4), int(a_stride_byte_offset >> 4)) - T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, + T.initialize_wgmma_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) T.warpgroup_arrive() @@ -288,7 +288,7 @@ def wgmma_rs(self, micro_size_k = self.micro_size_k k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles wgmma_prefix = self.wgmma_prefix - scale_out = not clear_accum + scale_out = ~clear_accum scale_in_a = 1 scale_in_b = 1 @@ -335,7 +335,7 @@ def wgmma_rs(self, @T.macro def _warp_mma(A_buf, B_buf, C_local_buf): desc_b = T.alloc_descriptor() - T.initialize_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, + T.initialize_wgmma_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), int(b_stride_byte_offset >> 4)) T.warpgroup_fence_operand(A_buf, num_regs=a_regs) T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) diff --git a/tilelang/language/ast/ir.py b/tilelang/language/ast/ir.py index 0948cdfa7..45b4c2df4 100644 --- a/tilelang/language/ast/ir.py +++ b/tilelang/language/ast/ir.py @@ -1894,6 +1894,7 @@ def wrapped(*args, **kwargs): ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) ptx_wgmma_ss = _dtype_forward(_tir_op.ptx_wgmma_ss) ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs) +ptx_tcgen05_mma_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_ss) ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk) @@ -2145,6 +2146,7 @@ def wrapped(*args, **kwargs): "ptx_mma_sp", "ptx_wgmma_ss", "ptx_wgmma_rs", + "ptx_tcgen05_mma_ss", "ptx_ldmatrix", "ptx_cp_async", "ptx_cp_async_bulk", diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index de6905e56..24eaf768d 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -449,38 +449,68 @@ def sync_grid(): return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid")) -def initialize_descriptor(descriptor: Buffer, - start_address: PrimExpr, - layout_type_: int = 0, - leading_byte_offset: int = 0, - stride_byte_offset: int = 0) -> PrimExpr: - """ - Initialize a memory descriptor with the given parameters. +def initialize_wgmma_descriptor( + descriptor: Buffer, + start_address: PrimExpr, + layout_type_: int = 0, + leading_byte_offset: int = 0, + stride_byte_offset: int = 0, +) -> PrimExpr: + """Initialize a WGMMA/UTCMMA shared-memory descriptor.""" - Parameters: - descriptor (Buffer): The memory descriptor to initialize. - start_address (PrimExpr): The starting address of the memory region. - layout_type_ (int, optional): Layout type identifier. Defaults to 0. - leading_byte_offset (int, optional): Leading byte offset. Defaults to 0. - stride_byte_offset (int, optional): Stride byte offset. Defaults to 0. + if not isinstance(descriptor, (BufferLoad, Buffer)): + raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") - Returns: - PrimExpr: A handle representing the initialized descriptor. - """ + if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): + raise ValueError("Descriptor must be a 1D buffer of size 1.") + + descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( + descriptor, [0]) + + return evaluate( + tir.call_intrin( + "handle", + tir.op.Op.get("tl.initialize_wgmma_descriptor"), + descriptor, + start_address, + layout_type_, + int(leading_byte_offset), + int(stride_byte_offset), + )) + + +def initialize_tcgen05_descriptor( + descriptor: Buffer, + start_address: PrimExpr, + leading_byte_offset: int, + stride_byte_offset: int, + base_offset: int = 0, + leading_is_absolute: bool = False, + swizzle_mode: int = 0, +) -> PrimExpr: + """Initialize a TCGEN05 shared-memory descriptor.""" if not isinstance(descriptor, (BufferLoad, Buffer)): raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") - if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: + if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): raise ValueError("Descriptor must be a 1D buffer of size 1.") descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( descriptor, [0]) return evaluate( - tir.call_intrin("handle", tir.op.Op.get("tl.initialize_descriptor"), descriptor, - start_address, layout_type_, int(leading_byte_offset), - int(stride_byte_offset))) + tir.call_intrin( + "handle", + tir.op.Op.get("tl.initialize_tcgen05_descriptor"), + descriptor, + start_address, + int(leading_byte_offset), + int(stride_byte_offset), + int(base_offset), + tir.IntImm("int32", 1 if leading_is_absolute else 0), + int(swizzle_mode), + )) def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimExpr: diff --git a/tilelang/language/tir/ir.py b/tilelang/language/tir/ir.py index 1143f2a9e..c36ea3395 100644 --- a/tilelang/language/tir/ir.py +++ b/tilelang/language/tir/ir.py @@ -293,6 +293,7 @@ def wrapped(*args, **kwargs): ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp) ptx_wgmma_ss = _dtype_forward(_tir_op.ptx_wgmma_ss) ptx_wgmma_rs = _dtype_forward(_tir_op.ptx_wgmma_rs) +ptx_tcgen05_mma_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_ss) ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async) ptx_cp_async_bulk = _dtype_forward(_tir_op.ptx_cp_async_bulk) diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index cd87f691b..f83881111 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -1141,6 +1141,48 @@ def ptx_wgmma_rs( ) +def ptx_tcgen05_mma_ss( + a_dtype, + b_dtype, + c_dtype, + desc_a, + A_offset, + desc_b, + B_offset, + C_ptr, + desc_val, + scale_out, + mask0, + mask1, + mask2, + mask3, +): + """TVM intrinsic for tcgen05.mma shared-memory × shared-memory instructions. + + Expects exactly 14 positional arguments: + (a_dtype, b_dtype, c_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr, + desc_val, scale_out, mask0, mask1, mask2, mask3). + """ + return call_intrin( + "handle", + _tvm_op.Op.get("tl.ptx_tcgen05_mma_ss"), + a_dtype, + b_dtype, + c_dtype, + desc_a, + A_offset, + desc_b, + B_offset, + C_ptr, + desc_val, + scale_out, + mask0, + mask1, + mask2, + mask3, + ) + + def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride): """TVM intrinsic for storing the result of PTX MMA into a destination pointer diff --git a/tilelang/layout/__init__.py b/tilelang/layout/__init__.py index 2df0ba187..055a23520 100644 --- a/tilelang/layout/__init__.py +++ b/tilelang/layout/__init__.py @@ -6,6 +6,7 @@ from .swizzle import ( make_swizzled_layout, # noqa: F401 make_wgmma_swizzled_layout, # noqa: F401 + make_tcgen05mma_swizzled_layout, # noqa: F401 make_full_bank_swizzled_layout, # noqa: F401 make_half_bank_swizzled_layout, # noqa: F401 make_quarter_bank_swizzled_layout, # noqa: F401 diff --git a/tilelang/layout/swizzle.py b/tilelang/layout/swizzle.py index 1d3e98909..ebb9c8ea9 100644 --- a/tilelang/layout/swizzle.py +++ b/tilelang/layout/swizzle.py @@ -33,6 +33,21 @@ def make_wgmma_swizzled_layout(buffer: tvm.tir.Buffer, k_major, ) +# for TCGEN05MMA Intrinsics +def make_tcgen05mma_swizzled_layout(buffer: tvm.tir.Buffer, + continuity: int = None, + k_major: bool = True): + assert len(buffer.shape) == 2 + if continuity is None: + continuity = int(buffer.shape[1]) + return _ffi_api.make_tcgen05mma_swizzled_layout( + int(buffer.shape[0]), + int(buffer.shape[1]), + continuity, + int(tvm.DataType(buffer.dtype).bits), + k_major, + ) + # swizzle 128B # args: buffer or (stride, continuous, element_size) diff --git a/tilelang/tileop/gemm/__init__.py b/tilelang/tileop/gemm/__init__.py index 53afae42d..c0ba4a44e 100644 --- a/tilelang/tileop/gemm/__init__.py +++ b/tilelang/tileop/gemm/__init__.py @@ -8,6 +8,7 @@ from tilelang.ir import GemmWarpPolicy from .gemm_mma import GemmMMA from .gemm_wgmma import GemmWGMMA +from .gemm_tcgen05 import GemmTCGEN5 from tilelang import _ffi_api @@ -122,7 +123,7 @@ def _get_implementation_class(self, gemm_inst: GemmInst): elif gemm_inst.is_wgmma(): return GemmWGMMA elif gemm_inst.is_tcgen5mma(): - raise NotImplementedError("TCGEN5MMA is not implemented") + return GemmTCGEN5 elif gemm_inst.is_mfma(): raise NotImplementedError("MFMA is not implemented") else: diff --git a/tilelang/tileop/gemm/gemm_base.py b/tilelang/tileop/gemm/gemm_base.py index 849b6d33a..f355eaf95 100644 --- a/tilelang/tileop/gemm/gemm_base.py +++ b/tilelang/tileop/gemm/gemm_base.py @@ -118,3 +118,15 @@ def wg_wait(self) -> int: @property def policy(self) -> GemmWarpPolicy: return self.gemm_node.policy + + @property + def mbarptr(self) -> PrimExpr: + return getattr(self.gemm_node, "mbarptr", tvm.tir.const(0, "uint32")) + + @property + def C_coords(self): + coords = getattr(self.gemm_node, "C_coords", None) + if coords is None or len(coords) == 0: + zero = tvm.tir.const(0, "int32") + return [zero, zero] + return [coords[i] for i in range(len(coords))] diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py new file mode 100644 index 000000000..b844b9ee6 --- /dev/null +++ b/tilelang/tileop/gemm/gemm_tcgen05.py @@ -0,0 +1,121 @@ +from typing import Tuple + +from .gemm_base import GemmBase +from tilelang.layout import make_tcgen05mma_swizzled_layout +from tilelang.intrinsics.tcgen05_macro_generator import ( + TensorCoreIntrinEmitter,) +from tilelang import language as T +from tilelang.transform.simplify import _Simplify +from tvm import tir +from tvm.tir import analysis +from tvm.target import Target + + +_FLOAT8_DTYPES = { + "float8_e4m3", + "float8_e4m3fn", + "float8_e4m3fnuz", + "float8_e5m2", + "float8_e5m2fn", + "float8_e5m2fnuz", +} + +class GemmTCGEN5(GemmBase): + + def infer_layout(self, target: Target, thread_nums: int): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, + True) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + ) + a_is_k_major = not self.trans_A + b_is_k_major = self.trans_B + + if self.is_gemm_ss(): + + a_continuity = self.M if a_is_k_major else 4 * self.K // m_warp + b_continuity = self.K if b_is_k_major else self.N // n_warp + + return { + # WGMMA does not support padding + self.A: + make_tcgen05mma_swizzled_layout( + self.A, continuity=a_continuity, k_major=a_is_k_major), + self.B: + make_tcgen05mma_swizzled_layout( + self.B, continuity=b_continuity, k_major=b_is_k_major), + self.C: + mma_emitter.make_mma_store_layout(self.C), + } + # No special swizzle requirement; rely on existing layout. + return {} + + def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var): + m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, + True) + warp_row_tiles = int(self.M // m_warp) + warp_col_tiles = int(self.N // n_warp) + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=self.in_dtype, + b_dtype=self.in_dtype, + accum_dtype=self.accum_dtype, + a_transposed=self.trans_A, + b_transposed=self.trans_B, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=self.chunk, + ) + if not self.is_gemm_ss(): + raise ValueError( + f"TCGEN5MMA currently only supports gemm_ss, got " + f"A scope {self.A.scope()}, B scope {self.B.scope()}") + + atom_m, atom_n, atom_k = mma_emitter.get_tcgen5_mma_meta( + self.M, self.N, self.K) + + if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}: + raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}") + if self.B.scope() not in {"shared", "shared.dyn"}: + raise ValueError(f"Unsupported B scope for TCGEN5MMA: {self.B.scope()}") + if self.C.scope() != "shared.tmem": + raise ValueError(f"TCGEN5MMA expects C in shared.tmem, got {self.C.scope()}") + if self.wg_wait != -1: + raise ValueError("TCGEN5MMA currently requires wg_wait == -1") + + mbarptr = self.mbarptr + if mbarptr == 0: + raise ValueError("TCGEN5MMA requires a valid mbarrier pointer") + + C_coords = self.C_coords + if len(C_coords) != 2: + raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access") + + accum_dtype = str(self.C.dtype) + if accum_dtype != "float32": + raise ValueError( + f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}") + + A_shared = self.A + B_shared = self.B + C_local = self.C + clear_accum = self.clear_accum + mbar = self.mbarptr + + @T.prim_func + def _gemm_ss() -> None: + mma_emitter.tcgen05mma(A_shared, B_shared, C_local, mbar, clear_accum) + + return _Simplify(_gemm_ss, inline_let=True) diff --git a/tilelang/utils/__init__.py b/tilelang/utils/__init__.py index f50aa8567..7edc4bec7 100644 --- a/tilelang/utils/__init__.py +++ b/tilelang/utils/__init__.py @@ -6,6 +6,7 @@ is_global, # noqa: F401 is_shared, # noqa: F401 is_shared_dynamic, # noqa: F401 + is_tensor_memory, # noqa: F401 is_fragment, # noqa: F401 is_local, # noqa: F401 array_reduce, # noqa: F401 diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index 2c0b4efad..00f79591e 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -52,6 +52,19 @@ def is_shared_dynamic(buffer: Buffer) -> bool: return buffer.scope() == "shared.dyn" +def is_tensor_memory(buffer: Buffer) -> bool: + """ + Check if the buffer is in tensor memory scope (e.g., shared.tmem). + + Args: + buffer (Buffer): The TVM buffer to check. + + Returns: + bool: True if the buffer is in tensor memory, False otherwise. + """ + return buffer.scope().startswith("shared.tmem") + + def is_local(buffer: Buffer) -> bool: """ Check if the buffer is in the local memory scope.