Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Oct 12, 2025

This pull request introduces several improvements and new features to the CUDA code generation and tensor core support in the codebase. The most significant changes include the addition of support for the new warpgroup_fence_operand intrinsic, refactoring and modernization of MMA and WGMMA code generation, and updates to header management for generated CUDA code. These changes enhance the maintainability, extensibility, and correctness of the code, particularly in handling tensor core operations and their associated data types.

Tensor Core and MMA/WGMMA Code Generation Improvements

  • Added a new header file src/tl_templates/cuda/instruction/mma.h implementing a generic, extensible MMA dispatcher and kernel for various data types and layouts, leveraging CUTE library intrinsics. This enables unified and type-safe MMA code generation for multiple tensor core configurations.
  • Refactored code generation for MMA and WGMMA intrinsics in src/target/codegen_cuda.cc to use the new dispatcher and type helpers, replacing previous string-based assembly generation with type-safe C++ calls. This includes new logic for argument parsing and type selection. [1] [2] [3]

New Intrinsic: warpgroup_fence_operand

  • Introduced the warpgroup_fence_operand intrinsic in both the operator registry (src/op/builtin.cc, src/op/builtin.h) and CUDA codegen, allowing explicit fencing of accumulator operand registers for upcoming WGMMA operations. This includes C++ overloads for both uint32_t* and float* register types. [1] [2] [3] [4]

Header Management and Codegen Infrastructure

  • Updated codegen logic and class members to conditionally include new instruction headers (mma.h, wgmma.h) only when needed, reducing unnecessary includes and improving compilation efficiency. [1] [2]

Type and Argument Handling

  • Added a helper function GetMMARegisterType to map PTX data types to C++ register types for MMA fragments, improving type safety and correctness in generated code. [1] [2]
  • Updated argument handling and documentation for WGMMA and MMA intrinsics to reflect new layouts, data type parsing, and argument order, ensuring correctness and clarity. [1] [2]

Correctness and Validation

  • Added an explicit check in GEMM lowering to ensure that the A operand is not transposed when using the gemm_rs path, preventing invalid configurations.

Summary by CodeRabbit

  • New Features

    • Added warpgroup_fence_operand API, TCGEN5 MMA support (new TCGEN5 MMA path, descriptor initializer, and swizzled layout helper), wgmma_rs endpoint, and a utility to map MMA register types.
  • Refactor

    • Emission of MMA/WGMMA now uses a dispatcher/template-based approach with generated instruction helpers and added warpgroup fence sequencing.
  • Bug Fixes

    • Enforced non‑transposed A in a GEMM lowering path and improved dtype error handling.
  • Chores

    • Unified dtype lookup in JIT wrappers and added optional barrier buffer support to gemm_v2.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 12, 2025

Walkthrough

Adds TCGEN5 MMA and warpgroup-fence support across TileLang: new intrinsics, descriptors, layouts, codegen templates, and GEMM selection; refactors MMA/WGMMA emission to dispatcher/templates and updates Python/TIR/JIT bindings and utilities to expose and validate the new paths.

Changes

Cohort / File(s) Summary
Builtins & registration
src/op/builtin.cc, src/op/builtin.h
Add tl.ptx_tcgen05_mma_ss and tl.warpgroup_fence_operand; rename initialize_descriptorinitialize_wgmma_descriptor; add initialize_tcgen05_descriptor.
Python/TIR builtin wrappers
tilelang/language/builtin.py, tilelang/language/tir/op.py, tilelang/language/ast/ir.py, tilelang/language/tir/ir.py
Add warpgroup_fence_operand, add/rename descriptor initializers, add TIR wrapper ptx_tcgen05_mma_ss, update ptx_wgmma_rs signature/arity.
CUDA codegen & PTX helpers
src/target/codegen_cuda.cc, src/target/codegen_cuda.h, src/target/ptx.cc, src/target/ptx.h
Emit tl::warpgroup_fence_operand calls; switch MMA/WGMMA/TCGEN05 emission to replacer/template-based calls; add need_*instruction_h flags; add GetMMARegisterType and new float8 aliases.
CUDA device intrinsics & common descriptors
src/tl_templates/cuda/intrin.h, src/tl_templates/cuda/common.h, src/tl_templates/cuda/instruction/tcgen05mma.h
Add warpgroup_fence_operand device wrappers; add Tcgen05Descriptor and initialize_tcgen05_descriptor; rename initialize_descriptorinitialize_wgmma_descriptor; add tcgen05 header skeleton.
MMA/WGMMA template refactor
src/tl_templates/cuda/instruction/mma.h, src/tl_templates/cuda/instruction/wgmma.h
Introduce MmaDispatcher/mma_sync dispatch machinery; replace many WGMMA specializations with macro-driven Impl dispatch; add wgmma_rs (RS path) and scale/validation machinery.
TCGEN5 meta/descriptor & layout plumbing
src/op/tcgen5_meta.h, src/layout/layout.cc, tilelang/layout/swizzle.py, tilelang/layout/__init__.py
Add TCGEN5 meta helpers and instr descriptor builder; expose make_tcgen05mma_swizzled_layout via FFI and python re-exports.
GEMM selection & lowering
src/op/gemm.cc, src/op/gemm_py.cc, src/op/gemm_py.h, tilelang/tileop/gemm/__init__.py
Add TCGEN5 selection/support, include TCGEN5 meta, new GemmPy/GemmInst hooks (AllowTCGEN5MMA, AllowWGMMA), add GemmTCGEN5 routing and enum changes.
TileLang TCGEN5 emitter & gemm op
tilelang/intrinsics/tcgen05_macro_generator.py, tilelang/tileop/gemm/gemm_tcgen05.py
New TensorCoreIntrinEmitter for TCGEN5, SwizzleMode, and new GemmTCGEN5 implementation (infer_layout, lower).
Macro generators & fence/synchronization
tilelang/intrinsics/wgmma_macro_generator.py, tilelang/intrinsics/mma_macro_generator.py
wgmma emitter: switch to read descriptors, compute atom/register counts, add warpgroup_fence_operand and warpgroup sync calls, adjust swizzle/offset math; add _get_dtype_abbrv helper.
TileLang JIT/type maps & utils
tilelang/jit/adapter/wrapper.py, tilelang/utils/language.py, tilelang/utils/__init__.py
Add _lookup_type helpers, extend dtype maps (e.g., float8_e4m3fn), add is_tensor_memory helper and export.
TileLang API changes
tilelang/language/gemm.py, tilelang/tileop/gemm/gemm_base.py
Add optional mbar to gemm_v2 and expose mbarptr / C_coords properties on GemmBase to propagate barrier info.
Transforms & tests
src/transform/inject_fence_proxy.cc, src/transform/lower_shared_tmem.cc, testing/python/transform/test_tilelang_transform_inject_fence_proxy.py, docs/...
Treat new descriptor initializers as known-generic ops for inject_fence_proxy; preserve/adjust remap logic in tmem lowering; update tests/docs to use initialize_wgmma_descriptor.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant TL as tilelang.language.builtin
  participant TIR as TIR Op
  participant CG as codegen_cuda
  participant Device as tl intrin (device)
  participant GPU

  User->>TL: warpgroup_fence_operand(buf_or_ptr, offset, num_regs, dtype)
  TL->>TIR: call_intrin tl.warpgroup_fence_operand(...)
  TIR->>CG: Visit CallNode for tl.warpgroup_fence_operand
  CG->>Device: emit tl::warpgroup_fence_operand(ptr+offset, num_regs)
  Device->>GPU: cute::warpgroup_fence_operand(reg[i]) loop
Loading
sequenceDiagram
  autonumber
  participant Front as tilelang.language.tir.op
  participant CG as codegen_cuda
  participant PTX as ptx helpers
  participant Tpl as tl templates (wgmma/mma/tcgen05)
  participant GPU

  Front->>CG: ptx_* MMA/WGMMA/TCGEN05 op call (e.g., ptx_wgmma_rs / ptx_tcgen05_mma_ss)
  CG->>PTX: DType parsing & GetMMARegisterType
  CG->>CG: Build template call via Replacer, mark need_*_instruction_h_
  CG->>GPU: emit templated tl::<mma/wgmma/tcgen05> call
  GPU->>Tpl: tl wrapper -> Impl::fma / dispatcher
  Tpl->>GPU: Execute instruction path (includes warpgroup fences/syncs)
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • chengyupku
  • xysmlx

Poem

A rabbit hops across the tree,
Builds fences where the regs will be.
Dispatchers hum, templates tight,
Warpgroups sync through day and night.
Code blooms — kernels leap in flight! 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.75% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title highlights the addition of the warpgroup_fence_operand intrinsic, which is indeed part of the pull request, but it does not reflect the larger scope of new generic MMA dispatch infrastructure, code generation refactors, and helper additions; it is therefore only partially related to the overall set of changes. According to the criteria, a title that refers to a real aspect of the changeset but does not fully summarize the main enhancements still meets the pass condition.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/jit/adapter/wrapper.py (1)

180-196: Add CUDA mapping for float8_e4m3fn.

_lookup_type now raises an assertion when the scheduled module contains buffers with dtype "float8_e4m3fn" because _TYPE_MAP lacks an entry in the CUDA wrapper. NVRTC and HIP wrappers define this mapping, so the CUDA path should do the same; otherwise, ahead-of-time builds using this dtype fail during host stub generation. Please mirror the new dtype mapping here.

         "float8_e4m3": "fp8_e4_t",
+        "float8_e4m3fn": "fp8_e4_t",
         "float8_e5m2": "fp8_e5_t",
🧹 Nitpick comments (7)
tilelang/intrinsics/mma_macro_generator.py (1)

106-110: Safer dtype lookup with explicit ValueError

Good switch to a guarded helper; clearer error than a raw KeyError.

Consider reusing _get_dtype_abbrv in TensorCoreIntrinEmitterWithLadderTransform._initialize_abbrev for consistent error handling across subclasses.

Also applies to: 111-116

tilelang/language/builtin.py (1)

284-342: API shape and validation look solid

Pointer vs Buffer cases handled; num_regs derivation correct; error messaging clear.

Small polish: use TypeError for pointer-argument validation errors (dtype/num_regs missing) to distinguish from value-domain issues.

src/target/codegen_cuda.cc (1)

1625-1657: Remove unused asm_code (dead store) in wgmma_ss path

PrintWGMMAAssembly result is computed but not used. Drop it to avoid warnings and confusion.

-    std::string asm_code = PrintWGMMAAssembly(
-        shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc,
-        A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a,
-        scale_in_b, a_is_shared, "", "", "", false);
tilelang/intrinsics/wgmma_macro_generator.py (2)

184-186: Fence register count: verify per‑thread 32‑bit reg calc (or infer)

accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 must equal the number of 32‑bit registers owned by the calling thread for C_local_buf. If this multiplies in warp/block factors, the CUDA fence loop will index past the thread’s register array.

Safer alternative: when C_local_buf shape is static, omit num_regs and let T.warpgroup_fence_operand infer it from buffer shape/dtype.


339-343: Good: RS path fences and sync ordering

Fencing A and C around warpgroup_arrive/commit_batch/wait(0) is placed correctly. Minor: consider a build‑time flag to disable fences for performance A/B if future NVCC versions no longer need it.

Also applies to: 367-371

src/tl_templates/cuda/instruction/wgmma.h (2)

455-469: Public wrappers are thin and efficient

Template wrappers delegate to Impl::execute cleanly. Consider adding __forceinline__ (or equivalent) to TL_DEVICE if inlining becomes important on hot paths.


7-8: Nit: explicitly include

std::size_t, std::index_sequence, and std::extent_v are fine via current includes, but an explicit #include <cstddef> avoids transitive‑include reliance.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b0b5347 and 4a96032.

📒 Files selected for processing (15)
  • src/op/builtin.cc (1 hunks)
  • src/op/builtin.h (2 hunks)
  • src/op/gemm.cc (1 hunks)
  • src/target/codegen_cuda.cc (5 hunks)
  • src/target/codegen_cuda.h (1 hunks)
  • src/target/ptx.cc (1 hunks)
  • src/target/ptx.h (1 hunks)
  • src/tl_templates/cuda/instruction/mma.h (1 hunks)
  • src/tl_templates/cuda/instruction/wgmma.h (2 hunks)
  • src/tl_templates/cuda/intrin.h (1 hunks)
  • tilelang/intrinsics/mma_macro_generator.py (1 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (5 hunks)
  • tilelang/jit/adapter/wrapper.py (7 hunks)
  • tilelang/language/builtin.py (2 hunks)
  • tilelang/language/tir/op.py (0 hunks)
💤 Files with no reviewable changes (1)
  • tilelang/language/tir/op.py
🧰 Additional context used
🧬 Code graph analysis (11)
src/op/builtin.cc (1)
tilelang/language/builtin.py (1)
  • warpgroup_fence_operand (284-341)
tilelang/language/builtin.py (3)
src/op/builtin.h (1)
  • tvm (13-482)
tilelang/language/ast/ir.py (1)
  • evaluate (1319-1331)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
src/tl_templates/cuda/intrin.h (1)
tilelang/language/builtin.py (1)
  • warpgroup_fence_operand (284-341)
src/op/gemm.cc (1)
tilelang/tileop/gemm/gemm_base.py (1)
  • trans_A (46-47)
src/target/ptx.h (2)
src/target/ptx.cc (2)
  • GetMMARegisterType (1532-1545)
  • GetMMARegisterType (1532-1532)
src/tl_templates/cuda/common.h (1)
  • DataType (180-225)
tilelang/intrinsics/mma_macro_generator.py (2)
tilelang/tileop/gemm/gemm_base.py (1)
  • accum_dtype (59-60)
tilelang/primitives/gemm/gemm_mma.py (1)
  • accum_dtype (252-259)
src/tl_templates/cuda/instruction/mma.h (2)
src/tl_templates/cuda/instruction/wgmma.h (6)
  • tl (10-96)
  • detail (17-85)
  • void (47-53)
  • void (74-82)
  • void (92-95)
  • void (103-106)
src/target/ptx.h (1)
  • DataType (45-90)
src/target/codegen_cuda.cc (3)
src/target/codegen_cuda.h (1)
  • need_wgmma_instruction_h_ (112-149)
tilelang/language/builtin.py (1)
  • warpgroup_fence_operand (284-341)
src/target/ptx.cc (10)
  • DTypeFromString (56-106)
  • DTypeFromString (56-56)
  • ParseMMAShape (142-150)
  • ParseMMAShape (142-142)
  • DTypeEnumToString (108-110)
  • DTypeEnumToString (108-108)
  • DTypeEnumToString (112-115)
  • DTypeEnumToString (112-112)
  • GetMMARegisterType (1532-1545)
  • GetMMARegisterType (1532-1532)
src/op/builtin.h (1)
tilelang/language/builtin.py (1)
  • warpgroup_fence_operand (284-341)
tilelang/intrinsics/wgmma_macro_generator.py (3)
src/target/ptx.h (1)
  • DataType (45-90)
tilelang/language/builtin.py (5)
  • initialize_descriptor (452-483)
  • warpgroup_fence_operand (284-341)
  • warpgroup_arrive (253-259)
  • warpgroup_commit_batch (262-268)
  • warpgroup_wait (271-281)
tilelang/language/tir/op.py (1)
  • ptx_wgmma_rs (1106-1141)
src/tl_templates/cuda/instruction/wgmma.h (4)
src/tl_templates/cuda/instruction/mma.h (3)
  • tl (10-148)
  • detail (17-132)
  • void (62-66)
src/tl_templates/cuda/intrin.h (5)
  • tl (8-78)
  • void (10-10)
  • void (11-11)
  • void (17-22)
  • void (24-29)
src/tl_templates/cuda/common.h (4)
  • tl (171-265)
  • int (97-100)
  • int (135-142)
  • DataType (180-225)
tilelang/intrinsics/wgmma_macro_generator.py (1)
  • wgmma_rs (275-372)
🪛 Ruff (0.13.3)
tilelang/language/builtin.py

308-308: Avoid specifying long messages outside the exception class

(TRY003)


314-314: Avoid specifying long messages outside the exception class

(TRY003)


322-323: Prefer TypeError exception for invalid type

(TRY004)


322-323: Avoid specifying long messages outside the exception class

(TRY003)


329-329: Avoid specifying long messages outside the exception class

(TRY003)


331-331: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/mma_macro_generator.py

115-115: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-metal
  • GitHub Check: build-test-amd
🔇 Additional comments (17)
src/op/gemm.cc (1)

585-586: LGTM! Good correctness guard for gemm_rs path.

This validation check appropriately enforces the constraint that the gemm_rs code path requires a non-transposed A operand when A resides in local fragment storage. The check is well-placed in the lowering logic, and the error message clearly communicates the requirement.

src/target/ptx.h (1)

272-276: Declaration placement and signature look good

Matches usage in codegen; takes ptx::DataType by const-ref and returns std::string. No concerns.

src/target/ptx.cc (1)

1532-1546: Helper implementation is correct and minimal

Mappings align with expected MMA register classes; defaulting others to unsigned is acceptable.

src/tl_templates/cuda/intrin.h (1)

17-22: Correct warpgroup_fence_operand wrappers

Overloads for uint32_t* and float* call cute::warpgroup_fence_operand per element; matches codegen cast logic.

Also applies to: 24-29

tilelang/language/builtin.py (1)

7-9: Required imports added

Needed for dtype bit math and PrimExpr conversion. LGTM.

src/target/codegen_cuda.cc (2)

262-267: Conditional includes for MMA/WGMMA headers

Good gating; reduces unnecessary includes.


1392-1407: warpgroup_fence_operand emission aligns with intrin overloads

Cast to float for f32/tf32, otherwise uint32_t; pointer arithmetic and count forwarded. LGTM.

tilelang/intrinsics/wgmma_macro_generator.py (5)

249-254: Good: descriptors use read pointers; pre‑fence C accumulators

Switching to access_ptr("r") for A/B descriptors is correct. Pre‑fencing C_local_buf before warpgroup_arrive helps prevent NVCC code motion across WGMMA.


271-272: Good: post‑WGMMA fence on C

Fencing C_local_buf after commit_batch + wait(0) is appropriate to stop sinks of accumulator uses.


297-307: RS path: verify dtype→register mapping for fences

a_regs/accum_regs look consistent, but warpgroup_fence_operand must pick the right register pointer type (uint32_t* vs float*). For A in f16/bf16 paths, codegen should map to a 32‑bit register container (e.g., uint32_t) rather than half.

Please confirm codegen for tl.warpgroup_fence_operand uses the same GetMMARegisterType mapping as MMA/WGMMA fragments.


310-332: B descriptor stride/LBO updates look consistent with swizzle rules

Default and swizzled b_leading_byte_offset/b_stride_byte_offset logic reads correct, including the n_dim == 8 special case and K‑major swizzled stride. No issues spotted.


333-349: BK atom sizing and B_offset formula look sound

bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) and the K‑major/N‑major B_offset expressions are coherent with atomized addressing. Please double‑check off‑by‑one at ki boundaries in tests for swizzled layouts.

src/tl_templates/cuda/instruction/wgmma.h (5)

24-33: Scale helpers are correct and constrained

ScaleInValue and IsValidScale ensure only ±1 scaling; good defensive checks.


34-53: SS call shim is safe and minimal

Index‑sequence expansion and reinterpret_cast to CReg* with 32‑bit size assert are appropriate. No issues.


56-83: RS call shim: good type checks for A/C regs

Asserts on 32‑bit A and 32‑bit C (uint32_t or float) guard misuse; expansion path is clear.


98-107: Fallback static_asserts provide clear diagnostics

Unspecialized Wgmma*Impl::execute fails compilation with informative messages. Good.


163-181: RS Impl guards K‑major A at compile time

static_assert(!tnspA, ...) enforces the RS contract early. Nice.

Comment on lines 181 to 186
a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
b_swizzle_atom_elems = n_dim if b_swizzle_mode.is_none(
) else b_swizzle_mode.swizzle_byte_size() // elems_in_bytes
accum_bits = DataType(accum_dtype).bits
accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Non‑swizzle A atom size can become 0 → wrong A_offset

For SwizzleMode.NONE, a_swizzle_mode.swizzle_byte_size() returns 1; with elems_in_bytes >= 2, a_swizzle_atom_elems = 1 // elems_in_bytes becomes 0. This zero is used in A_offset terms, collapsing offsets to 0.

Guard against 0 in the non‑swizzle case. Minimal safe fix:

- a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
+ a_swizzle_atom_elems = (
+    (a_swizzle_mode.swizzle_byte_size() // elems_in_bytes)
+    if not a_swizzle_mode.is_none()
+    else 1  # ensure >=1 for non-swizzle
+)

Optionally choose a more accurate non‑swizzle atom size if required by layout, but it must be ≥1.

Also applies to: 257-259

🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around lines 181 to 186 (and
similarly at lines 257-259), the computed swizzle atom element count can
evaluate to 0 for the NON swizzle case when swizzle_byte_size() returns 1 and
elems_in_bytes >= 2; change the computation to guard against zero by ensuring
the atom element count is at least 1 (e.g., if swizzle_mode.is_none() then set
a_swizzle_atom_elems = 1 otherwise use the computed value or simply use max(1,
computed_value)); apply the same guard to b_swizzle_atom_elems and the duplicate
code at 257-259 so no offset calculations use a zero atom size.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a60701b and 9fd79eb.

📒 Files selected for processing (2)
  • src/target/codegen_cuda.cc (5 hunks)
  • src/target/ptx.cc (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/target/codegen_cuda.cc
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-metal
  • GitHub Check: build-test-amd
🔇 Additional comments (1)
src/target/ptx.cc (1)

77-80: LGTM: FP8 alias additions are consistent.

The new "float8_e4m3" and "float8_e5m2" aliases follow the established pattern and provide clearer type names.

Comment on lines +1532 to +1545
std::string GetMMARegisterType(const ptx::DataType &dtype) {
switch (dtype) {
case ptx::DataType::kInt32:
return "unsigned";
case ptx::DataType::kUInt32:
return "unsigned";
case ptx::DataType::kFloat32:
return "float";
case ptx::DataType::kFloat64:
return "double";
default:
return "unsigned";
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Verify limited type coverage and consider error handling.

The function handles only 4 data types (kInt32, kUInt32, kFloat32, kFloat64) but MMA/WGMMA operations support many more: kFloat16, kBFloat16, kInt8, kUInt8, kInt4, kUInt4, kFloat8_e4m3, kFloat8_e5m2, etc. (see valid_mma_configs/valid_wgmma_configs above).

Concerns:

  • kInt32 mapped to "unsigned": Signed int32 is mapped to "unsigned", which may be incorrect depending on usage context.
  • Silent fallback: The default case returns "unsigned" for all unsupported types instead of failing, which could mask errors if the function is called with types not yet supported by the new dispatcher.

Recommendations:

  • Verify whether the limited type coverage is intentional for the initial implementation.
  • Consider using LOG(FATAL) or CHECK(false) in the default case to catch unsupported types explicitly during development.
  • If the mapping of kInt32 to "unsigned" is intentional for PTX register constraints, add a clarifying comment.

Run the following script to verify if GetMMARegisterType is called with types beyond the four it handles:


🏁 Script executed:

#!/bin/bash
# Search for calls to GetMMARegisterType and analyze argument types
rg -n -C5 'GetMMARegisterType\(' --type=cpp

Length of output: 2440


🏁 Script executed:

#!/bin/bash
rg -nP "enum\s+DataType" -C5 --type=cpp

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Search for DataType enum definition in src/target
rg -nP "enum\s+(class\s+)?DataType" -C5 src/target
rg -n "DataType" -C2 src/target/ptx.h

Length of output: 1150


Improve type coverage and enforce unsupported types in GetMMARegisterType.

  • Extend the switch to cover all ptx::DataType enum values (e.g. kInt4, kUInt4, kInt8, kUInt8, kInt16, kUInt16, kFloat16, kBFloat16, kFloat8_e4m3, kFloat8_e5m2, etc.).
  • Replace the silent default with a runtime failure (LOG(FATAL) or CHECK(false)) to catch unsupported types.
  • If mapping signed kInt32 to "unsigned" is intentional, document the rationale with an explanatory comment.
🤖 Prompt for AI Agents
In src/target/ptx.cc around lines 1532 to 1545, the GetMMARegisterType switch
only covers a few ptx::DataType values, returns "unsigned" for signed kInt32,
and falls back to a silent default; update the switch to explicitly handle every
ptx::DataType enum value (kInt4, kUInt4, kInt8, kUInt8, kInt16, kUInt16, kInt32,
kUInt32, kFloat16, kBFloat16, kFloat32, kFloat64, kFloat8_e4m3, kFloat8_e5m2,
etc.) with the correct string mapping, replace the default branch with a runtime
failure (e.g., LOG(FATAL) or CHECK(false)) to surface unsupported types, and add
a short comment explaining why kInt32 maps to "unsigned" if that mapping is
intentional.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/language/builtin.py (1)

530-531: Correctness: operator-precedence bug can crash on BufferLoad

The condition mixes and/or without parentheses. For a BufferLoad, descriptor.shape access will raise. Guard both shape checks under the Buffer branch.

Apply this diff:

-    if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1:
-        raise ValueError("Descriptor must be a 1D buffer of size 1.")
+    if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1):
+        raise ValueError("Descriptor must be a 1D buffer of size 1.")
♻️ Duplicate comments (1)
tilelang/intrinsics/wgmma_macro_generator.py (1)

182-183: Non‑swizzle A atom size can become 0 → A_offset collapses to 0.

When SwizzleMode.NONE and elems_in_bytes >= 2, swizzle_byte_size() is 1 and a_swizzle_atom_elems computes to 0. This makes terms using it zero, collapsing A_offset to 0. Guard to ensure it’s ≥1.

Apply this minimal fix:

-        a_swizzle_atom_elems = a_swizzle_mode.swizzle_byte_size() // elems_in_bytes
+        a_swizzle_atom_elems = max(1, a_swizzle_mode.swizzle_byte_size() // elems_in_bytes)

This keeps non‑swizzle addressing correct while retaining swizzled cases.

Also applies to: 257-259

🧹 Nitpick comments (24)
tilelang/utils/__init__.py (1)

9-9: Remove unused noqa directive.

The # noqa: F401 comment is unnecessary since the F401 check is not enabled in your linter configuration.

Apply this diff:

-    is_tensor_memory,  # noqa: F401
+    is_tensor_memory,
tilelang/layout/__init__.py (1)

9-9: Remove unused noqa directive.

The # noqa: F401 comment is unnecessary since the F401 check is not enabled in your linter configuration.

Apply this diff:

-    make_tcgen05mma_swizzled_layout,  # noqa: F401
+    make_tcgen05mma_swizzled_layout,
tilelang/language/tir/ir.py (1)

296-296: Double-check dtype forwarding for tcgen05 MMA (expects 3 dtypes).

ptx_tcgen05_mma_ss takes a_dtype, b_dtype, c_dtype. _dtype_forward injects only one dtype (as first positional), which can misalign arguments or create ambiguous usage.

Recommend a dedicated wrapper that accepts either a single dtype (replicated to a/b/c) or a (a, b, c) tuple.

Apply:

 def _dtype_forward(func):
@@
     return wrapped
+
+def _dtype3_forward(func):
+    @functools.wraps(func)
+    def wrapped(*args, **kwargs):
+        if "dtype" in kwargs:
+            dt = kwargs.pop("dtype")
+            if isinstance(dt, (tuple, list)) and len(dt) == 3:
+                args = (dt[0], dt[1], dt[2]) + args
+            else:
+                args = (dt, dt, dt) + args
+        return func(*args, **kwargs)
+    return wrapped
@@
-ptx_tcgen05_mma_ss = _dtype_forward(_tir_op.ptx_tcgen05_mma_ss)
+ptx_tcgen05_mma_ss = _dtype3_forward(_tir_op.ptx_tcgen05_mma_ss)

Please confirm the intended call convention (single dtype vs per-operand dtypes) and adjust docs accordingly.

tilelang/layout/swizzle.py (1)

37-49: Fix Optional typing for continuity (RUF013).

Annotate continuity as Optional to avoid implicit Optional.

Based on static analysis hints

-def make_tcgen05mma_swizzled_layout(buffer: tvm.tir.Buffer,
-                               continuity: int = None,
-                               k_major: bool = True):
+def make_tcgen05mma_swizzled_layout(
+    buffer: tvm.tir.Buffer, continuity: Optional[int] = None, k_major: bool = True
+):

Also add (outside this hunk):

from typing import Optional

Optionally, mirror this change for make_wgmma_swizzled_layout for consistency.

tilelang/tileop/gemm/gemm_tcgen05.py (4)

64-67: Silence unused args without changing API.

Keep override-compatible signature; mark unused to satisfy linters.

     def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
         m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target,
                                                         True)
+        # keep signature for override/keyword-arg compatibility; silence linters
+        _ = (layout_map, thread_var)

86-88: Silence unused meta tuple.

Prefix with underscores to avoid lint warnings.

-        atom_m, atom_n, atom_k = mma_emitter.get_tcgen5_mma_meta(
+        _atom_m, _atom_n, _atom_k = mma_emitter.get_tcgen5_mma_meta(
             self.M, self.N, self.K)

51-60: Update comment to match TCGEN5 context.

Mentions WGMMA in a TCGEN5 file.

-                # WGMMA does not support padding
+                # TCGEN5MMA does not support padding

1-2: Remove unused import.

Tuple is not used.

-from typing import Tuple
src/target/codegen_cuda.cc (2)

1630-1663: Remove unused variable to avoid warnings

asm_code is built then discarded. Drop it or guard with debug ifdef.

-    std::string asm_code = PrintWGMMAAssembly(
-        shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc,
-        A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a,
-        scale_in_b, a_is_shared, "", "", "", false);
+    // Legacy inline-PTX path retained for reference; current emission uses C++ templates.
+    // auto asm_code = PrintWGMMAAssembly(...);

1701-1732: Confirm register-pointer casts for RS path

You cast A/C to uint32_t*. If tl::wgmma_rs expects typed accumulator/register pointers (e.g., float* for f32 accum), consider mirroring MMA’s CRegType mapping. If the template is defined on uint32_t* this is fine; otherwise add dtype-based casts.

If required, adapt like:

- reinterpret_cast<uint32_t*>((C_ptr) + (C_offset))
+ reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset))

and register (CRegType) via GetMMARegisterType(dtype_c_enum), analogous to the MMA path.

tilelang/language/tir/op.py (1)

1160-1166: Replace “×” with “x” in docstring

Non-ASCII × triggers lint (RUF002) and may render poorly. Use plain "x".

-    """TVM intrinsic for tcgen05.mma shared-memory × shared-memory instructions.
+    """TVM intrinsic for tcgen05.mma shared-memory x shared-memory instructions.
src/tl_templates/cuda/common.h (2)

267-314: Tcgen05Descriptor union is well-formed; consider clarifying layout constants

Definition and decay operator are fine. Optionally add constexpr named constants for the magic bits (bit 48 and 0xB0 at 53) to document intent.

+// Descriptor signature constants (per SM100 spec)
+constexpr uint32_t kTcgen05SigBit = 1u << (48 - 32);
+constexpr uint32_t kTcgen05SigTag = 0xB0u << (53 - 32);
 ...
-  descriptor.words.hi |= (1u << (48 - 32));
-  descriptor.words.hi |= (0xB0u << (53 - 32));
+  descriptor.words.hi |= kTcgen05SigBit;
+  descriptor.words.hi |= kTcgen05SigTag;

363-383: initialize_tcgen05_descriptor LGTM; add a static_assert on field widths (optional)

Bitfield writes and signature bits look correct. For maintainability, consider static_asserts on swizzle_mode/base_offset ranges.

+  // Optional sanity checks (compile-time when constants)
+  // static_assert((swizzle_mode & ~0x7) == 0, "swizzle_mode must be 3 bits");
+  // static_assert((base_offset & ~0x7) == 0, "base_offset must be 3 bits");
src/op/builtin.h (1)

241-247: Update ptx_wgmma_rs comment to reflect new signature

Docstring still mentions a_is_k_major; new API only takes b_is_k_major. Adjust to avoid confusion.

- *  void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix,
- * bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
+ *  void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix,
+ *                    bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
tilelang/language/builtin.py (3)

284-342: Solid API for warpgroup_fence_operand; minor robustness nits

Looks correct and matches op signature (dtype, ptr, offset, num_regs). Two small improvements:

  • Accept tvm.DataType objects for dtype as well as str (normalize to str).
  • Consider permitting symbolic shapes by requiring explicit num_regs, which you already enforce; message could hint the remedy.

No blockers.


452-480: Descriptor init rename and validation LGTM; tighten typing

Behavior and Buffer/BufferLoad handling are good. Adjust the type hints to reflect accepted types.

Apply this diff to the signature and docstring:

-def initialize_wgmma_descriptor(
-    descriptor: Buffer,
+def initialize_wgmma_descriptor(
+    descriptor: Union[Buffer, BufferLoad],
     start_address: PrimExpr,

482-514: TCGEN05 descriptor init LGTM; typing consistency

Validation and argument shaping are consistent with WGMMA version. Consider the same type-hint polish.

Apply this diff:

-def initialize_tcgen05_descriptor(
-    descriptor: Buffer,
+def initialize_tcgen05_descriptor(
+    descriptor: Union[Buffer, BufferLoad],
     start_address: PrimExpr,
tilelang/intrinsics/tcgen05_macro_generator.py (6)

181-183: Drop unused variables

accum_bits/accum_regs are computed but unused.

Apply this diff:

-        accum_bits = DataType(accum_dtype).bits
-        accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32
+        # accum_regs not needed for tcgen05 path

149-155: Unused argument mbar

mbar is unused. Rename to _mbar (keep signature stability) or remove if callers don’t pass it.

Apply this diff:

-    def tcgen05mma(self,
+    def tcgen05mma(self,
               A_buf: Buffer,
               B_buf: Buffer,
               C_local_buf: Buffer,
-              mbar,
+              _mbar,
               clear_accum: PrimExpr = False):

97-111: n_dim param in _initialize_tcgen05_prefix is unused

Simplify signature and call.

Apply this diff:

-        self._initialize_tcgen05_prefix(self.n_dim)
+        self._initialize_tcgen05_prefix()

And:

-    def _initialize_tcgen05_prefix(self, n_dim: int = 16):
+    def _initialize_tcgen05_prefix(self):

74-76: Type hints: Optional for layouts

Fields can be None; mark as Optional for clarity.

Apply this diff:

-    a_shared_layout: Layout = None
-    b_shared_layout: Layout = None
+    a_shared_layout: Optional[Layout] = None
+    b_shared_layout: Optional[Layout] = None

324-326: TODO clarity

make_mma_load_layout is unimplemented. Add a brief TODO or raise a clearer error to guide users.

Apply this diff:

-    def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment:
-       raise NotImplementedError
+    def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragment:
+        raise NotImplementedError("TCGEN5 load layout is not implemented yet; implement when ready.")

25-63: Avoid duplicating SwizzleMode across modules

This mirrors wgmma_macro_generator.SwizzleMode. Consider reusing a single enum to prevent drift.

tilelang/intrinsics/wgmma_macro_generator.py (1)

353-353: Minor consistency nit: use the local variable.

You compute b_is_k_major = self.b_transposed, but pass self.b_transposed directly here. Use b_is_k_major for consistency/readability.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 73fa0af and 3e90be7.

📒 Files selected for processing (28)
  • docs/compiler_internals/inject_fence_proxy.md (3 hunks)
  • src/layout/layout.cc (1 hunks)
  • src/op/builtin.cc (3 hunks)
  • src/op/builtin.h (3 hunks)
  • src/op/gemm.cc (3 hunks)
  • src/op/gemm_py.cc (4 hunks)
  • src/op/gemm_py.h (6 hunks)
  • src/op/tcgen5_meta.h (1 hunks)
  • src/target/codegen_cuda.cc (6 hunks)
  • src/target/codegen_cuda.h (1 hunks)
  • src/tl_templates/cuda/common.h (3 hunks)
  • src/tl_templates/cuda/instruction/tcgen05mma.h (1 hunks)
  • src/transform/inject_fence_proxy.cc (1 hunks)
  • src/transform/lower_shared_tmem.cc (3 hunks)
  • testing/python/transform/test_tilelang_transform_inject_fence_proxy.py (1 hunks)
  • tilelang/intrinsics/tcgen05_macro_generator.py (1 hunks)
  • tilelang/intrinsics/wgmma_macro_generator.py (6 hunks)
  • tilelang/language/ast/ir.py (2 hunks)
  • tilelang/language/builtin.py (3 hunks)
  • tilelang/language/tir/ir.py (1 hunks)
  • tilelang/language/tir/op.py (1 hunks)
  • tilelang/layout/__init__.py (1 hunks)
  • tilelang/layout/swizzle.py (1 hunks)
  • tilelang/tileop/gemm/__init__.py (3 hunks)
  • tilelang/tileop/gemm/gemm_base.py (1 hunks)
  • tilelang/tileop/gemm/gemm_tcgen05.py (1 hunks)
  • tilelang/utils/__init__.py (1 hunks)
  • tilelang/utils/language.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • src/tl_templates/cuda/instruction/tcgen05mma.h
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/op/gemm.cc
🧰 Additional context used
🧬 Code graph analysis (23)
tilelang/utils/language.py (3)
src/transform/lower_shared_tmem.cc (1)
  • buffer (198-213)
tilelang/language/ast/ir.py (1)
  • buffer (93-161)
src/transform/legalize_safe_memory_access.cc (12)
  • buffer (80-88)
  • buffer (80-80)
  • buffer (91-130)
  • buffer (91-92)
  • buffer (236-240)
  • buffer (236-236)
  • buffer (242-245)
  • buffer (242-242)
  • buffer (247-250)
  • buffer (247-247)
  • buffer (252-257)
  • buffer (252-252)
tilelang/language/tir/ir.py (2)
tilelang/language/tir/op.py (1)
  • ptx_tcgen05_mma_ss (1144-1183)
tilelang/language/ast/ir.py (1)
  • _dtype_forward (1876-1884)
tilelang/utils/__init__.py (1)
tilelang/utils/language.py (1)
  • is_tensor_memory (55-65)
tilelang/layout/__init__.py (1)
tilelang/layout/swizzle.py (1)
  • make_tcgen05mma_swizzled_layout (37-49)
tilelang/layout/swizzle.py (3)
tilelang/language/ast/ir.py (1)
  • buffer (93-161)
src/op/builtin.h (1)
  • tvm (13-493)
src/layout/swizzle.h (1)
  • tvm (12-70)
src/layout/layout.cc (1)
src/layout/gemm_layouts.cc (2)
  • makeGemmABLayoutSm100 (768-788)
  • makeGemmABLayoutSm100 (768-769)
src/op/gemm_py.h (2)
src/op/gemm_py.cc (4)
  • AllowTCGEN5MMA (110-117)
  • AllowTCGEN5MMA (110-110)
  • AllowWGMMA (119-127)
  • AllowWGMMA (119-119)
tilelang/tileop/gemm/gemm_base.py (2)
  • mbarptr (123-124)
  • C_coords (127-132)
tilelang/tileop/gemm/gemm_base.py (2)
src/op/builtin.h (1)
  • tvm (13-493)
src/target/ptx.h (1)
  • tvm (32-278)
src/target/codegen_cuda.cc (4)
src/target/codegen_cuda.h (1)
  • need_tcgen05mma_instruction_h_ (114-151)
tilelang/language/builtin.py (3)
  • warpgroup_fence_operand (284-341)
  • initialize_wgmma_descriptor (452-479)
  • initialize_tcgen05_descriptor (482-513)
src/target/ptx.cc (10)
  • DTypeFromString (56-106)
  • DTypeFromString (56-56)
  • ParseMMAShape (142-150)
  • ParseMMAShape (142-142)
  • DTypeEnumToString (108-110)
  • DTypeEnumToString (108-108)
  • DTypeEnumToString (112-115)
  • DTypeEnumToString (112-112)
  • GetMMARegisterType (1532-1545)
  • GetMMARegisterType (1532-1532)
tilelang/language/tir/op.py (1)
  • ptx_tcgen05_mma_ss (1144-1183)
src/transform/lower_shared_tmem.cc (1)
src/transform/lower_tile_op.cc (2)
  • expr (433-445)
  • expr (433-433)
src/op/tcgen5_meta.h (1)
src/op/builtin.h (1)
  • tvm (13-493)
tilelang/language/builtin.py (3)
src/op/builtin.h (1)
  • tvm (13-493)
tilelang/language/ast/ir.py (1)
  • evaluate (1319-1331)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
tilelang/tileop/gemm/__init__.py (1)
tilelang/tileop/gemm/gemm_tcgen05.py (1)
  • GemmTCGEN5 (23-121)
tilelang/language/ast/ir.py (2)
tilelang/language/tir/op.py (1)
  • ptx_tcgen05_mma_ss (1144-1183)
tilelang/language/tir/ir.py (1)
  • _dtype_forward (156-164)
src/tl_templates/cuda/common.h (1)
tilelang/language/builtin.py (2)
  • initialize_wgmma_descriptor (452-479)
  • initialize_tcgen05_descriptor (482-513)
src/transform/inject_fence_proxy.cc (1)
tilelang/language/builtin.py (2)
  • initialize_wgmma_descriptor (452-479)
  • initialize_tcgen05_descriptor (482-513)
tilelang/intrinsics/wgmma_macro_generator.py (5)
tilelang/tileop/gemm/gemm_base.py (2)
  • clear_accum (107-108)
  • accum_dtype (59-60)
tilelang/language/builtin.py (5)
  • initialize_wgmma_descriptor (452-479)
  • warpgroup_fence_operand (284-341)
  • warpgroup_arrive (253-259)
  • warpgroup_commit_batch (262-268)
  • warpgroup_wait (271-281)
tilelang/intrinsics/tcgen05_macro_generator.py (4)
  • _determinate_swizzle_mode (136-147)
  • is_none (32-33)
  • swizzle_byte_size (44-52)
  • _warp_mma (271-320)
tilelang/language/allocate.py (1)
  • alloc_descriptor (158-164)
tilelang/language/tir/op.py (1)
  • ptx_wgmma_rs (1106-1141)
src/op/builtin.cc (2)
tilelang/language/tir/op.py (1)
  • ptx_tcgen05_mma_ss (1144-1183)
tilelang/language/builtin.py (3)
  • warpgroup_fence_operand (284-341)
  • initialize_wgmma_descriptor (452-479)
  • initialize_tcgen05_descriptor (482-513)
tilelang/tileop/gemm/gemm_tcgen05.py (4)
tilelang/tileop/gemm/gemm_base.py (20)
  • GemmBase (12-132)
  • infer_layout (15-16)
  • policy (119-120)
  • M (34-35)
  • N (38-39)
  • in_dtype (54-56)
  • accum_dtype (59-60)
  • trans_A (46-47)
  • trans_B (50-51)
  • chunk (63-64)
  • is_gemm_ss (21-22)
  • K (42-43)
  • A (67-68)
  • B (71-72)
  • C (75-76)
  • lower (18-19)
  • wg_wait (115-116)
  • mbarptr (123-124)
  • C_coords (127-132)
  • clear_accum (107-108)
tilelang/layout/swizzle.py (1)
  • make_tcgen05mma_swizzled_layout (37-49)
tilelang/intrinsics/tcgen05_macro_generator.py (4)
  • TensorCoreIntrinEmitter (66-423)
  • make_mma_store_layout (327-398)
  • get_tcgen5_mma_meta (401-402)
  • tcgen05mma (149-322)
tilelang/transform/simplify.py (1)
  • _Simplify (30-49)
src/op/builtin.h (2)
tilelang/language/tir/op.py (2)
  • ptx_wgmma_rs (1106-1141)
  • ptx_tcgen05_mma_ss (1144-1183)
tilelang/language/builtin.py (3)
  • warpgroup_fence_operand (284-341)
  • initialize_wgmma_descriptor (452-479)
  • initialize_tcgen05_descriptor (482-513)
src/op/gemm_py.cc (2)
src/op/gemm.cc (8)
  • AllowTCGEN5MMA (103-110)
  • AllowTCGEN5MMA (103-103)
  • AllowWGMMA (112-120)
  • AllowWGMMA (112-112)
  • CheckWGMMA (349-399)
  • CheckWGMMA (349-349)
  • GetGemmInst (122-138)
  • GetGemmInst (122-122)
src/target/utils.cc (14)
  • TargetIsSm100 (56-61)
  • TargetIsSm100 (56-56)
  • TargetGetWarpSize (127-132)
  • TargetGetWarpSize (127-127)
  • TargetIsHopper (49-54)
  • TargetIsHopper (49-49)
  • TargetIsCDNA (70-79)
  • TargetIsCDNA (70-70)
  • TargetIsVolta (28-33)
  • TargetIsVolta (28-28)
  • TargetIsAmpere (42-47)
  • TargetIsAmpere (42-42)
  • TargetIsTuring (35-40)
  • TargetIsTuring (35-35)
tilelang/intrinsics/tcgen05_macro_generator.py (6)
tilelang/intrinsics/wgmma_macro_generator.py (11)
  • SwizzleMode (23-60)
  • is_none (30-31)
  • is_swizzle_32b (33-34)
  • is_swizzle_64b (36-37)
  • is_swizzle_128b (39-40)
  • swizzle_byte_size (42-50)
  • swizzle_atom_size (52-60)
  • _initialize_micro_size (111-132)
  • _determinate_swizzle_mode (134-145)
  • _warp_mma (246-271)
  • _warp_mma (336-370)
src/op/builtin.h (1)
  • tvm (13-493)
tilelang/utils/language.py (1)
  • is_tensor_memory (55-65)
tilelang/layout/swizzle.py (4)
  • make_full_bank_swizzled_layout (54-74)
  • make_half_bank_swizzled_layout (79-99)
  • make_quarter_bank_swizzled_layout (104-124)
  • make_linear_layout (127-145)
tilelang/language/builtin.py (1)
  • initialize_tcgen05_descriptor (482-513)
tilelang/language/tir/op.py (1)
  • ptx_tcgen05_mma_ss (1144-1183)
testing/python/transform/test_tilelang_transform_inject_fence_proxy.py (1)
tilelang/language/builtin.py (1)
  • initialize_wgmma_descriptor (452-479)
🪛 Ruff (0.14.0)
tilelang/utils/__init__.py

9-9: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/layout/__init__.py

9-9: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/layout/swizzle.py

38-38: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

tilelang/language/builtin.py

308-308: Avoid specifying long messages outside the exception class

(TRY003)


314-314: Avoid specifying long messages outside the exception class

(TRY003)


322-323: Prefer TypeError exception for invalid type

(TRY004)


322-323: Avoid specifying long messages outside the exception class

(TRY003)


329-329: Avoid specifying long messages outside the exception class

(TRY003)


331-331: Avoid specifying long messages outside the exception class

(TRY003)


462-462: Avoid specifying long messages outside the exception class

(TRY003)


465-465: Avoid specifying long messages outside the exception class

(TRY003)


494-494: Avoid specifying long messages outside the exception class

(TRY003)


497-497: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/tileop/gemm/gemm_tcgen05.py

64-64: Unused method argument: layout_map

(ARG002)


64-64: Unused method argument: thread_var

(ARG002)


82-84: Avoid specifying long messages outside the exception class

(TRY003)


86-86: Unpacked variable atom_m is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


86-86: Unpacked variable atom_n is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


86-86: Unpacked variable atom_k is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


90-90: Avoid specifying long messages outside the exception class

(TRY003)


92-92: Avoid specifying long messages outside the exception class

(TRY003)


94-94: Avoid specifying long messages outside the exception class

(TRY003)


96-96: Avoid specifying long messages outside the exception class

(TRY003)


100-100: Avoid specifying long messages outside the exception class

(TRY003)


104-104: Avoid specifying long messages outside the exception class

(TRY003)


108-109: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/tir/op.py

1160-1160: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF002)

tilelang/intrinsics/tcgen05_macro_generator.py

107-107: Unused method argument: n_dim

(ARG002)


147-147: Avoid specifying long messages outside the exception class

(TRY003)


153-153: Unused method argument: mbar

(ARG002)


182-182: Local variable accum_regs is assigned to but never used

Remove assignment to unused variable accum_regs

(F841)


243-246: Avoid specifying long messages outside the exception class

(TRY003)


349-351: Avoid specifying long messages outside the exception class

(TRY003)


359-362: Avoid specifying long messages outside the exception class

(TRY003)


366-368: Avoid specifying long messages outside the exception class

(TRY003)


396-396: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (36)
src/transform/lower_shared_tmem.cc (3)

91-91: LGTM! Defensive check prevents duplicate remapping.

Skipping already-remapped variables prevents overwriting existing entries and ensures idempotent behavior.


111-111: LGTM! Maintains consistent mapping.

Recording the new data variable to new buffer mapping ensures subsequent lookups can find the correct buffer, supporting the remapping workflow introduced in this PR.


260-269: LGTM! Proper expression handling and variable remapping.

The changes correctly:

  • Preserve the mutated Call expression instead of directly returning
  • Add VarNode visitor to remap variables consistently across the IR

This completes the variable remapping support for tmem lowering.

src/transform/inject_fence_proxy.cc (1)

106-107: LGTM! Correctly extends known generic operations.

Adding the new descriptor initializers ensures they're treated as generic proxies, triggering proper fence injection when followed by async operations. This aligns with the PR's introduction of initialize_wgmma_descriptor (renamed from initialize_descriptor) and the new initialize_tcgen05_descriptor.

tilelang/language/ast/ir.py (1)

1897-1897: LGTM! Consistent PTX operation wrapper.

The new ptx_tcgen05_mma_ss wrapper follows the established pattern for PTX intrinsics, using _dtype_forward for dtype-first argument handling and properly exporting it in __all__.

Also applies to: 2149-2149

tilelang/utils/language.py (1)

55-65: LGTM! Well-implemented scope checker.

The new is_tensor_memory function follows the established pattern of other scope checkers in this file. Using startswith("shared.tmem") correctly handles potential scope variations like "shared.tmem_addr" or other tmem-related scopes.

src/target/codegen_cuda.h (1)

109-114: LGTM! Proper header control flags.

The new boolean flags for conditional header inclusion follow the established pattern in this class. The clear comments and consistent naming (with trailing underscore for private members) maintain good code hygiene.

testing/python/transform/test_tilelang_transform_inject_fence_proxy.py (1)

196-197: LGTM! Test updated for renamed API.

The test correctly reflects the rename from initialize_descriptor to initialize_wgmma_descriptor while preserving the same arguments and test logic, ensuring continued coverage of fence-proxy injection behavior.

src/layout/layout.cc (1)

538-543: LGTM: binding and parameter order look correct.

tl.make_tcgen05mma_swizzled_layout maps to makeGemmABLayoutSm100(stride, mat_continuous, continuity, element_size, k_inner) as expected.

tilelang/tileop/gemm/__init__.py (2)

125-127: TCGEN5 dispatch addition looks good.

Routing to GemmTCGEN5 on is_tcgen5mma() is correct and isolated; other paths unchanged.


33-35: Enum values align with C++ GemmInst (kTCGEN5MMA=2, kMFMA=3); no changes required.

src/op/builtin.cc (2)

226-230: LGTM: warpgroup_fence_operand registration.

4 inputs and opaque effect align with the Python built-in wrapper.

Please ensure codegen includes this op only when used to avoid unnecessary headers.


278-287: LGTM: Descriptor initializers.

  • initialize_wgmma_descriptor: 5 inputs matches Python wrapper.
  • initialize_tcgen05_descriptor: 7 inputs matches Python wrapper.
docs/compiler_internals/inject_fence_proxy.md (1)

20-25: Docs sync looks good.

initialize_descriptorinitialize_wgmma_descriptor updates are consistent in timeline and examples.

Also applies to: 56-56, 86-86

src/op/gemm_py.h (2)

22-23: LGTM: New feature gates.

AllowTCGEN5MMA(Target) and AllowWGMMA(int, Target) declarations align with implementations.


62-64: LGTM: Reflection for new fields.

mbarptr and C_coords are exposed read-only and are included in equality/hash.

src/op/tcgen5_meta.h (2)

20-80: LGTM: Meta selection logic.

Shape/dtype gating and atom selection look consistent with TCGEN5 constraints; clean FAIL/SUCCESS paths.

Consider adding brief comments on allowed atom_n sets for M=64/32 to document intent.


82-161: LGTM: Descriptor packing.

Bitfield packing includes sanity checks; dtype encodings for FP16/BF16/FP8 families and accum formats are clear. Masking handles width=32 safely.

src/target/codegen_cuda.cc (4)

1395-1410: warpgroup_fence_operand emission LGTM

Argument parsing, dtype-to-cast mapping (float for f32/tf32 else uint32_t), and pointer arithmetic are correct and align with Python wrapper semantics.

Ensure tl_templates provides overloads for both float* and uint32_t* to avoid ambiguous call resolution with nvcc when num_regs is constexpr vs runtime.


1733-1792: TCGEN05 path looks correct; minor pointer alignment check

Descriptor decode + template substitution are consistent. C is cast to uint32_t*, which is appropriate for 32-bit regs; ensure C_ptr is 4-byte aligned at (C) for all call sites.


2117-2128: initialize_wgmma_descriptor emission LGTM

Arity check and template invocation match the Python builtin wrapper.


2130-2144: initialize_tcgen05_descriptor emission LGTM

Matches wrapper signature and forwards args verbatim.

src/tl_templates/cuda/common.h (1)

351-361: Updated wgmma descriptor initializer LGTM

Template params and field assignments remain consistent with previous semantics.

src/op/builtin.h (3)

249-253: New intrinsic declaration ptx_tcgen05_mma_ss LGTM

Matches the Python wrapper and codegen usage.


366-373: warpgroup_fence_operand export LGTM

Signature matches Python/frontend usage and codegen path.


475-482: Descriptor initializer exports LGTM

initialize_wgmma_descriptor and initialize_tcgen05_descriptor correctly exposed.

src/op/gemm_py.cc (5)

16-17: Include tcgen5_meta.h LGTM

Matches new helpers used below.


80-93: Constructor: mbarptr and C_coords defaults LGTM

Backwards-compatible defaults when args absent; types consistent.


119-127: AllowWGMMA gate LGTM

Respects disable flag, Hopper, M>=64, and warp multiple-of-4 constraint before CheckWGMMA().


129-147: GetGemmInst selection LGTM

Prefers TCGEN5MMA, then WGMMA, then MFMA/MMA by target.


328-353: FFI exports for tcgen5 meta/desc LGTM

Useful reflection endpoints; types/returns consistent.

tilelang/language/builtin.py (1)

7-8: Imports look fine

No concerns; usage of DataType/tir and convert aligns with later code.

tilelang/intrinsics/wgmma_macro_generator.py (4)

184-186: Double-check accumulator register count used for fencing.

accum_regs = ((m_dim // 64) * warp_cols * local_size_out * accum_bits + 31) // 32 looks reasonable, but a mismatch under/over-fences registers, weakening the NVCC code‑motion barrier.

Please confirm local_size_out is per‑warp fragment elements and that the product over warps matches the accumulator register footprint actually live across the WGMMA batch.


249-253: Descriptor initialization with read pointers + fences around C look correct.

  • Switched to access_ptr("r") and initialize_wgmma_descriptor: good.
  • Fencing accumulator registers before/after batch protects against NVCC sinking.

Also applies to: 271-271


305-333: B swizzle atom/LBO/SBO computations are consistent.

  • Non‑swizzle uses b_swizzle_atom_elems = n_dim to avoid zero atom size.
  • Swizzled K‑major and MN‑major LBO/SBO follow PTX requirements (multiples of 16B).

338-348: RS path fences and atomized B offsets LGTM.

  • Descriptor init with read pointer is correct.
  • Fences for A and C before and after commit/wait bracket the WGMMA batch.
  • bk_atom_size in B_offset properly handles atom splitting.

Also applies to: 367-371

Comment on lines +157 to +161
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
.set_num_inputs(13)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix arity: ptx_tcgen05_mma_ss requires 14 inputs (not 13).

Python/TIR wrapper passes 14 args; registering 13 will error at call time. Update set_num_inputs to 14.

 TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
-    .set_num_inputs(13)
+    .set_num_inputs(14)
     .set_attr<TCallEffectKind>("TCallEffectKind",
                                Integer(CallEffectKind::kOpaque));

Reference: tilelang/language/tir/op.py expects 14 args for tl.ptx_tcgen05_mma_ss. Based on provided snippets.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
.set_num_inputs(13)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
.set_num_inputs(14)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
🤖 Prompt for AI Agents
In src/op/builtin.cc around lines 157 to 161, the registration for
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss) incorrectly sets .set_num_inputs(13);
update this to .set_num_inputs(14) so the C++ op registration matches the
Python/TIR wrapper which passes 14 arguments, and verify no other registrations
for this op remain with the old arity.

Comment on lines +110 to +117
bool GemmPyNode::AllowTCGEN5MMA(Target target) const {
return TargetIsSm100(target) &&
((A.scope() == "shared.dyn" || A.scope() == "shared" ||
A.scope() == "shared.tmem") &&
(B.scope() == "shared.dyn" || B.scope() == "shared") &&
C.scope() == "shared.tmem") &&
GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

AllowTCGEN5MMA gate LGTM; verify TargetIsSm100 logic

Condition looks right. However, TargetIsSm100 in src/target/utils.cc appears to use bitwise '&' instead of '&&', which can mis-gate SM100 detection.

If it shows arch >= 100 & arch <= 110, change to arch >= 100 && arch <= 110 in src/target/utils.cc. This affects AllowTCGEN5MMA().


🏁 Script executed:

#!/bin/bash
# Verify TargetIsSm100 implementation uses logical &&, not bitwise &.
rg -nC2 'bool\s+TargetIsSm100' src | sed -n '1,120p'

Length of output: 489


🏁 Script executed:

#!/bin/bash
sed -n '56,90p' src/target/utils.cc

Length of output: 1002


Replace bitwise & with logical && in TargetIsSm100
In src/target/utils.cc:58 change arch >= 100 & arch <= 110 to use && for correct SM100 detection.

🤖 Prompt for AI Agents
In src/op/gemm_py.cc around lines 110 to 117, the review points out that
TargetIsSm100 uses a bitwise & instead of logical && when checking architecture
range; update TargetIsSm100 in src/target/utils.cc (line ~58) to use `&&` (arch
>= 100 && arch <= 110) so the SM100 detection is correct, then rebuild and run
tests that exercise GemmPyNode::AllowTCGEN5MMA to ensure the change fixes
selection logic.

Comment on lines 77 to 79
equal(offset_A, other->offset_B) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Correct SEqualReduce comparisons for offsets.

Currently compares offset_A to other->offset_B twice. This breaks structural equality and can cause caching bugs.

-           equal(offset_A, other->offset_B) &&
-           equal(offset_B, other->offset_B) &&
+           equal(offset_A, other->offset_A) &&
+           equal(offset_B, other->offset_B) &&
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
equal(offset_A, other->offset_B) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(offset_A, other->offset_A) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
🤖 Prompt for AI Agents
In src/op/gemm_py.h around lines 77 to 79, the SEqualReduce comparison
mistakenly compares offset_A to other->offset_B; change the comparisons so
offset_A is compared to other->offset_A and offset_B to other->offset_B (i.e.,
replace the first equal(offset_A, other->offset_B) with equal(offset_A,
other->offset_A) and keep/verify the second is equal(offset_B, other->offset_B))
to restore correct structural equality and avoid caching bugs.

Comment on lines +268 to +270
if (need_tcgen05mma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/tcgen05gmma.h>\n";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Typo in include path will break builds

Header name is misspelled: tcgen05gmma.h → tcgen05mma.h.

Apply this diff:

-  if (need_tcgen05mma_instruction_h_) {
-    decl_stream << "#include <tl_templates/cuda/instruction/tcgen05gmma.h>\n";
-  }
+  if (need_tcgen05mma_instruction_h_) {
+    decl_stream << "#include <tl_templates/cuda/instruction/tcgen05mma.h>\n";
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (need_tcgen05mma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/tcgen05gmma.h>\n";
}
if (need_tcgen05mma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/tcgen05mma.h>\n";
}
🤖 Prompt for AI Agents
In src/target/codegen_cuda.cc around lines 268 to 270, the included header
filename is misspelled ("tcgen05gmma.h"); update the include to the correct
header name "tcgen05mma.h" so the line reads to include
tl_templates/cuda/instruction/tcgen05mma.h, ensuring the build can find the
proper header.

Comment on lines +156 to +158
if is_tensor_memory(A_buf):
return self.tcgen05mma_rs(A_buf, B_buf, C_local_buf, clear_accum)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify tcgen05mma_rs availability

This calls self.tcgen05mma_rs(...). Ensure it exists in MMAIntrinEmitter or this class.

Run:


🏁 Script executed:

#!/bin/bash
set -euo pipefail
rg -nP 'class\s+TensorCoreIntrinEmitter|class\s+.*MMAIntrinEmitter' -C2
echo
rg -nP 'def\s+tcgen05mma_rs\s*\(' -C3 --type=py

Length of output: 1530


Define missing tcgen05mma_rs method
In tilelang/intrinsics/tcgen05_macro_generator.py (lines 156–158), self.tcgen05mma_rs(...) isn’t implemented in this class or its superclasses. Add its definition in MMAIntrinEmitter (e.g., in mma_macro_generator.py) or correct the method call.

Comment on lines +252 to +266
print("Before get get_tcgen5_instr_desc")
instr_desc = T.Cast(
"uint32",
self.get_tcgen5_instr_desc(
atom_m,
atom_n,
atom_k,
a_is_k_major,
b_is_k_major,
scale_in_a,
scale_in_b,
),
)
print("instr_desc, ", instr_desc)
mask_full = T.Cast("int32", -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Remove debug prints

Stray prints in hot-path codegen will pollute logs and slow kernels.

Apply this diff:

-        print("Before get get_tcgen5_instr_desc")
...
-        print("instr_desc, ", instr_desc)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print("Before get get_tcgen5_instr_desc")
instr_desc = T.Cast(
"uint32",
self.get_tcgen5_instr_desc(
atom_m,
atom_n,
atom_k,
a_is_k_major,
b_is_k_major,
scale_in_a,
scale_in_b,
),
)
print("instr_desc, ", instr_desc)
mask_full = T.Cast("int32", -1)
instr_desc = T.Cast(
"uint32",
self.get_tcgen5_instr_desc(
atom_m,
atom_n,
atom_k,
a_is_k_major,
b_is_k_major,
scale_in_a,
scale_in_b,
),
)
mask_full = T.Cast("int32", -1)
🤖 Prompt for AI Agents
In tilelang/intrinsics/tcgen05_macro_generator.py around lines 252 to 266,
remove the stray debug print statements ("Before get get_tcgen5_instr_desc" and
"instr_desc, ", instr_desc) so the hot-path codegen doesn't emit logs or slow
down kernels; simply delete those print(...) lines and leave the surrounding
code (the T.Cast call to get_tcgen5_instr_desc and the mask_full assignment)
intact.

Comment on lines +306 to +309
a_dtype_abbrv,
b_dtype_abbrv,
a_dtype_abbrv,
desc_a.data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Blocker: wrong c_dtype passed to tcgen05 intrinsic

ptx_tcgen05_mma_ss expects (a_dtype, b_dtype, c_dtype, ...). You pass a_dtype_abbrv for c_dtype; it should be the accumulator dtype abbrv.

Apply this diff:

-        a_dtype_abbrv,
-        b_dtype_abbrv,
-        a_dtype_abbrv,
+        a_dtype_abbrv,
+        b_dtype_abbrv,
+        accum_dtype_abbrv,

Ensure accum_dtype_abbrv is defined:

-        a_dtype_abbrv = self.a_dtype_abbrv
-        b_dtype_abbrv = self.b_dtype_abbrv
+        a_dtype_abbrv = self.a_dtype_abbrv
+        b_dtype_abbrv = self.b_dtype_abbrv
+        accum_dtype_abbrv = self.accum_dtype_abbrv
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
a_dtype_abbrv,
b_dtype_abbrv,
a_dtype_abbrv,
desc_a.data,
a_dtype_abbrv = self.a_dtype_abbrv
b_dtype_abbrv = self.b_dtype_abbrv
accum_dtype_abbrv = self.accum_dtype_abbrv
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
desc_a.data,
🤖 Prompt for AI Agents
In tilelang/intrinsics/tcgen05_macro_generator.py around lines 306 to 309, the
third argument passed to ptx_tcgen05_mma_ss is incorrectly a_dtype_abbrv;
replace that third argument with accum_dtype_abbrv (the accumulator dtype
abbreviation) and make sure accum_dtype_abbrv is defined earlier in this scope
(derive it from the accumulator/output descriptor or dtype used for
accumulation, e.g., obtain the accumulator dtype abbrev from the
result/accumulator descriptor and assign it to accum_dtype_abbrv before this
call).

k_dim, n_dim = self.chunk, self.block_col_warps * self.warp_col_tiles
wgmma_prefix = self.wgmma_prefix
scale_out = not clear_accum
scale_out = ~clear_accum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Use TIR-safe negation; avoid Python ~False-1.

~clear_accum on a Python bool (default False) evaluates to -1, which can leak as an integer into the intrinsic. Prefer a TIR-safe expression.

Apply this diff in both places:

-        scale_out = ~clear_accum
+        scale_out = 1 - lift(clear_accum)  # 0/1 in TIR regardless of Python bool or PrimExpr

Also applies to: 291-291


🏁 Script executed:

#!/bin/bash
# Locate scale_out usage and context
rg -nC3 "scale_out" tilelang/intrinsics/wgmma_macro_generator.py

# Inspect clear_accum occurrences
rg -nC3 "clear_accum" tilelang/intrinsics/wgmma_macro_generator.py

# Find lift definition or import
rg -n "def lift" tilelang/intrinsics/wgmma_macro_generator.py || rg -n "lift(" tilelang/intrinsics/wgmma_macro_generator.py

Length of output: 2639


🏁 Script executed:

#!/bin/bash
# Find definition of lift in Python files
rg -n "def lift" -g '*.py'

# Find all calls to lift(
rg -n "lift\(" .

Length of output: 2373


🏁 Script executed:

#!/bin/bash
# Search for lift definition
rg -n "def lift" .

# Search for lift import statements
rg -n "import .*lift" .

# Search for variable assignment to lift
rg -n "lift\s*=" .

Length of output: 258


Use TIR-safe negation for clear_accum
~clear_accum performs a bitwise complement on a PrimExpr, yielding –1 or unintended values. Replace with 1 - lift(clear_accum) to produce a 0/1 PrimExpr.

Apply in both occurrences (lines 166 & 291):

-        scale_out = ~clear_accum
+        scale_out = 1 - lift(clear_accum)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
scale_out = ~clear_accum
scale_out = 1 - lift(clear_accum)
🤖 Prompt for AI Agents
In tilelang/intrinsics/wgmma_macro_generator.py around line 166 (and also apply
same change at line 291), the code uses the Python bitwise complement operator
(~clear_accum) which performs an unsafe bitwise negation on a PrimExpr; replace
this with the TIR-safe expression 1 - lift(clear_accum) so the result is a 0/1
PrimExpr. Update both occurrences to use 1 - lift(clear_accum), ensuring lift is
imported/available in scope and the resulting expression type matches the
surrounding TIR expectations.

Comment on lines +122 to +132
@property
def mbarptr(self) -> PrimExpr:
return getattr(self.gemm_node, "mbarptr", tvm.tir.const(0, "uint32"))

@property
def C_coords(self):
coords = getattr(self.gemm_node, "C_coords", None)
if coords is None or len(coords) == 0:
zero = tvm.tir.const(0, "int32")
return [zero, zero]
return [coords[i] for i in range(len(coords))]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Default mbarptr width and C_coords fallback can cause subtle bugs.

  • mbarptr: defaulting to uint32 may truncate pointers on 64‑bit. Prefer pointer‑sized integer.
  • C_coords: returning [0, 0] masks missing metadata; downstream checks like len(C_coords) != 2 won’t trigger.

Apply:

 @property
 def mbarptr(self) -> PrimExpr:
-    return getattr(self.gemm_node, "mbarptr", tvm.tir.const(0, "uint32"))
+    # Use pointer-sized default; treat 0 as null-equivalent.
+    return getattr(self.gemm_node, "mbarptr", tvm.tir.const(0, "uint64"))
@@
 @property
 def C_coords(self):
     coords = getattr(self.gemm_node, "C_coords", None)
-    if coords is None or len(coords) == 0:
-        zero = tvm.tir.const(0, "int32")
-        return [zero, zero]
-    return [coords[i] for i in range(len(coords))]
+    if coords is None or len(coords) == 0:
+        # Force validation upstream; avoid silently using bogus coordinates.
+        return []
+    return [coords[i] for i in range(len(coords))]

If TVM has a dedicated null/handle type you use elsewhere, switch mbarptr to that for stronger typing. Also ensure any callers handle the empty C_coords case explicitly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@property
def mbarptr(self) -> PrimExpr:
return getattr(self.gemm_node, "mbarptr", tvm.tir.const(0, "uint32"))
@property
def C_coords(self):
coords = getattr(self.gemm_node, "C_coords", None)
if coords is None or len(coords) == 0:
zero = tvm.tir.const(0, "int32")
return [zero, zero]
return [coords[i] for i in range(len(coords))]
@property
def mbarptr(self) -> PrimExpr:
# Use pointer-sized default; treat 0 as null-equivalent.
return getattr(self.gemm_node, "mbarptr", tvm.tir.const(0, "uint64"))
@property
def C_coords(self):
coords = getattr(self.gemm_node, "C_coords", None)
if coords is None or len(coords) == 0:
# Force validation upstream; avoid silently using bogus coordinates.
return []
return [coords[i] for i in range(len(coords))]

Comment on lines +98 to +101
mbarptr = self.mbarptr
if mbarptr == 0:
raise ValueError("TCGEN5MMA requires a valid mbarrier pointer")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: PrimExpr used in boolean context.

if mbarptr == 0: constructs a TIR expr and will error in Python truthiness. Use structural equality against a 0-const.

-        mbarptr = self.mbarptr
-        if mbarptr == 0:
+        mbarptr = self.mbarptr
+        if analysis.expr_deep_equal(mbarptr, tir.const(0, "uint32")):
             raise ValueError("TCGEN5MMA requires a valid mbarrier pointer")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
mbarptr = self.mbarptr
if mbarptr == 0:
raise ValueError("TCGEN5MMA requires a valid mbarrier pointer")
mbarptr = self.mbarptr
if analysis.expr_deep_equal(mbarptr, tir.const(0, "uint32")):
raise ValueError("TCGEN5MMA requires a valid mbarrier pointer")
🧰 Tools
🪛 Ruff (0.14.0)

100-100: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In tilelang/tileop/gemm/gemm_tcgen05.py around lines 98 to 101, the code uses a
TIR PrimExpr in a Python boolean context with "if mbarptr == 0:", which triggers
errors; replace this by performing a structural equality against a TIR zero
constant, e.g. construct a 0 constant with the same dtype as mbarptr
(tir.const(0, mbarptr.dtype)) and call tir.structural_equal(mbarptr,
that_zero_const) (or the equivalent structural equality helper used in the
repo), and branch on that result instead of using Python truthiness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants