-
Notifications
You must be signed in to change notification settings - Fork 314
[Language] support T.gemm_sp_v2 on sm80 and sm89
#1056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughIntroduces a new Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant gemm_sp_v2
participant GemmSPPy
participant GemmSPMMA
participant SparseTensorCoreIntrinEmitter
User->>gemm_sp_v2: call with A_sparse, E, B, C, policy
gemm_sp_v2->>gemm_sp_v2: legalize_arguments() & retrieve shapes/strides/offsets
gemm_sp_v2->>GemmSPPy: tl.gemm_sp_py intrinsic call
Note over GemmSPPy: Kernel Compilation Phase
GemmSPPy->>GemmSPMMA: infer_layout(target, thread_nums)
GemmSPMMA->>GemmSPMMA: compute warp partitions
GemmSPMMA->>SparseTensorCoreIntrinEmitter: create emitter per pattern (ss/sr/rs/rr)
SparseTensorCoreIntrinEmitter->>SparseTensorCoreIntrinEmitter: make_mma_load_layout(A, B, E)
SparseTensorCoreIntrinEmitter-->>GemmSPMMA: fragment layouts
GemmSPMMA-->>GemmSPPy: layout_map
Note over GemmSPPy: Kernel Lowering Phase
GemmSPPy->>GemmSPMMA: lower(target, thread_nums, thread_var)
GemmSPMMA->>SparseTensorCoreIntrinEmitter: orchestrate loads & mma_sp
SparseTensorCoreIntrinEmitter->>SparseTensorCoreIntrinEmitter: ldmatrix_a/e/b()
SparseTensorCoreIntrinEmitter->>SparseTensorCoreIntrinEmitter: mma_sp() with ptx_mma_sp
SparseTensorCoreIntrinEmitter->>SparseTensorCoreIntrinEmitter: stmatrix()
SparseTensorCoreIntrinEmitter-->>GemmSPMMA: prim_func kernel
GemmSPMMA-->>GemmSPPy: lowered kernel
GemmSPPy-->>User: compiled kernel
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Areas requiring close attention:
Possibly related issues
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (7)
🚧 Files skipped from review as they are similar to previous changes (2)
🧰 Additional context used🧬 Code graph analysis (4)tilelang/layout/__init__.py (2)
src/op/gemm_sp.cc (1)
tilelang/intrinsics/mma_macro_generator.py (1)
tilelang/language/__init__.py (1)
🪛 Ruff (0.14.4)tilelang/layout/__init__.py16-16: Unused Remove unused (RUF100) tilelang/language/__init__.py53-53: Unused Remove unused (RUF100) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (8)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
016dd1c to
122abb5
Compare
T.gemm_sp_v2T.gemm_sp_v2 on sm80 and sm89
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 13
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tilelang/profiler/__init__.py (1)
131-132: Fix incorrect usage ofis_float8function.The
is_float8function expectstorch.dtypebut receivestorch.Tensorobjects. This will cause incorrect behavior.Apply this diff:
- lhs if not is_float8(lhs) else lhs.to(torch.float32), - rhs if not is_float8(rhs) else rhs.to(torch.float32), + lhs if not is_float8(lhs.dtype) else lhs.to(torch.float32), + rhs if not is_float8(rhs.dtype) else rhs.to(torch.float32),benchmark/matmul/benchmark_matmul_sp.py (1)
270-270: Fix missingin_dtypeparameter.The
matmul_spfunction expects bothin_dtypeandaccum_dtypeparameters (line 89), but onlyaccum_dtypeis passed here.Apply this diff:
- best_result = matmul_sp(M, N, K, args.accum_dtype) + best_result = matmul_sp(M, N, K, "float16", args.accum_dtype)Or add an
--in_dtypeargument:+ parser.add_argument( + "--in_dtype", + type=str, + default="float16", + choices=["float16", "float8_e4m3fn"], + help="Input datatype") args = parser.parse_args() ... - best_result = matmul_sp(M, N, K, args.accum_dtype) + best_result = matmul_sp(M, N, K, args.in_dtype, args.accum_dtype)
🧹 Nitpick comments (11)
tilelang/intrinsics/mma_sp_macro_generator.py (2)
195-239: Remove leftover debug prints from initializationThese
print(...)calls run every time the emitter is constructed, spamming stdout in normal builds. Please drop them (or gate behind an explicit debug flag).- print(f"{self.local_size_a=}, {self.local_size_e=}, {self.local_size_b=}, {self.local_size_out=}, {self.e_factor=} {n_dim=} {k_dim=}") ... - print(f"{self.warp_rows=}, {self.warp_cols=}, {self.n_dim=}, {self.micro_size_x=}, {self.micro_size_y=}, {self.micro_size_k=}, {self.warp_row_tiles=}, {self.warp_col_tiles=}")
551-553: Drop stray print inmma_spThis diagnostic
- print(f"{e_local_stride=}")tilelang/language/experimental/gemm_sp.py (3)
191-191: Remove unused variableE_shape.The variable
E_shapeis assigned but never used in the function.Apply this diff:
A_shape = retrieve_shape(A_sparse) -E_shape = retrieve_shape(E) B_shape = retrieve_shape(B) C_shape = retrieve_shape(C)
143-274: Consider refactoring helper functions to reduce duplication.The helper functions
retrieve_shape,retrieve_stride,retrieve_ptr, andretrieve_offsethave significant structural duplication with nearly identical type-checking patterns and branching logic for Buffer/BufferRegion/BufferLoad. Consider extracting the common pattern or creating a base dispatcher.For example, you could create a generic dispatcher:
def _dispatch_on_buffer_type(object, buffer_handler, region_handler, load_handler): if isinstance(object, tir.Buffer): return buffer_handler(object) elif isinstance(object, tir.BufferRegion): return region_handler(object) elif isinstance(object, tir.BufferLoad): return load_handler(object) else: raise TypeError(f"Unsupported type: {type(object)}")Then each retrieve function becomes a simple call to this dispatcher with specialized handlers.
159-160: Consider usingTypeErrorfor invalid type errors.When raising exceptions for unsupported argument types,
TypeErroris more semantically appropriate thanValueError.For example, at line 159-160:
- raise ValueError( - f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}") + raise TypeError( + f"Unsupported retrieve_shape argument type: {type(object)}")Apply similar changes to lines 187-188, 253-254, and 273-274.
Also applies to: 187-188, 253-254, 273-274
src/op/gemm_sp_py.cc (1)
143-193: Consider refactoring to reduce cognitive complexity.The
CheckWGMMAfunction has significant duplication with repeated checks for dtype combinations and K-alignment constraints. While correct, this makes the code harder to maintain.Consider extracting common patterns into helper functions or using a lookup table:
// Example: Extract alignment checks static constexpr int GetKAlignment(DataType c_dtype, DataType a_dtype, DataType b_dtype) { if (c_dtype == DataType::Float(16)) { if (a_dtype.is_float8() || b_dtype.is_float8()) return 32; return 16; } // ... similar patterns } bool GemmSPPyNode::CheckWGMMA() const { if (B.scope() != "shared.dyn" && B.scope() != "shared") { return false; } int required_alignment = GetKAlignment(C->dtype, A->dtype, B->dtype); if (required_alignment == 0) return false; bool requires_specific_transpose = /* condition */; if (requires_specific_transpose && (trans_A || !trans_B)) { return false; } return K % required_alignment == 0; }examples/gemm_sp/example_custom_compress.py (3)
119-120: Remove unused variable.The
devicevariable is assigned but never used in the function.Apply this diff:
m, k = dense.shape -device = dense.device
281-282: TODO: Address the alloc_var issue.The comment indicates
alloc_varhas buggy behavior. This should be tracked and resolved.Do you want me to open a new issue to track this TODO?
296-296: TODO: Add device assertion after rebasing.The comment suggests using
T.device_assert(non_zero_cnt <= 2)after rebasing the main branch. This validation would help catch violations of the 2:4 sparsity constraint at runtime.Do you want me to open a new issue to track adding this assertion?
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (1)
585-587: Minor formatting: Remove extra space.There's an inconsistent space before the
Mparameter.Apply this diff:
program = matmul_rr( - M, + M, N,tilelang/intrinsics/mma_sp_layout.py (1)
60-101: LGTM with optional style suggestion.The metadata functions correctly handle different bit-width scenarios with appropriate logical ID transformations.
Note: The static analysis warning about unused
local_idon line 96 is acceptable—the comment correctly explains thatlocal_idis always 0 for this function, and maintaining the parameter keeps the signature consistent with similar functions.Optional: Add explicit parentheses for clarity.
Lines 87 and 93 could benefit from explicit parentheses to improve readability, even though the operator precedence is correct:
- col = (logical_id % 4) // 2 * 4 + local_id + col = ((logical_id % 4) // 2) * 4 + local_id- col = (logical_id % 4) // 2 * 2 + local_id + col = ((logical_id % 4) // 2) * 2 + local_id
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
docs/_static/img/sparse_mma_storage_example.pngis excluded by!**/*.png
📒 Files selected for processing (33)
benchmark/matmul/benchmark_matmul_sp.py(6 hunks)docs/deeplearning_operators/matmul_sparse.md(1 hunks)docs/index.md(1 hunks)examples/gemm_sp/example_custom_compress.py(1 hunks)examples/gemm_sp/example_gemm_sp.py(5 hunks)examples/gemm_sp/test_example_gemm_sp.py(1 hunks)examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py(2 hunks)src/op/gemm.h(1 hunks)src/op/gemm_sp.cc(1 hunks)src/op/gemm_sp.h(1 hunks)src/op/gemm_sp_py.cc(1 hunks)src/op/gemm_sp_py.h(1 hunks)src/target/ptx.cc(1 hunks)src/tl_templates/cuda/debug.h(1 hunks)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py(7 hunks)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py(1 hunks)tilelang/intrinsics/mma_layout.py(1 hunks)tilelang/intrinsics/mma_macro_generator.py(5 hunks)tilelang/intrinsics/mma_sp_layout.py(1 hunks)tilelang/intrinsics/mma_sp_macro_generator.py(1 hunks)tilelang/ir.py(1 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/experimental/gemm_sp.py(2 hunks)tilelang/layout/__init__.py(1 hunks)tilelang/layout/gemm_sp.py(3 hunks)tilelang/profiler/__init__.py(1 hunks)tilelang/tileop/__init__.py(1 hunks)tilelang/tileop/gemm/__init__.py(1 hunks)tilelang/tileop/gemm_sp/__init__.py(1 hunks)tilelang/tileop/gemm_sp/gemm_sp_base.py(1 hunks)tilelang/tileop/gemm_sp/gemm_sp_mma.py(1 hunks)tilelang/utils/sparse.py(5 hunks)tilelang/utils/tensor.py(1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.pytilelang/intrinsics/mma_sp_layout.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.pytilelang/intrinsics/mma_sp_layout.py
🧬 Code graph analysis (27)
src/op/gemm.h (1)
src/transform/layout_reducer.h (1)
Object(63-71)
tilelang/layout/__init__.py (1)
tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(137-152)
src/op/gemm_sp.cc (1)
src/op/gemm.h (2)
RegisterReflection(36-42)RegisterReflection(123-145)
src/op/gemm_sp.h (2)
src/ir.cc (2)
TVM_DECLARE_FINAL_OBJECT_INFO(197-197)TVM_DECLARE_FINAL_OBJECT_INFO(314-314)src/op/gemm.h (1)
GemmWarpPolicyNode(27-73)
tilelang/utils/tensor.py (1)
tilelang/engine/param.py (1)
is_float8(81-91)
tilelang/language/__init__.py (1)
tilelang/language/experimental/gemm_sp.py (2)
gemm_sp(9-86)gemm_sp_v2(89-307)
tilelang/tileop/gemm_sp/gemm_sp_mma.py (5)
tilelang/tileop/gemm_sp/gemm_sp_base.py (15)
infer_layout(14-15)policy(126-127)M(33-34)N(37-38)e_dtype(53-54)accum_dtype(62-63)trans_A(45-46)trans_B(49-50)K(41-42)is_gemm_ss(20-21)A(66-67)B(74-75)C(78-79)lower(17-18)E(70-71)tilelang/intrinsics/mma_sp_macro_generator.py (7)
SparseTensorCoreIntrinEmitter(39-867)make_mma_store_layout(798-867)make_mma_load_layout(655-796)ldmatrix_a(293-357)ldmatrix_e(359-425)ldmatrix_b(428-522)mma_sp(524-598)tilelang/utils/language.py (2)
is_shared(25-39)is_fragment(68-78)tilelang/transform/simplify.py (1)
_Simplify(30-49)tilelang/tileop/gemm_sp/__init__.py (2)
infer_layout(52-57)lower(59-65)
tilelang/tileop/gemm_sp/__init__.py (3)
tilelang/utils/target.py (1)
target_is_cuda(90-91)tilelang/tileop/gemm_sp/gemm_sp_mma.py (3)
GemmSPMMA(12-244)infer_layout(14-58)lower(60-232)tilelang/tileop/gemm_sp/gemm_sp_base.py (23)
infer_layout(14-15)lower(17-18)A(66-67)E(70-71)B(74-75)C(78-79)APtr(82-83)EPtr(86-87)BPtr(90-91)CPtr(94-95)M(33-34)N(37-38)K(41-42)trans_A(45-46)trans_B(49-50)stride_A(98-99)stride_B(102-103)offset_A(106-107)offset_B(110-111)clear_accum(114-115)k_pack(118-119)wg_wait(122-123)policy(126-127)
tilelang/language/experimental/gemm_sp.py (3)
tilelang/language/frame.py (2)
has_let_value(189-198)get_let_value(201-210)tilelang/utils/language.py (1)
get_buffer_region_from_load(124-138)tilelang/language/tir/op.py (1)
call_intrin(119-144)
examples/gemm_sp/test_example_gemm_sp.py (2)
examples/gemm_sp/example_custom_compress.py (1)
main(314-357)examples/gemm_sp/example_gemm_sp.py (1)
main(107-146)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (3)
tilelang/utils/sparse.py (3)
compress(77-106)randn_semi_sparse(109-128)randint_semi_sparse(130-157)tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(137-152)tilelang/utils/tensor.py (2)
torch_assert_close(233-325)map_torch_type(34-51)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)
tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(137-152)
examples/gemm_sp/example_gemm_sp.py (4)
tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(137-152)tilelang/language/experimental/gemm_sp.py (1)
gemm_sp_v2(89-307)examples/gemm_sp/example_custom_compress.py (1)
kernel(258-309)tilelang/utils/sparse.py (1)
compress(77-106)
tilelang/tileop/__init__.py (2)
tilelang/language/experimental/gemm_sp.py (1)
gemm_sp(9-86)tilelang/tileop/gemm_sp/__init__.py (1)
GemmSPPy(25-65)
tilelang/utils/sparse.py (1)
tilelang/utils/tensor.py (2)
is_float8(10-16)fp8_remove_negative_zeros_(18-22)
tilelang/ir.py (1)
src/op/gemm_sp.h (5)
tvm(13-131)GemmSPWarpPolicy(29-53)GemmSPWarpPolicy(34-38)GemmSPWarpPolicy(40-44)GemmSPWarpPolicy(46-52)
examples/gemm_sp/example_custom_compress.py (3)
tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(137-152)tilelang/utils/sparse.py (3)
compress(77-106)randn_semi_sparse(109-128)arange_semi_sparse(159-182)tilelang/utils/tensor.py (2)
torch_assert_close(233-325)map_torch_type(34-51)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (4)
tilelang/utils/sparse.py (3)
compress(77-106)randn_semi_sparse(109-128)randint_semi_sparse(130-157)tilelang/utils/tensor.py (2)
torch_assert_close(233-325)map_torch_type(34-51)tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(137-152)tilelang/intrinsics/mma_sp_macro_generator.py (1)
SparseTensorCoreIntrinEmitter(39-867)
src/op/gemm_sp_py.h (2)
src/op/operator.h (2)
TileOperatorNode(56-96)TileOperator(63-95)src/op/gemm_sp_py.cc (11)
CheckWGMMA(143-193)CheckWGMMA(143-143)Lower(222-255)Lower(222-222)InferLayout(257-272)InferLayout(257-258)Clone(92-95)Clone(92-92)GetGemmInst(97-111)GetGemmInst(97-97)GemmSPPy(50-82)
tilelang/tileop/gemm/__init__.py (4)
tilelang/language/ast/ir.py (2)
Range(1716-1728)target(1682-1713)tilelang/ir.py (1)
GemmWarpPolicy(30-39)tilelang/tileop/gemm/gemm_mma.py (2)
GemmMMA(13-212)infer_layout(15-58)tilelang/tileop/gemm/gemm_base.py (1)
infer_layout(14-15)
tilelang/profiler/__init__.py (2)
tilelang/utils/tensor.py (1)
is_float8(10-16)tilelang/engine/param.py (1)
is_float8(81-91)
tilelang/tileop/gemm_sp/gemm_sp_base.py (4)
tilelang/utils/language.py (2)
is_shared(25-39)is_fragment(68-78)tilelang/carver/roller/node.py (1)
Node(93-175)tilelang/tileop/gemm_sp/gemm_sp_mma.py (6)
infer_layout(14-58)lower(60-232)is_gemm_ss(234-235)is_gemm_sr(237-238)is_gemm_rs(240-241)is_gemm_rr(243-244)tilelang/tileop/gemm_sp/__init__.py (2)
infer_layout(52-57)lower(59-65)
benchmark/matmul/benchmark_matmul_sp.py (3)
tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(137-152)examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)
matmul_sp(9-60)tilelang/language/experimental/gemm_sp.py (1)
gemm_sp_v2(89-307)
tilelang/intrinsics/mma_sp_layout.py (1)
tilelang/intrinsics/mma_layout.py (4)
mma_load_a_32x4_to_shared_16x8_layout(130-133)mma_load_a_32x16_to_shared_16x32_layout(142-145)mma_load_a_32x8_to_shared_16x16_layout(147-160)ldmatrix_trans_32x8_to_shared_16x16_layout(24-27)
tilelang/intrinsics/mma_macro_generator.py (1)
tilelang/intrinsics/mma_layout.py (4)
mma_load_b_32x8_to_shared_16x16_layout(167-179)mma_load_a_32x16_to_shared_16x32_layout(142-145)mma_load_b_32x16_to_shared_16x32_layout(162-165)mma_load_a_32x8_to_shared_16x16_layout(147-160)
src/op/gemm_sp_py.cc (1)
src/op/gemm_sp_py.h (2)
GemmSPPy(119-124)RegisterReflection(41-65)
tilelang/intrinsics/mma_sp_macro_generator.py (4)
tilelang/intrinsics/utils.py (2)
mma_store_index_map(81-82)get_ldmatrix_offset(21-63)tilelang/utils/language.py (2)
is_fragment(68-78)is_global(12-22)tilelang/intrinsics/mma_sp_layout.py (20)
shared_16x16_to_mma_sp_layout_sr_a(14-15)shared_16x16_to_mma_sp_layout_sr_b(17-19)shared_16x32_to_mma_sp_layout_sr_a(21-22)shared_16x32_to_mma_sp_layout_sr_b(24-26)shared_16x64_to_mma_sp_layout_sr_a(28-29)shared_16x64_to_mma_sp_layout_sr_b(31-33)mma_sp_load_a_32x4_to_shared_16x16_layout(35-36)mma_sp_load_a_32x8_to_shared_16x32_layout(38-39)mma_sp_load_a_32x16_to_shared_16x64_layout(41-42)mma_sp_load_b_32x8_to_shared_16x16_layout(44-47)mma_sp_load_b_32x16_to_shared_16x32_layout(49-52)mma_sp_load_b_32x32_to_shared_16x64_layout(54-57)metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(63-67)metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(69-73)metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(75-76)metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(78-79)metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(84-88)metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(90-94)metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(96-101)get_ldmatrix_offset_b(123-157)tilelang/language/tir/op.py (3)
ptx_ldmatrix(1123-1159)address_of(463-479)ptx_mma_sp(963-1061)
🪛 Clang (14.0.6)
src/op/gemm_sp.cc
[error] 306-306: variable 'TVM_FFI_STATIC_INIT_BLOCK' is non-const and globally accessible, consider making it const
(cppcoreguidelines-avoid-non-const-global-variables,-warnings-as-errors)
src/op/gemm_sp_py.cc
[error] 93-93: variable name 'op' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 97-97: method 'GetGemmInst' can be made static
(readability-convert-member-functions-to-static,-warnings-as-errors)
[error] 97-97: 2 adjacent parameters of 'GetGemmInst' of similar type ('int') are easily swapped by mistake
(bugprone-easily-swappable-parameters,-warnings-as-errors)
[error] 98-98: variable 'warp_size' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 100-100: variable 'allow_wgmma' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 104-104: do not use 'else' after 'return'
(readability-else-after-return,-warnings-as-errors)
[error] 143-143: function 'CheckWGMMA' has cognitive complexity of 40 (threshold 25)
(readability-function-cognitive-complexity,-warnings-as-errors)
[error] 149-149: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 150-150: 16 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 151-151: do not use 'else' after 'return'
(readability-else-after-return,-warnings-as-errors)
[error] 151-151: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 152-152: repeated branch in conditional chain
(bugprone-branch-clone,-warnings-as-errors)
[error] 152-152: 32 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 153-153: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 154-154: 32 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 155-155: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 156-156: 32 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 157-157: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 158-158: 32 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 159-159: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 162-162: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 163-163: repeated branch in conditional chain
(bugprone-branch-clone,-warnings-as-errors)
[error] 163-163: 16 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 164-164: do not use 'else' after 'return'
(readability-else-after-return,-warnings-as-errors)
[error] 165-165: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 166-166: 16 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 167-167: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 168-168: 8 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 169-169: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 170-170: repeated branch in conditional chain
(bugprone-branch-clone,-warnings-as-errors)
[error] 170-170: 32 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 171-171: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 172-172: 32 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 173-173: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 174-174: 32 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 175-175: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 176-176: 32 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 177-177: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 180-180: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 181-181: repeated branch in conditional chain
(bugprone-branch-clone,-warnings-as-errors)
[error] 181-181: 32 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 182-182: do not use 'else' after 'return'
(readability-else-after-return,-warnings-as-errors)
[error] 182-182: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 183-183: 32 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 184-184: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 185-185: 32 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 186-186: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 187-187: 32 is a magic number; consider replacing it with a named constant
(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)
[error] 188-188: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 211-211: variable name 's' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 213-213: variable 'arch' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 224-224: variable 'gemm_inst' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 236-236: variable 'block_realize' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 239-239: variable 'n' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 239-239: variable name 'n' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 259-259: statement should be inside braces
(readability-braces-around-statements,-warnings-as-errors)
[error] 261-261: variable 'results' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
[error] 263-263: if with identical then and else branches
(bugprone-branch-clone,-warnings-as-errors)
[error] 274-274: variable 'TVM_REGISTER_OP' is non-const and globally accessible, consider making it const
(cppcoreguidelines-avoid-non-const-global-variables,-warnings-as-errors)
[error] 274-274: variable name 'op' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 279-279: variable 'TVM_FFI_STATIC_INIT_BLOCK' is non-const and globally accessible, consider making it const
(cppcoreguidelines-avoid-non-const-global-variables,-warnings-as-errors)
🪛 LanguageTool
docs/deeplearning_operators/matmul_sparse.md
[grammar] ~50-~50: Ensure spelling is correct
Context: ... in A_sparse/A and E. (i.e. the 4-elment group at [n, k] doesn't match the 4-bit...
(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)
[grammar] ~137-~137: Ensure spelling is correct
Context: ...each 4-bit chunk represents two 2-bit indcies of non-zero elements within four cons...
(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)
🪛 markdownlint-cli2 (0.18.1)
docs/deeplearning_operators/matmul_sparse.md
39-39: Link text should be descriptive
(MD059, descriptive-link-text)
39-39: Link text should be descriptive
(MD059, descriptive-link-text)
176-176: Link text should be descriptive
(MD059, descriptive-link-text)
🪛 Ruff (0.14.3)
tilelang/layout/__init__.py
7-7: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/language/__init__.py
50-50: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/tileop/gemm_sp/gemm_sp_mma.py
57-58: Avoid specifying long messages outside the exception class
(TRY003)
231-232: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tileop/gemm_sp/__init__.py
57-57: Avoid specifying long messages outside the exception class
(TRY003)
65-65: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/experimental/gemm_sp.py
153-153: Undefined name get_buffer_region_from_load
(F821)
159-160: Prefer TypeError exception for invalid type
(TRY004)
159-160: Avoid specifying long messages outside the exception class
(TRY003)
187-188: Prefer TypeError exception for invalid type
(TRY004)
187-188: Avoid specifying long messages outside the exception class
(TRY003)
191-191: Local variable E_shape is assigned to but never used
Remove assignment to unused variable E_shape
(F841)
239-239: Undefined name get_buffer_region_from_load
(F821)
253-254: Prefer TypeError exception for invalid type
(TRY004)
253-254: Avoid specifying long messages outside the exception class
(TRY003)
267-267: Undefined name get_buffer_region_from_load
(F821)
273-274: Prefer TypeError exception for invalid type
(TRY004)
273-274: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tileop/__init__.py
2-2: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/layout/gemm_sp.py
116-116: Avoid specifying long messages outside the exception class
(TRY003)
119-119: Avoid specifying long messages outside the exception class
(TRY003)
examples/gemm_sp/example_custom_compress.py
115-117: Avoid specifying long messages outside the exception class
(TRY003)
120-120: Local variable device is assigned to but never used
Remove assignment to unused variable device
(F841)
128-128: Avoid specifying long messages outside the exception class
(TRY003)
131-131: Avoid specifying long messages outside the exception class
(TRY003)
135-137: Avoid specifying long messages outside the exception class
(TRY003)
140-142: Avoid specifying long messages outside the exception class
(TRY003)
144-146: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/mma_sp_layout.py
96-96: Unused function argument: local_id
(ARG001)
139-139: Avoid specifying long messages outside the exception class
(TRY003)
155-155: Avoid specifying long messages outside the exception class
(TRY003)
157-157: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/mma_sp_macro_generator.py
51-59: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
61-118: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
120-128: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
178-180: Avoid specifying long messages outside the exception class
(TRY003)
213-213: Avoid specifying long messages outside the exception class
(TRY003)
320-320: Avoid specifying long messages outside the exception class
(TRY003)
389-389: Avoid specifying long messages outside the exception class
(TRY003)
398-398: Avoid specifying long messages outside the exception class
(TRY003)
403-403: Avoid specifying long messages outside the exception class
(TRY003)
405-405: Avoid specifying long messages outside the exception class
(TRY003)
456-456: Avoid specifying long messages outside the exception class
(TRY003)
551-551: Local variable thread_binding is assigned to but never used
Remove assignment to unused variable thread_binding
(F841)
705-705: Avoid specifying long messages outside the exception class
(TRY003)
722-722: Avoid specifying long messages outside the exception class
(TRY003)
780-780: Avoid specifying long messages outside the exception class
(TRY003)
794-794: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (52)
src/target/ptx.cc (1)
608-608: PTX syntax for sparse MMA is correct and recommended.The
".sp::ordered_metadata"syntax is the officially recommended form for sparse MMA operations, introduced in PTX ISA v8.5. Using::ordered_metadatais recommended because the genericmma.spform may have substantially reduced performance on some target architectures. This syntax is supported on modern SMs including sm80 and sm89. The code change correctly implements this best practice.tilelang/layout/__init__.py (1)
7-7: LGTM: Public API export updated correctly.The rename from
make_metadata_layouttomake_cutlass_metadata_layoutaligns with the broader migration to CUTLASS-specific metadata layouts throughout the codebase.tilelang/utils/tensor.py (2)
10-16: LGTM: Clean float8 dtype check.The function correctly identifies all torch float8 variants.
18-22: Add unit tests for fp8_remove_negative_zeros_ to verify correctness across all float8 formats.The function's bit manipulation logic is sound: it correctly identifies both positive and negative zeros using PyTorch's == operator (which follows IEEE-754 semantics where ±0.0 compare equal), then converts them all to positive zero (0x00). This works correctly across all supported float8 formats (e4m3fn, e4m3fnuz, e5m2, e5m2fnuz) since they all represent negative zero with the sign bit set (0x80) and positive zero as all bits zero (0x00).
However, no explicit unit tests were found for this critical function. Since it directly impacts sparse tensor compression, add tests that:
- Verify negative zeros are correctly identified and converted
- Confirm behavior is consistent across all four float8 dtypes
- Ensure the function modifies tensors in-place as expected
tilelang/language/__init__.py (1)
50-50: LGTM: New gemm_sp_v2 API exported correctly.Both
gemm_spandgemm_sp_v2are now available, maintaining backward compatibility while introducing the enhanced v2 API.tilelang/tileop/__init__.py (1)
2-2: LGTM: GemmSPPy operator exported correctly.The new sparse GEMM operator class is now available through the public API.
src/op/gemm_sp.h (1)
21-22: LGTM: TVM object metadata added correctly.The type key and reflection macros properly integrate
GemmSPWarpPolicyNodeinto the TVM object system, following the same pattern as the baseGemmWarpPolicyNode.examples/gemm_sp/test_example_gemm_sp.py (1)
1-15: LGTM: Test structure is clean and follows conventions.The test file properly wraps the example scripts with the testing harness. The tests rely on assertions within the
main()functions of each example module.src/op/gemm.h (1)
34-34: LGTM: Required change to enable inheritance.Changing from
TVM_DECLARE_FINAL_OBJECT_INFOtoTVM_DECLARE_BASE_OBJECT_INFOis necessary to allowGemmSPWarpPolicyNode(insrc/op/gemm_sp.h) to properly inherit fromGemmWarpPolicyNode.docs/deeplearning_operators/matmul_sparse.md (1)
1-261: Excellent documentation for the sparse GEMM feature.The documentation is comprehensive and well-structured, covering:
- Structured sparsity concepts
- Compression workflow
- Both
T.gemm_spandT.gemm_sp_v2APIs- Practical code examples
- Important notes on layout differences
The examples and explanations will help users understand when and how to use sparse GEMM operations.
src/op/gemm_sp.cc (1)
306-316: LGTM! FFI registration follows established patterns.The FFI registration correctly exposes
GemmSPWarpPolicyComputeWarpPartitionto Python. The lambda callsComputeWarpPartitionfor its side effects (mutatingm_warpandn_warpfields), which the Python side then retrieves, matching the pattern used inGemmWarpPolicy.Note: The static analysis warning about non-const global variables is a false positive for the
TVM_FFI_STATIC_INIT_BLOCKmacro.tilelang/tileop/gemm/__init__.py (3)
7-7: LGTM! Type annotation improvement.Adding explicit
Rangetype import improves type safety and aligns with the updated function signatures.
15-17: LGTM! Proper use of Range type.The updated signature and extraction of
thread_numsfromthread_bounds.extentis correct and improves type safety.
20-23: LGTM! Consistent Range usage.The pattern matches
gemm_py_infer_layoutand correctly extracts thread numbers from the Range.tilelang/ir.py (1)
41-51: LGTM! New sparse GEMM warp policy class.The
GemmSPWarpPolicyclass is well-structured and follows the same pattern asGemmWarpPolicy. The additionalbitsparameter incompute_warp_partitionis appropriate for handling different data type sizes in sparse operations.tilelang/intrinsics/mma_layout.py (2)
147-160: LGTM! Well-documented layout helper.The
mma_load_a_32x8_to_shared_16x16_layoutfunction is clearly documented with its mapping logic and correctly implements the 32x8 to 16x16 layout transformation.
167-179: LGTM! Complementary B-matrix layout.The
mma_load_b_32x8_to_shared_16x16_layoutfunction properly implements the corresponding layout for the B matrix with clear documentation.benchmark/matmul/benchmark_matmul_sp.py (4)
12-12: LGTM! Updated to new layout API.Correctly imports the renamed
make_cutlass_metadata_layoutfunction.
89-89: LGTM! Added in_dtype parameter.The function signature now properly includes
in_dtypefor flexible data type support.
206-210: LGTM! Using updated layout helper.Correctly uses
make_cutlass_metadata_layoutwith appropriate parameters for both global and shared metadata buffers.
222-229: LGTM! Updated to gemm_sp_v2 API.The call correctly uses the new
T.gemm_sp_v2API with appropriate parameters.tilelang/utils/sparse.py (4)
92-103: LGTM! Comprehensive compress_sm80 enhancement.The addition of transposed flag support and float8 handling is well-implemented. The conversion to int8 for compression and back to the original float8 dtype maintains type correctness.
120-121: LGTM! Float32 sparsity support.Correctly adjusts to 1:2 sparsity pattern for float32 dtype, which is appropriate for this data type.
130-157: LGTM! New integer semi-sparse generator.The
randint_semi_sparsefunction is well-implemented and follows the same pattern asrandn_semi_sparse. The docstring is clear and comprehensive.
174-175: LGTM! Consistent float32 handling.Correctly applies the same 1:2 sparsity pattern for float32 dtype as in other semi-sparse generators.
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (2)
4-4: LGTM! Updated to new layout API.Correctly imports the renamed
make_cutlass_metadata_layoutfunction.
43-50: LGTM! Consistent layout usage.Both metadata layout annotations correctly use
make_cutlass_metadata_layoutwith appropriate SM90 architecture specification and block_k parameter.tilelang/tileop/gemm_sp/gemm_sp_mma.py (2)
14-58: Layout inference implementation looks correct.The
infer_layoutmethod correctly handles the four GEMM dataflow variants (ss, sr, rs, rr) and delegates to the appropriate layout builders. The use ofSparseTensorCoreIntrinEmitterfor layout management is consistent with the design.
92-131: SS (shared-shared) kernel variant is well-structured.The
_gemm_ssrkernel correctly allocates local fragments for A, E, and B, loads them from shared memory, and performs the sparse MMA operation. The simplification pass withinline_let=Trueis appropriate for optimizing index computations.examples/gemm_sp/example_gemm_sp.py (2)
8-8: API migration togemm_sp_v2andmake_cutlass_metadata_layoutis correct.The changes consistently update the example to use the new v2 API:
- Import path updated to
make_cutlass_metadata_layout- Kernel call updated from
T.gemm_sptoT.gemm_sp_v2- Layout factory calls updated to use
make_cutlass_metadata_layoutwith explicitarchparameterAlso applies to: 86-93, 99-99
17-58: Configuration improvements enhance usability.The renaming to
DEFAULT_CONFIG(line 17), addition ofARCH_INFO(line 60), and providing a default value for--cfg(line 118) make the example more accessible and self-contained.Also applies to: 60-60, 118-121
tilelang/intrinsics/mma_macro_generator.py (2)
21-24: 16-bit load layout integration is correct.The changes properly integrate the new 32x8 → 16x16 load layouts for 16-bit A and B matrices when
ldmatrixis not available. The layout functionsmma_load_a_32x8_to_shared_16x16_layoutandmma_load_b_32x8_to_shared_16x16_layoutare imported and used consistently.Also applies to: 223-224, 291-292
263-263: Transpose-aware indexing correctly handles memory access patterns.The conditional indexing based on the
transflag ensures that elements are loaded from the correct memory locations for both transposed and non-transposed cases:
- Line 263 (A): Uses
A_shared_buf[wk + mk, wi + mi]when transposed, elseA_shared_buf[wi + mi, wk + mk]- Line 337 (B): Uses
B_shared_buf[wi + mi, wk + mk]whenb_transposed, elseB_shared_buf[wk + mk, wi + mi]Also applies to: 337-337
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (4)
14-28: Newgenerate_dense_inputhelper improves test clarity.The centralized input generation function properly handles different data types (int vs float, 8-bit vs others, signed vs unsigned) and transpose modes. This reduces duplication and makes tests more maintainable.
111-112: Metadata dtype determination is cleaner and more robust.The new approach using
SparseTensorCoreIntrinEmitter.E_FACTOR_MAPprovides a cleaner and more maintainable way to determine the correct metadata dtype based on input dtype, replacing the previous hardcoded logic.
65-78: Fragment buffer naming and layout updates are correct.The changes from
C_localtoC_frag(lines 65, 78) accurately reflect that these are fragment buffers, and the updatedmake_cutlass_metadata_layoutcalls (lines 68-75, 134-137) align with the new API.Also applies to: 131-139
209-216: Test assertion improvements enhance debugging.The use of
torch_assert_closewith named tensors (base_name="tilelang_sp",ref_name="ref_dense") provides better error messages and aligns with the broader test suite improvements.tilelang/language/experimental/gemm_sp.py (1)
278-281: No issues found; the offset assertion logic is correct for higher-dimensional tensors.The code uses
-2indexing intentionally and correctly. The shape constraints (lines 121-130 in gemm.py) enforce that for tensors with more than 2 dimensions, all non-matrix dimensions must equal 1. This means the last two dimensions always contain the matrix data, makingA_offset[-2]andA_offset[-1]consistently refer to the offsets of the first and second matrix dimensions, regardless of total tensor dimensionality.For example:
- 2D tensor:
A_offsethas 2 elements;[-2]→ first matrix dimension ✓- 3D tensor:
A_offsethas 3 elements;[-2]→ second-to-last (still first matrix dimension) ✓The error message could be more explicit (e.g., "first matrix dimension"), but the logic is sound.
tilelang/tileop/gemm_sp/gemm_sp_base.py (2)
20-30: LGTM!The
is_gemm_*helper methods clearly express the four GEMM scope combinations (shared-shared, shared-register, register-shared, register-register) and correctly delegate tois_sharedandis_fragmentutilities.
32-127: LGTM!The property accessors provide a clean, read-only interface to the underlying
gemm_sp_nodeattributes. The assertion inin_dtype(line 58) ensures A and B have matching data types, which is a good defensive check.examples/gemm_sp/example_custom_compress.py (2)
314-361: LGTM!The main function has a good structure: argument parsing, input generation, compression (with alternative compressor support), correctness validation, and performance benchmarking. The assertion at line 337 ensures torch_compress is only used with the naive layout, which is a good safeguard.
66-108: LGTM!The
matmul_sp_fp16_custom_compresskernel correctly implements the sparse GEMM with optional Cutlass metadata layout. The conditional layout annotation (lines 84-95) properly handles both layout variants.tilelang/tileop/gemm_sp/__init__.py (1)
24-65: LGTM!The
GemmSPPyclass provides a clean FFI-compatible wrapper with proper field annotations. The delegation toGemmSPMMAfor CUDA targets is the correct pattern, allowing target-specific implementations while maintaining a uniform interface.testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (3)
10-136: LGTM!The
matmulandrun_gemm_ssfunctions provide a clean testing framework with proper layout annotations, compilation configurations, and validation against PyTorch references. The use ofSparseTensorCoreIntrinEmitter.E_FACTOR_MAPensures correct metadata dimensions across different dtype combinations.
184-664: LGTM!The test functions for
rs,sr, andrrvariants follow a consistent structure with comprehensive coverage of transpose combinations, data types (float16, int8, float8, bfloat16), and edge cases (n8 dimensions). The validation against PyTorch references ensures correctness across all scope combinations.
158-158: The review comment is incorrect. Transposed A compression is fully supported.The compress_sm90.cu template has native transposed support with conditional layout handling (ColumnMajor for transposed, RowMajor otherwise), dimension extraction based on the transposed flag, and proper output shape handling. The tests already exercise trans_A=True cases (lines 161–162, 171–172, 175), and the compress() function passes the transposed parameter through to the implementation. The TODO comment appears outdated.
Likely an incorrect or invalid review comment.
tilelang/intrinsics/mma_sp_layout.py (6)
1-12: LGTM!The imports are well-organized and all appear to be used in the module. Good reuse of existing layout functions from
mma_layout.
14-33: LGTM!The layout conversion functions are well-structured. The A matrix functions appropriately delegate to existing implementations, while the B matrix functions provide new coordinate transformations specific to sparse layouts.
35-42: LGTM!The A matrix load helpers correctly delegate to existing implementations with appropriate dimension mapping.
44-57: LGTM!The B matrix load helpers provide new coordinate transformations that follow consistent patterns with appropriate dimension handling.
123-157: LGTM!The function has good error handling with clear, specific error messages that aid debugging. The dtype branching logic correctly handles different bit-widths and transposition modes.
Note: The static analysis warnings about "long error messages" (lines 139, 155, 157) are minor style preferences. The inline error messages are clear and appropriate for simple validation errors.
103-106: ****The different implementation is intentional and by design. The
mma_sp_layout.pymodule explicitly does not importldmatrix_trans_32x8_to_shared_16x16_layoutfrommma_layout.py—the import list shows it selectively imports 6 other functions but omits this one. Instead,mma_sp_layout.pydefines its own SP-specific version (lines 103–106) and uses it locally (line 135). This is a valid design pattern: the sparse matrix module maintains its own specialized layout transformation logic separate from the general module. There is no oversight here.Likely an incorrect or invalid review comment.
| static int GetArchInt(Target target) { | ||
| int arch_int = 0; | ||
| auto s = target->GetAttr<String>("arch"); | ||
| ICHECK(s.defined()); | ||
| std::string arch = s.value(); | ||
| if (arch.rfind("sm_", 0) == 0) { | ||
| arch_int = std::stoi(arch.substr(3)); | ||
| } else { | ||
| arch_int = 0; | ||
| } | ||
| return arch_int; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for std::stoi.
The GetArchInt function uses std::stoi without error handling. If the architecture string contains non-numeric characters after "sm_" (e.g., "sm_abc"), std::stoi will throw std::invalid_argument or std::out_of_range, potentially causing unexpected crashes.
Apply this diff to add error handling:
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
- arch_int = std::stoi(arch.substr(3));
+ try {
+ arch_int = std::stoi(arch.substr(3));
+ } catch (const std::exception& e) {
+ LOG(WARNING) << "Failed to parse architecture number from '" << arch << "': " << e.what();
+ arch_int = 0;
+ }
} else {
arch_int = 0;
}
return arch_int;
}🧰 Tools
🪛 Clang (14.0.6)
[error] 211-211: variable name 's' is too short, expected at least 3 characters
(readability-identifier-length,-warnings-as-errors)
[error] 213-213: variable 'arch' is not initialized
(cppcoreguidelines-init-variables,-warnings-as-errors)
src/op/gemm_sp_py.h
Outdated
| bool SEqualReduce(const GemmSPPyNode *other, SEqualReducer equal) const { | ||
| return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) && | ||
| equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) && | ||
| equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) && | ||
| equal(trans_B, other->trans_B) && equal(M, other->M) && | ||
| equal(N, other->N) && equal(K, other->K) && | ||
| equal(stride_A, other->stride_A) && | ||
| equal(stride_B, other->stride_B) && | ||
| equal(offset_A, other->offset_B) && | ||
| equal(offset_B, other->offset_B) && | ||
| equal(clear_accum, other->clear_accum) && | ||
| equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) && | ||
| equal(policy, other->policy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SEqual must compare the metadata buffer correctly
SEqualReduce skips E/Eptr entirely and even checks offset_A against other->offset_B. That lets operators with different metadata collapse together in the structural cache, yielding wrong reuse. Please include the metadata fields and fix the offset comparison.
- return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
- equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) &&
- equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
+ return equal(A, other->A) && equal(E, other->E) && equal(B, other->B) &&
+ equal(C, other->C) && equal(Aptr, other->Aptr) &&
+ equal(Eptr, other->Eptr) && equal(Bptr, other->Bptr) &&
+ equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(stride_A, other->stride_A) &&
equal(stride_B, other->stride_B) &&
- equal(offset_A, other->offset_B) &&
+ equal(offset_A, other->offset_A) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
equal(policy, other->policy);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| bool SEqualReduce(const GemmSPPyNode *other, SEqualReducer equal) const { | |
| return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) && | |
| equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) && | |
| equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) && | |
| equal(trans_B, other->trans_B) && equal(M, other->M) && | |
| equal(N, other->N) && equal(K, other->K) && | |
| equal(stride_A, other->stride_A) && | |
| equal(stride_B, other->stride_B) && | |
| equal(offset_A, other->offset_B) && | |
| equal(offset_B, other->offset_B) && | |
| equal(clear_accum, other->clear_accum) && | |
| equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) && | |
| equal(policy, other->policy); | |
| bool SEqualReduce(const GemmSPPyNode *other, SEqualReducer equal) const { | |
| return equal(A, other->A) && equal(E, other->E) && equal(B, other->B) && | |
| equal(C, other->C) && equal(Aptr, other->Aptr) && | |
| equal(Eptr, other->Eptr) && equal(Bptr, other->Bptr) && | |
| equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) && | |
| equal(trans_B, other->trans_B) && equal(M, other->M) && | |
| equal(N, other->N) && equal(K, other->K) && | |
| equal(stride_A, other->stride_A) && | |
| equal(stride_B, other->stride_B) && | |
| equal(offset_A, other->offset_A) && | |
| equal(offset_B, other->offset_B) && | |
| equal(clear_accum, other->clear_accum) && | |
| equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) && | |
| equal(policy, other->policy); |
🤖 Prompt for AI Agents
In src/op/gemm_sp_py.h around lines 67 to 79, the SEqualReduce implementation
omits comparing the metadata fields E and Eptr and incorrectly compares offset_A
to other->offset_B; update the boolean chain to include equal(E, other->E) &&
equal(Eptr, other->Eptr) and change the incorrect comparison to equal(offset_A,
other->offset_A) so all metadata buffers and offsets are checked correctly.
src/op/gemm_sp_py.h
Outdated
| void SHashReduce(SHashReducer hash_reduce) const { | ||
| hash_reduce(A); | ||
| hash_reduce(B); | ||
| hash_reduce(C); | ||
| hash_reduce(Aptr); | ||
| hash_reduce(Bptr); | ||
| hash_reduce(Cptr); | ||
| hash_reduce(trans_A); | ||
| hash_reduce(trans_B); | ||
| hash_reduce(M); | ||
| hash_reduce(N); | ||
| hash_reduce(K); | ||
| hash_reduce(stride_A); | ||
| hash_reduce(stride_B); | ||
| hash_reduce(offset_A); | ||
| hash_reduce(offset_B); | ||
| hash_reduce(clear_accum); | ||
| hash_reduce(kPack); | ||
| hash_reduce(wg_wait); | ||
| hash_reduce(policy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hash must cover the metadata too
SHashReduce omits both E and Eptr, so nodes that only differ in metadata hash to the same bucket. After fixing SEqualReduce, this still causes excessive collisions – please fold the metadata into the hash.
- hash_reduce(A);
- hash_reduce(B);
- hash_reduce(C);
- hash_reduce(Aptr);
- hash_reduce(Bptr);
- hash_reduce(Cptr);
+ hash_reduce(A);
+ hash_reduce(E);
+ hash_reduce(B);
+ hash_reduce(C);
+ hash_reduce(Aptr);
+ hash_reduce(Eptr);
+ hash_reduce(Bptr);
+ hash_reduce(Cptr);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| void SHashReduce(SHashReducer hash_reduce) const { | |
| hash_reduce(A); | |
| hash_reduce(B); | |
| hash_reduce(C); | |
| hash_reduce(Aptr); | |
| hash_reduce(Bptr); | |
| hash_reduce(Cptr); | |
| hash_reduce(trans_A); | |
| hash_reduce(trans_B); | |
| hash_reduce(M); | |
| hash_reduce(N); | |
| hash_reduce(K); | |
| hash_reduce(stride_A); | |
| hash_reduce(stride_B); | |
| hash_reduce(offset_A); | |
| hash_reduce(offset_B); | |
| hash_reduce(clear_accum); | |
| hash_reduce(kPack); | |
| hash_reduce(wg_wait); | |
| hash_reduce(policy); | |
| void SHashReduce(SHashReducer hash_reduce) const { | |
| hash_reduce(A); | |
| hash_reduce(E); | |
| hash_reduce(B); | |
| hash_reduce(C); | |
| hash_reduce(Aptr); | |
| hash_reduce(Eptr); | |
| hash_reduce(Bptr); | |
| hash_reduce(Cptr); | |
| hash_reduce(trans_A); | |
| hash_reduce(trans_B); | |
| hash_reduce(M); | |
| hash_reduce(N); | |
| hash_reduce(K); | |
| hash_reduce(stride_A); | |
| hash_reduce(stride_B); | |
| hash_reduce(offset_A); | |
| hash_reduce(offset_B); | |
| hash_reduce(clear_accum); | |
| hash_reduce(kPack); | |
| hash_reduce(wg_wait); | |
| hash_reduce(policy); |
🤖 Prompt for AI Agents
In src/op/gemm_sp_py.h around lines 82 to 101, SHashReduce currently omits
hashing the metadata fields E and Eptr causing different-metadata nodes to
collide; update SHashReduce to call hash_reduce(E) and hash_reduce(Eptr)
(placing them logically with the other pointer/metadata fields, maintaining
consistent order) so the metadata is folded into the hash and collisions are
avoided.
| a_is_fragment = is_fragment(A_local_buf) | ||
| e_is_fragment = is_fragment(E_local_buf) | ||
| b_is_fragment = is_fragment(B_local_buf) | ||
| assert not e_is_fragment, f"currently E_local_buf must be a fragment buffer, found {E_local_buf.scope()}" | ||
| a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix fragment assertion for metadata buffer
is_fragment(E_local_buf) must be True, but the assertion requires it to be False, so we abort as soon as the metadata fragment is passed in. Please drop the negation.
- assert not e_is_fragment, f"currently E_local_buf must be a fragment buffer, found {E_local_buf.scope()}"
+ assert e_is_fragment, f"currently E_local_buf must be a fragment buffer, found {E_local_buf.scope()}"🤖 Prompt for AI Agents
In tilelang/intrinsics/mma_sp_macro_generator.py around lines 543 to 547, the
assertion for E_local_buf incorrectly negates is_fragment so it fails when
E_local_buf is a fragment; change the assertion to require
is_fragment(E_local_buf) is True (remove the `not`) so the code asserts
E_local_buf is a fragment and keep the error message using E_local_buf.scope()
for context.
tilelang/tileop/gemm_sp/__init__.py
Outdated
| @tvm.ffi.register_func("tl.gemm_sp_py.infer_layout") | ||
| def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range): | ||
| thread_nums = thread_bounds.extent | ||
| return gemm_sp_py.infer_layout(target, thread_nums) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct the parameter type annotation.
The function is registered to receive a GemmSPPy object (based on the C++ call in src/op/gemm_sp_py.cc lines 229-231), but the type annotation indicates GemmSPMMA. This mismatch could cause confusion and type-checking failures.
Apply this diff:
@tvm.ffi.register_func("tl.gemm_sp_py.infer_layout")
-def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range):
+def gemm_sp_py_infer_layout(gemm_sp_py: "GemmSPPy", target: Target, thread_bounds: Range):
thread_nums = thread_bounds.extent
return gemm_sp_py.infer_layout(target, thread_nums)Note: Use forward reference "GemmSPPy" since the class is defined later in the file.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @tvm.ffi.register_func("tl.gemm_sp_py.infer_layout") | |
| def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range): | |
| thread_nums = thread_bounds.extent | |
| return gemm_sp_py.infer_layout(target, thread_nums) | |
| @tvm.ffi.register_func("tl.gemm_sp_py.infer_layout") | |
| def gemm_sp_py_infer_layout(gemm_sp_py: "GemmSPPy", target: Target, thread_bounds: Range): | |
| thread_nums = thread_bounds.extent | |
| return gemm_sp_py.infer_layout(target, thread_nums) |
🤖 Prompt for AI Agents
In tilelang/tileop/gemm_sp/__init__.py around lines 13 to 16, the parameter type
annotation for gemm_sp_py_infer_layout incorrectly uses GemmSPMMA when the
registered function actually receives a GemmSPPy; update the annotation to use
the forward reference "GemmSPPy" (a string) for the gemm_sp_py parameter so it
matches the C++ call and avoids type-checking confusion, leaving the rest of the
signature and body unchanged.
tilelang/tileop/gemm_sp/__init__.py
Outdated
| @tvm.ffi.register_func("tl.gemm_sp_py.lower") | ||
| def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range, thread_var: tir.Var): | ||
| thread_nums = thread_bounds.extent | ||
| stmt = gemm_sp_py.lower(target, thread_nums, thread_var) | ||
| return stmt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct the parameter type annotation.
The function is registered to receive a GemmSPPy object, but the type annotation indicates GemmSPMMA.
Apply this diff:
@tvm.ffi.register_func("tl.gemm_sp_py.lower")
-def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range, thread_var: tir.Var):
+def gemm_sp_py_lower(gemm_sp_py: "GemmSPPy", target: Target, thread_bounds: Range, thread_var: tir.Var):
thread_nums = thread_bounds.extent
stmt = gemm_sp_py.lower(target, thread_nums, thread_var)
return stmtThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
♻️ Duplicate comments (1)
src/op/gemm_sp_py.cc (1)
210-221: Add error handling forstd::stoi.The previous review correctly identified that
std::stoiat line 216 can throwstd::invalid_argumentorstd::out_of_rangeif the architecture string contains non-numeric characters (e.g., "sm_abc"). This remains unaddressed.
🧹 Nitpick comments (1)
benchmark/matmul/benchmark_matmul_sp.py (1)
95-114: Docstring omits the newin_dtypeargument.Please update the parameter section to describe both
in_dtypeandaccum_dtypeso the public API matches the signature.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
benchmark/matmul/benchmark_matmul_sp.py(7 hunks)docs/index.md(1 hunks)examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py(2 hunks)src/op/gemm_sp.cc(1 hunks)src/op/gemm_sp.h(1 hunks)src/op/gemm_sp_py.cc(1 hunks)src/op/gemm_sp_py.h(1 hunks)src/target/ptx.cc(1 hunks)src/tl_templates/cuda/debug.h(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- docs/index.md
- src/target/ptx.cc
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
🧬 Code graph analysis (6)
src/op/gemm_sp.h (1)
src/op/gemm.h (2)
RegisterReflection(35-41)RegisterReflection(106-128)
src/op/gemm_sp.cc (1)
src/op/gemm_sp.h (1)
RegisterReflection(27-33)
benchmark/matmul/benchmark_matmul_sp.py (4)
tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(137-152)examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)
matmul_sp(9-60)tilelang/language/allocate.py (1)
alloc_shared(27-42)tilelang/language/experimental/gemm_sp.py (1)
gemm_sp_v2(90-308)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)
tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(137-152)
src/op/gemm_sp_py.cc (4)
src/op/gemm_sp_py.h (2)
GemmSPPy(79-84)RegisterReflection(40-64)tilelang/tileop/gemm_sp/__init__.py (1)
GemmSPPy(25-65)src/op/gemm.h (6)
GemmWarpPolicy(59-83)GemmWarpPolicy(64-68)GemmWarpPolicy(70-74)GemmWarpPolicy(76-82)RegisterReflection(35-41)RegisterReflection(106-128)src/op/gemm_sp.cc (6)
Clone(120-123)Clone(120-120)Lower(143-185)Lower(143-143)InferLayout(215-299)InferLayout(215-216)
src/op/gemm_sp_py.h (3)
src/op/gemm_sp.h (3)
tvm(13-112)tl(15-111)RegisterReflection(27-33)src/op/operator.h (2)
TileOperatorNode(56-66)TileOperator(68-72)src/op/gemm_sp_py.cc (11)
CheckWGMMA(143-194)CheckWGMMA(143-143)Lower(223-256)Lower(223-223)InferLayout(258-273)InferLayout(258-259)Clone(92-95)Clone(92-92)GetGemmInst(97-111)GetGemmInst(97-97)GemmSPPy(50-82)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (14)
src/tl_templates/cuda/debug.h (1)
261-269: LGTM! Clean addition for uint16_t debugging support.The specialization correctly follows the established pattern (mirroring the
int16_tvariant), uses the appropriate format specifier (%u), and casts touint32_tfor safe variadic printf usage. This addition is useful for debugging sparse GEMM metadata, which is often stored in compact 16-bit unsigned formats.src/op/gemm_sp_py.cc (5)
92-95: LGTM!The Clone implementation correctly follows the established pattern seen in
GemmSPNode::Clone(src/op/gemm_sp.cc:120-123).
97-111: LGTM!The instruction selection logic correctly chooses between WGMMA, MFMA, and MMA based on target architecture and constraints. The
CheckWGMMA()placeholder returning false (line 101) is appropriately documented at line 144.
143-194: Placeholder implementation with extensive commented roadmap.The method returns
falsewith a fully documented implementation commented out. This aligns with the PR's roadmap showing WGMMA support as incomplete. The commented code and comprehensive docstring (lines 113-142) provide clear eligibility criteria for future enablement.Is WGMMA support expected in a follow-up PR, or should this placeholder be tracked separately? The checklist in the PR description doesn't explicitly mention WGMMA as a deliverable for this PR.
258-273: LGTM!The layout inference correctly delegates to the FFI function and guards against re-entrance with the
completed_flag, matching the pattern established inGemmSPNode::InferLayout.
275-281: LGTM!The operator registration correctly declares 5 inputs and
kOpaquecall effect, and properly registers reflection in the static initialization block.src/op/gemm_sp_py.h (3)
7-7: TODO: Address code duplication withgemm_py.h.The TODO at line 7 notes planned refactoring to eliminate duplication with
gemm_py.h. Given that this is a new sparse GEMM variant, some initial duplication is acceptable, but this should be tracked.Is this TODO tracked in an issue or follow-up work item?
40-64: LGTM!The reflection registration comprehensively exposes all relevant fields including the metadata buffer
Eand pointerEptr, which correctly distinguishes this sparse operator from the dense GEMM variant.
79-84: LGTM!The wrapper class follows the standard TVM object reference pattern with appropriate constructor and accessor declarations.
src/op/gemm_sp.h (1)
27-33: LGTM!The reflection registration for
GemmSPWarpPolicyNodecorrectly mirrors the established pattern fromGemmWarpPolicyNode(src/op/gemm.h:34-40) and properly exposespolicy_type,m_warp, andn_warpfields.src/op/gemm_sp.cc (2)
306-307: LGTM!The operator registration for
tl.GemmSPWarpPolicycorrectly sets the TScriptPrinterName attribute for script printing support.
309-312: LGTM!The static initialization block correctly registers reflection for both
GemmSPNodeand the newly addedGemmSPWarpPolicyNode.examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (2)
4-4: LGTM! Clean API migration.The migration from
make_metadata_layouttomake_cutlass_metadata_layoutis consistent and correct. Removing thebackend="cutlass"parameter makes sense as the backend is now implied by the function name itself.Also applies to: 43-50
57-57: No changes required — this example is outside the PR's scope.Both
T.gemm_spandT.gemm_sp_v2coexist in the codebase with identical signatures. The sparse_tensorcore example targets SM90 (arch="9.0"), while this PR specifically adds support forT.gemm_sp_v2on SM80 and SM89. Other examples in thegemm_sp/directory have already been updated to useT.gemm_sp_v2, indicating the API direction; however, the sparse_tensorcore example remains appropriately on the legacy function for its architecture-specific use case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (13)
docs/deeplearning_operators/matmul_sparse.md (2)
50-50: Fix typo: "elment" → "element".
137-137: Fix typo: "indcies" → "indices".src/op/gemm_sp.cc (1)
313-319: Return the computed warp partition from the FFI binding.The lambda currently discards the
std::pair<int, int>returned byComputeWarpPartitionand returns void. The Python caller expects to receive the computed(m_warp, n_warp)values.Apply this diff:
- "tl.GemmSPWarpPolicyComputeWarpPartition", - [](GemmSPWarpPolicy policy, int M, int N, int block_size, Target target, - bool use_wgmma, int bits) { - policy->ComputeWarpPartition(M, N, block_size, target, use_wgmma, bits); - return; + "tl.GemmSPWarpPolicyComputeWarpPartition", + [](GemmSPWarpPolicy policy, int M, int N, int block_size, Target target, + bool use_wgmma, int bits) { + return policy->ComputeWarpPartition(M, N, block_size, target, use_wgmma, bits); });tilelang/tileop/gemm_sp/gemm_sp_mma.py (1)
208-230: Fix duplicate function name in therrvariant.Line 209 defines
_gemm_rsrfor the register-register (rr) case, but this name conflicts with the register-shared (rs) variant at line 174. This should be a distinct name, such as_gemm_rrr, to correctly represent the register-register semantics.Apply this diff:
elif self.is_gemm_rr(): A_local = self.A B_local = self.B @T.prim_func - def _gemm_rsr() -> None: + def _gemm_rrr() -> None: """ The inner macro that loads data from shared buffers A_shared and B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype) for ki in T.serial(0, (self.K // micro_size_k)): # Load E into fragment mma_emitter.ldmatrix_e( E_local, E_shared, ki, ) # Perform Matrix Multiplication mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki) # Simplify to optimize the index computing # Must inline let statements to simplify the analysis - return _Simplify(_gemm_rsr, inline_let=True) + return _Simplify(_gemm_rrr, inline_let=True)src/op/gemm_sp_py.cc (3)
67-67: UseGemmSPWarpPolicyinstead ofGemmWarpPolicy.Line 67 constructs a
GemmWarpPolicy, but the sparse GEMM path requiresGemmSPWarpPolicy. SinceGemmSPWarpPolicyhas specializedComputeWarpPartitionlogic for sparse GEMM (as shown in src/op/gemm_sp.cc:21-63), using the base type will bypass critical sparse-specific warp-partition adjustments.Apply this diff:
- node->policy = GemmWarpPolicy(args[9].as<IntImm>().value()->value); + node->policy = GemmSPWarpPolicy(args[9].as<IntImm>().value()->value);
212-223: Add error handling forstd::stoi.The
GetArchIntfunction usesstd::stoiwithout error handling. If the architecture string contains non-numeric characters after "sm_" (e.g., "sm_abc"),std::stoiwill throwstd::invalid_argumentorstd::out_of_range, potentially causing unexpected crashes.Apply this diff:
static int GetArchInt(Target target) { int arch_int = 0; auto s = target->GetAttr<String>("arch"); - ICHECK(s.has_value()); + ICHECK(s.defined()); std::string arch = s.value(); if (arch.rfind("sm_", 0) == 0) { - arch_int = std::stoi(arch.substr(3)); + try { + arch_int = std::stoi(arch.substr(3)); + } catch (const std::exception& e) { + LOG(WARNING) << "Failed to parse architecture number from '" << arch << "': " << e.what(); + arch_int = 0; + } } else { arch_int = 0; } return arch_int; }
225-230: FixComputeWarpPartitioncall signature.Lines 229-230 call
policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst), butGemmSPWarpPolicyNode::ComputeWarpPartitionexpects(int M, int N, int block_size, Target target, bool use_wgmma, int bits)as shown in src/op/gemm_sp.h:21-23 and src/op/gemm_sp.cc:150-151. The current code passes aGemmInstenum where abool use_wgmmais expected, and omits the requiredbitsparameter.Apply this diff:
+ bool use_wgmma = (gemm_inst == GemmInst::kWGMMA); auto [warp_m, warp_n] = - policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst); + policy->ComputeWarpPartition(M, N, block_size, T.target, use_wgmma, A->dtype.bits());src/op/gemm_sp_py.h (1)
36-36: DeclarepolicyasGemmSPWarpPolicyinstead ofGemmWarpPolicy.Line 36 declares
policyasmutable GemmWarpPolicy, butGemmSPNode(src/op/gemm_sp.h:73) usesmutable GemmSPWarpPolicy. SinceGemmSPWarpPolicy::ComputeWarpPartitionincludes sparse-specific atom-size adjustments (src/op/gemm_sp.cc:31-62) that the baseGemmWarpPolicylacks, using the base type will skip critical sparse GEMM layout constraints.Apply this diff:
- mutable GemmWarpPolicy policy; + mutable GemmSPWarpPolicy policy;tilelang/tileop/gemm_sp/__init__.py (2)
14-17: Correct the parameter type annotation.The function is registered to receive a
GemmSPPyobject (based on the C++ call in src/op/gemm_sp_py.cc lines 232-235), but the type annotation indicatesGemmSPMMA. This mismatch could cause confusion and type-checking failures.Apply this diff:
@tvm_ffi.register_global_func("tl.gemm_sp_py.infer_layout") -def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range): +def gemm_sp_py_infer_layout(gemm_sp_py: "GemmSPPy", target: Target, thread_bounds: Range): thread_nums = thread_bounds.extent return gemm_sp_py.infer_layout(target, thread_nums)Note: Use forward reference
"GemmSPPy"since the class is defined later in the file.
20-25: Correct the parameter type annotation.The function is registered to receive a
GemmSPPyobject, but the type annotation indicatesGemmSPMMA.Apply this diff:
@tvm_ffi.register_global_func("tl.gemm_sp_py.lower") -def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range, +def gemm_sp_py_lower(gemm_sp_py: "GemmSPPy", target: Target, thread_bounds: Range, thread_var: tir.Var): thread_nums = thread_bounds.extent stmt = gemm_sp_py.lower(target, thread_nums, thread_var) return stmttilelang/intrinsics/mma_sp_macro_generator.py (3)
382-390: Fix the conditional chain to prevent fallthrough error.The second
ifat Line 385 should beelifto form a proper conditional ladder. Currently, whena_dtypeis 8-bit, the code correctly assignsmetadata_16bit_load_32x2_to_shared_16x4_layout_8bit, but then the followingifevaluates to false and execution falls through to theelseblock at Line 389, raising an error incorrectly.Apply this diff to fix the issue:
elif DataType(e_dtype).bits == 16: if DataType(a_dtype).bits == 8: mma_load_layout = metadata_16bit_load_32x2_to_shared_16x4_layout_8bit - if DataType(a_dtype).bits == 16: + elif DataType(a_dtype).bits == 16: mma_load_layout = metadata_16bit_load_32x2_to_shared_16x2_layout_16bit elif DataType(a_dtype).bits == 32: mma_load_layout = metadata_16bit_load_32x2_to_shared_16x2_layout_32bit else: raise ValueError(f"Unsupported a_dtype for e_dtype 16bit: {a_dtype}")
431-432: Fix inverted ldmatrix availability logic for int8 transpose.The current condition enables ldmatrix for the unsupported int8+transpose case. The predicate should be
not (DataType(b_dtype).bits != 16 and b_transposed)to correctly disable ldmatrix when B is int8 and transposed.Apply this diff to fix the logic:
-ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) +ldmatrix_available = not (DataType(b_dtype).bits != 16 and b_transposed)
537-540: Fix negated fragment assertion for E_local_buf.The assertion requires
not e_is_fragmentwhile the error message states E_local_buf "must be a fragment buffer". This contradiction causes the assertion to fail when E_local_buf is correctly passed as a fragment. Remove thenotto match the intended behavior.Apply this diff to fix the assertion:
-assert not e_is_fragment, f"currently E_local_buf must be a fragment buffer, found {E_local_buf.scope()}" +assert e_is_fragment, f"currently E_local_buf must be a fragment buffer, found {E_local_buf.scope()}"
🧹 Nitpick comments (1)
tilelang/intrinsics/mma_sp_macro_generator.py (1)
52-129: Consider annotating constant class dictionaries withClassVar.The
dtype_abbrv,E_FACTOR_MAP, andE_REPLICATE_FACTORdictionaries are shared across all instances and act as constant lookup tables. Annotating them withtyping.ClassVarclarifies that they are class-level constants rather than instance attributes.Apply this diff to add the annotations:
+from typing import ClassVar + class SparseTensorCoreIntrinEmitter: ... - dtype_abbrv = { + dtype_abbrv: ClassVar[dict[str, str]] = { ... } - E_FACTOR_MAP = { + E_FACTOR_MAP: ClassVar[dict[str, dict[str, int]]] = { ... } - E_REPLICATE_FACTOR = { + E_REPLICATE_FACTOR: ClassVar[dict[str, int]] = { ... }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (24)
benchmark/matmul/benchmark_matmul_sp.py(7 hunks)docs/deeplearning_operators/matmul_sparse.md(1 hunks)examples/gemm_sp/example_custom_compress.py(1 hunks)examples/gemm_sp/example_gemm_sp.py(5 hunks)examples/gemm_sp/test_example_gemm_sp.py(1 hunks)examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py(2 hunks)src/op/gemm_sp.cc(1 hunks)src/op/gemm_sp_py.cc(1 hunks)src/op/gemm_sp_py.h(1 hunks)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py(7 hunks)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py(1 hunks)tilelang/intrinsics/mma_layout.py(1 hunks)tilelang/intrinsics/mma_macro_generator.py(5 hunks)tilelang/intrinsics/mma_sp_layout.py(1 hunks)tilelang/intrinsics/mma_sp_macro_generator.py(1 hunks)tilelang/ir.py(1 hunks)tilelang/language/experimental/gemm_sp.py(2 hunks)tilelang/layout/gemm_sp.py(2 hunks)tilelang/tileop/gemm/__init__.py(2 hunks)tilelang/tileop/gemm_sp/__init__.py(1 hunks)tilelang/tileop/gemm_sp/gemm_sp_base.py(1 hunks)tilelang/tileop/gemm_sp/gemm_sp_mma.py(1 hunks)tilelang/utils/sparse.py(5 hunks)tilelang/utils/tensor.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (7)
- tilelang/utils/tensor.py
- tilelang/tileop/gemm/init.py
- examples/gemm_sp/example_gemm_sp.py
- tilelang/ir.py
- tilelang/tileop/gemm_sp/gemm_sp_base.py
- tilelang/intrinsics/mma_macro_generator.py
- tilelang/intrinsics/mma_layout.py
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.pysrc/op/gemm_sp_py.cctilelang/intrinsics/mma_sp_macro_generator.pytilelang/layout/gemm_sp.pytilelang/intrinsics/mma_sp_layout.pydocs/deeplearning_operators/matmul_sparse.md
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.pysrc/op/gemm_sp_py.cctilelang/intrinsics/mma_sp_macro_generator.pytilelang/layout/gemm_sp.pytilelang/intrinsics/mma_sp_layout.py
🧬 Code graph analysis (15)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)
tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(136-150)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (5)
tilelang/utils/sparse.py (3)
compress(77-106)randn_semi_sparse(109-128)randint_semi_sparse(131-158)tilelang/utils/tensor.py (2)
torch_assert_close(237-329)map_torch_type(37-54)tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(136-150)tilelang/intrinsics/mma_sp_macro_generator.py (1)
SparseTensorCoreIntrinEmitter(40-858)tilelang/language/experimental/gemm_sp.py (1)
gemm_sp_v2(92-308)
examples/gemm_sp/test_example_gemm_sp.py (2)
examples/gemm_sp/example_custom_compress.py (1)
main(305-359)examples/gemm_sp/example_gemm_sp.py (1)
main(105-144)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (4)
tilelang/utils/sparse.py (3)
compress(77-106)randn_semi_sparse(109-128)randint_semi_sparse(131-158)tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(136-150)tilelang/utils/tensor.py (2)
torch_assert_close(237-329)map_torch_type(37-54)tilelang/intrinsics/mma_sp_macro_generator.py (1)
SparseTensorCoreIntrinEmitter(40-858)
src/op/gemm_sp.cc (2)
src/op/gemm_sp_py.h (1)
RegisterReflection(41-65)src/op/gemm_sp.h (1)
RegisterReflection(27-33)
benchmark/matmul/benchmark_matmul_sp.py (3)
tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(136-150)examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)
matmul_sp(9-57)tilelang/language/experimental/gemm_sp.py (1)
gemm_sp_v2(92-308)
src/op/gemm_sp_py.cc (4)
src/op/gemm_sp_py.h (2)
GemmSPPy(80-86)RegisterReflection(41-65)tilelang/tileop/gemm_sp/__init__.py (1)
GemmSPPy(29-69)src/op/gemm.h (6)
GemmWarpPolicy(59-83)GemmWarpPolicy(64-68)GemmWarpPolicy(70-74)GemmWarpPolicy(76-82)RegisterReflection(35-41)RegisterReflection(106-128)src/op/gemm_sp.cc (6)
Clone(120-123)Clone(120-120)Lower(143-185)Lower(143-143)InferLayout(215-299)InferLayout(215-216)
tilelang/utils/sparse.py (1)
tilelang/utils/tensor.py (2)
is_float8(11-17)fp8_remove_negative_zeros_(20-24)
examples/gemm_sp/example_custom_compress.py (6)
tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(136-150)tilelang/utils/sparse.py (1)
randn_semi_sparse(109-128)tilelang/utils/tensor.py (1)
torch_assert_close(237-329)tilelang/language/allocate.py (3)
alloc_shared(27-42)alloc_fragment(59-70)alloc_local(45-56)tilelang/language/annotations.py (1)
annotate_layout(25-36)tilelang/language/experimental/gemm_sp.py (1)
gemm_sp_v2(92-308)
tilelang/intrinsics/mma_sp_macro_generator.py (4)
tilelang/intrinsics/utils.py (2)
mma_store_index_map(81-82)get_ldmatrix_offset(21-63)tilelang/utils/language.py (1)
is_fragment(81-91)tilelang/intrinsics/mma_sp_layout.py (20)
shared_16x16_to_mma_sp_layout_sr_a(14-15)shared_16x16_to_mma_sp_layout_sr_b(18-20)shared_16x32_to_mma_sp_layout_sr_a(23-24)shared_16x32_to_mma_sp_layout_sr_b(27-29)shared_16x64_to_mma_sp_layout_sr_a(32-33)shared_16x64_to_mma_sp_layout_sr_b(36-38)mma_sp_load_a_32x4_to_shared_16x16_layout(41-42)mma_sp_load_a_32x8_to_shared_16x32_layout(45-46)mma_sp_load_a_32x16_to_shared_16x64_layout(49-50)mma_sp_load_b_32x8_to_shared_16x16_layout(53-56)mma_sp_load_b_32x16_to_shared_16x32_layout(59-62)mma_sp_load_b_32x32_to_shared_16x64_layout(65-68)metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(75-80)metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(83-88)metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(91-94)metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(97-100)metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(107-112)metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(115-120)metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(123-129)get_ldmatrix_offset_b(156-190)tilelang/language/tir/op.py (3)
ptx_ldmatrix(1313-1349)address_of(464-480)ptx_mma_sp(964-1062)
src/op/gemm_sp_py.h (3)
src/op/gemm_sp.h (3)
tvm(13-112)tl(15-111)RegisterReflection(27-33)src/op/operator.h (2)
TileOperatorNode(56-66)TileOperator(68-72)src/op/gemm_sp.cc (6)
Lower(143-185)Lower(143-143)InferLayout(215-299)InferLayout(215-216)Clone(120-123)Clone(120-120)
tilelang/layout/gemm_sp.py (2)
tilelang/layout/layout.py (1)
Layout(13-148)tilelang/contrib/nvcc.py (2)
get_target_compute_version(258-299)parse_compute_version(302-324)
tilelang/tileop/gemm_sp/__init__.py (6)
src/op/gemm_sp_py.h (2)
tvm(15-89)GemmSPPy(80-86)tilelang/utils/target.py (1)
target_is_cuda(126-127)tilelang/ir.py (1)
GemmWarpPolicy(30-39)tilelang/tileop/gemm_sp/gemm_sp_mma.py (3)
GemmSPMMA(12-245)infer_layout(14-58)lower(60-233)tilelang/tileop/gemm_sp/gemm_sp_base.py (23)
infer_layout(14-15)lower(17-18)A(66-67)E(70-71)B(74-75)C(78-79)APtr(82-83)EPtr(86-87)BPtr(90-91)CPtr(94-95)M(33-34)N(37-38)K(41-42)trans_A(45-46)trans_B(49-50)stride_A(98-99)stride_B(102-103)offset_A(106-107)offset_B(110-111)clear_accum(114-115)k_pack(118-119)wg_wait(122-123)policy(126-127)src/op/gemm_sp_py.cc (1)
GemmSPPy(51-83)
tilelang/intrinsics/mma_sp_layout.py (1)
tilelang/intrinsics/mma_layout.py (4)
mma_load_a_32x4_to_shared_16x8_layout(130-133)mma_load_a_32x16_to_shared_16x32_layout(142-145)mma_load_a_32x8_to_shared_16x16_layout(148-161)ldmatrix_trans_32x8_to_shared_16x16_layout(24-27)
tilelang/language/experimental/gemm_sp.py (2)
tilelang/utils/language.py (1)
get_buffer_region_from_load(137-159)tilelang/language/gemm.py (10)
legalize_arguments(48-59)legalize_arguments(251-262)retrieve_shape(66-83)retrieve_shape(269-286)retrieve_stride(85-111)retrieve_stride(288-314)retrieve_ptr(140-175)retrieve_ptr(343-378)retrieve_offset(177-195)retrieve_offset(380-398)
🪛 LanguageTool
docs/deeplearning_operators/matmul_sparse.md
[grammar] ~50-~50: Ensure spelling is correct
Context: ... in A_sparse/A and E. (i.e. the 4-elment group at [n, k] doesn't match the 4-bit...
(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)
[grammar] ~137-~137: Ensure spelling is correct
Context: ...each 4-bit chunk represents two 2-bit indcies of non-zero elements within four cons...
(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)
🪛 markdownlint-cli2 (0.18.1)
docs/deeplearning_operators/matmul_sparse.md
39-39: Link text should be descriptive
(MD059, descriptive-link-text)
39-39: Link text should be descriptive
(MD059, descriptive-link-text)
176-176: Link text should be descriptive
(MD059, descriptive-link-text)
🪛 Ruff (0.14.3)
tilelang/tileop/gemm_sp/gemm_sp_mma.py
57-58: Avoid specifying long messages outside the exception class
(TRY003)
232-233: Avoid specifying long messages outside the exception class
(TRY003)
examples/gemm_sp/example_custom_compress.py
111-112: Avoid specifying long messages outside the exception class
(TRY003)
122-122: Avoid specifying long messages outside the exception class
(TRY003)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
129-129: Avoid specifying long messages outside the exception class
(TRY003)
132-132: Avoid specifying long messages outside the exception class
(TRY003)
134-136: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/mma_sp_macro_generator.py
52-60: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
62-119: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
121-129: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
178-180: Avoid specifying long messages outside the exception class
(TRY003)
212-212: Avoid specifying long messages outside the exception class
(TRY003)
313-313: Avoid specifying long messages outside the exception class
(TRY003)
381-381: Avoid specifying long messages outside the exception class
(TRY003)
390-390: Avoid specifying long messages outside the exception class
(TRY003)
395-395: Avoid specifying long messages outside the exception class
(TRY003)
397-397: Avoid specifying long messages outside the exception class
(TRY003)
445-445: Avoid specifying long messages outside the exception class
(TRY003)
696-696: Avoid specifying long messages outside the exception class
(TRY003)
713-713: Avoid specifying long messages outside the exception class
(TRY003)
771-771: Avoid specifying long messages outside the exception class
(TRY003)
785-785: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/layout/gemm_sp.py
115-115: Avoid specifying long messages outside the exception class
(TRY003)
118-118: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tileop/gemm_sp/__init__.py
61-61: Avoid specifying long messages outside the exception class
(TRY003)
69-69: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/mma_sp_layout.py
124-124: Unused function argument: local_id
(ARG001)
172-172: Avoid specifying long messages outside the exception class
(TRY003)
188-188: Avoid specifying long messages outside the exception class
(TRY003)
190-190: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/experimental/gemm_sp.py
162-163: Prefer TypeError exception for invalid type
(TRY004)
162-163: Avoid specifying long messages outside the exception class
(TRY003)
190-191: Prefer TypeError exception for invalid type
(TRY004)
190-191: Avoid specifying long messages outside the exception class
(TRY003)
254-255: Prefer TypeError exception for invalid type
(TRY004)
254-255: Avoid specifying long messages outside the exception class
(TRY003)
274-275: Prefer TypeError exception for invalid type
(TRY004)
274-275: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (30)
tilelang/intrinsics/mma_sp_layout.py (1)
1-190: LGTM! Layout helpers follow established patterns.The SP layout mapping functions correctly delegate to existing MMA layout helpers and maintain API consistency.
examples/gemm_sp/test_example_gemm_sp.py (1)
1-16: LGTM! Clean test structure.The test module follows standard patterns for exercising example code.
tilelang/utils/sparse.py (2)
92-103: LGTM! Correct float8 and transposition handling.The float8 handling correctly normalizes negative zeros and the transposition logic maintains correctness for the sparse compression workflow.
131-159: LGTM! New utility follows existing patterns.
randint_semi_sparsecorrectly mirrors the structure and behavior ofrandn_semi_sparsefor integer tensors.tilelang/language/experimental/gemm_sp.py (1)
91-308: LGTM! gemm_sp_v2 implementation is comprehensive.The new gemm_sp_v2 function correctly handles buffer normalization, shape/stride computation, and offset calculation. The implementation properly validates matrix dimensions and generates the appropriate intrinsic call.
benchmark/matmul/benchmark_matmul_sp.py (2)
12-12: LGTM! Correct API migration to make_cutlass_metadata_layout.
220-227: LGTM! Successful migration to gemm_sp_v2.The transition from
T.gemm_sptoT.gemm_sp_v2aligns with the PR objectives and maintains all required parameters.examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (2)
4-4: LGTM! Consistent API migration.
43-47: LGTM! Correct usage of make_cutlass_metadata_layout.The migration removes the
backendparameter and maintains all required arguments for the new API.testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (3)
5-11: LGTM! Clean test infrastructure updates.The imports and configuration changes (disabling TF32, importing new utilities) align well with the v2 pathway requirements and metadata layout updates.
14-42: LGTM! Well-structured input generation helper.The
generate_dense_inputfunction cleanly handles dynamic dtype selection and transposition for both integer and floating-point types, supporting the test harness's expanded dtype coverage.
122-123: LGTM! Correct metadata dtype and layout usage.The dynamic metadata dtype calculation via
SparseTensorCoreIntrinEmitter.E_FACTOR_MAPand the switch tomake_cutlass_metadata_layoutalign with the PR's standardized metadata layout approach for SM8x.Also applies to: 141-145
tilelang/layout/gemm_sp.py (3)
21-104: LGTM! SM90 metadata layout correctly promoted to public API.The function is now exposed as
make_cutlass_metadata_layout_sm90with proper error handling and dynamic shape/stride computation. The layout logic correctly implements CUTLASS-compatible interleaving for SM9.0 metadata.
107-133: LGTM! SM8x metadata layout correctly implements CUTLASS conventions.The renamed
make_cutlass_metadata_layout_sm8xcorrectly computes group and interweave constants based on metadata dtype bit-width, and the interleaved column-major transformation aligns with the CUTLASS SM8x metadata format referenced in the PyTorch link (line 109).Based on learnings
136-150: LGTM! Clean architecture-based dispatch.The
make_cutlass_metadata_layoutwrapper correctly determines the compute version and dispatches to the appropriate SM-specific layout generator, providing a unified public API.examples/gemm_sp/example_custom_compress.py (3)
18-62: LGTM! Well-structured configuration constants.The
DEFAULT_CONFIGandARCH_INFOdictionaries provide clear, architecture-specific configurations for different GPU families and data types, making the example easy to adapt and tune.
106-221: LGTM! Comprehensive naive compression implementation.The
torch_compressfunction provides a clear, well-documented reference implementation of 2:4 sparsity encoding with proper validation and dtype handling. The detailed encoding comments (lines 148-180) are particularly helpful for understanding the metadata format.
305-363: LGTM! Complete end-to-end example with benchmarking.The
mainfunction demonstrates a full workflow: argument parsing, input generation, compression (with selectable compressor), correctness verification viatorch_assert_close, and performance benchmarking. This provides excellent guidance for users implementing custom compression.testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (3)
11-65: LGTM! Clean SS variant test implementation.The
matmulfunction correctly constructs a gemm_sp_v2 kernel with proper shape handling, metadata layout annotation viamake_cutlass_metadata_layout, and pipelined execution. This establishes a solid foundation for the test suite.
136-165: LGTM! Robust input generation helper.The
generate_dense_inputfunction handles diverse dtypes (int8, float8, float16, bfloat16) with appropriate value ranges and semi-sparse generation viarandint_semi_sparse/randn_semi_sparse, supporting comprehensive test coverage.
198-256: LGTM! Complete coverage of RS, SR, and RR variants.The
matmul_rs,matmul_sr, andmatmul_rrfunctions correctly implement fragment-based loading patterns for each variant, with appropriatealloc_fragmentallocations,T.copyoperations, and layout annotations. This ensures all dataflow patterns (shared-shared, register-shared, shared-register, register-register) are tested.Also applies to: 352-410, 506-568
tilelang/tileop/gemm_sp/__init__.py (1)
28-69: LGTM! Clean GemmSPPy class definition with proper delegation.The
GemmSPPyclass correctly defines all required fields with proper type hints and delegatesinfer_layoutandlowertoGemmSPMMAfor CUDA targets, with clear error handling for unsupported targets. This provides a clean separation between the FFI boundary and the implementation.tilelang/intrinsics/mma_sp_macro_generator.py (8)
134-180: LGTM!The constructor properly initializes all instance attributes and validates the warp configuration. The initialization helpers correctly set up dimensions, abbreviations, and MMA prefixes based on data types.
182-288: LGTM!The initialization and helper methods correctly compute dimensions, local sizes, and thread bindings. The
extract_thread_bindingmethod properly handles bothis_m_firstorderings for flexible thread layout.
290-354: LGTM!The
ldmatrix_amethod correctly handles data loading for the A matrix with proper fallback when ldmatrix is unavailable for transposed int8. The physical K dimension adjustment (dividing bySPARSE_FACTOR) correctly accounts for the 2:4 sparsity pattern.
447-516: LGTM!The
_warp_ldmatrix_bmacro correctly handles both ldmatrix and fallback paths. Thereplicate_blogic properly doubles the load for 16x16 output tiles, and the fallback path correctly handles transposed and non-transposed layouts.
545-589: LGTM!The
_warp_mma_spmacro correctly implements the sparse MMA operation. The stride calculations properly handle fragment buffers, and the replicate_b path correctly performs two MMA operations with appropriate offsets for 16x16 output tiles. Metadata handling withSPARSE_SELECTOR = 0is correct.
591-644: LGTM!The
stmatrixmethod correctly implements both shared and global memory stores. Thread-local indices are properly mapped to 2D positions usingmma_store_index_map, and vectorized stores optimize memory access. The global variant correctly offsets by block coordinates.
646-787: LGTM!The
make_mma_load_layoutmethod correctly constructs fragment layouts for A and B matrices. The axis order determination (sr vs rs) correctly handles transposed and non-transposed cases, and the sparse factor division for matrix A (Line 744) properly accounts for 2:4 sparsity. The repeat/replicate operations correctly build block-level fragments with appropriate thread mapping.
789-858: LGTM!The
make_mma_store_layoutmethod correctly constructs the output fragment layout. Theforward_threadandforward_indexfunctions properly decompose 2D coordinates into block, warp, and MMA components, correctly handling bothis_m_firstthread orderings. The local offset calculation accurately reflects the warp-column-major layout of the output buffer.
| from tilelang.utils.language import get_buffer_region_from_load | ||
| import tilelang.language as T | ||
| from tvm import tir | ||
| from typing import list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix typing import: use List instead of list.
The import from typing import list is incorrect. Use from typing import List for Python < 3.9 compatibility, or if targeting Python 3.9+, remove this import and use the built-in list type directly.
Apply this diff for Python < 3.9 compatibility:
-from typing import list
+from typing import ListThen update the return type annotations:
- Line 146:
-> List[int] - Line 165:
-> List[int]
Alternatively, if Python 3.9+ is required, remove the import entirely and keep the lowercase list annotations.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from typing import list | |
| from typing import List |
🤖 Prompt for AI Agents
In tilelang/language/experimental/gemm_sp.py around line 8 and affecting
annotations at lines 146 and 165, the typing import is incorrect: replace "from
typing import list" with "from typing import List" (for Python <3.9) and update
the two return type annotations at line 146 and line 165 from "-> list[int]" to
"-> List[int]". If the codebase targets Python 3.9+, instead remove the import
entirely and keep the built-in "list" annotations as-is; ensure imports and
annotations are consistent across the file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
docs/deeplearning_operators/matmul_sparse.md (1)
39-39: Improve link text for accessibility.The duplicate
[here]links violate MD059 and make the doc harder to navigate with assistive tech. Please replace them with descriptive titles (e.g.,[PyTorch sparse kernel],[vLLM sparse kernel]).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
docs/deeplearning_operators/matmul_sparse.md(1 hunks)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py(1 hunks)tilelang/intrinsics/mma_sp_macro_generator.py(1 hunks)tilelang/language/experimental/gemm_sp.py(2 hunks)tilelang/tileop/gemm_sp/gemm_sp_mma.py(1 hunks)
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
docs/deeplearning_operators/matmul_sparse.mdtilelang/intrinsics/mma_sp_macro_generator.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
tilelang/intrinsics/mma_sp_macro_generator.py
🧬 Code graph analysis (4)
tilelang/tileop/gemm_sp/gemm_sp_mma.py (6)
tilelang/tileop/gemm_sp/gemm_sp_base.py (20)
GemmSPBase(11-127)infer_layout(14-15)policy(126-127)M(33-34)N(37-38)in_dtype(57-59)e_dtype(53-54)accum_dtype(62-63)trans_A(45-46)trans_B(49-50)K(41-42)is_gemm_ss(20-21)A(66-67)B(74-75)C(78-79)is_gemm_sr(23-24)is_gemm_rs(26-27)is_gemm_rr(29-30)lower(17-18)E(70-71)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-18)tilelang/intrinsics/mma_sp_macro_generator.py (6)
make_mma_store_layout(789-858)make_mma_load_layout(646-787)ldmatrix_a(290-354)ldmatrix_e(356-418)ldmatrix_b(420-516)mma_sp(518-589)tilelang/utils/language.py (2)
is_shared(25-39)is_fragment(81-91)tilelang/transform/simplify.py (1)
_Simplify(31-49)tilelang/tileop/gemm_sp/__init__.py (2)
infer_layout(56-61)lower(63-69)
tilelang/language/experimental/gemm_sp.py (2)
tilelang/utils/language.py (1)
get_buffer_region_from_load(137-159)tilelang/language/gemm.py (10)
legalize_arguments(48-59)legalize_arguments(251-262)retrieve_shape(66-83)retrieve_shape(269-286)retrieve_stride(85-111)retrieve_stride(288-314)retrieve_ptr(140-175)retrieve_ptr(343-378)retrieve_offset(177-195)retrieve_offset(380-398)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (6)
tilelang/utils/sparse.py (3)
compress(77-106)randn_semi_sparse(109-128)randint_semi_sparse(131-158)tilelang/utils/tensor.py (2)
torch_assert_close(237-329)map_torch_type(37-54)tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(136-150)tilelang/intrinsics/mma_sp_macro_generator.py (1)
SparseTensorCoreIntrinEmitter(40-858)tilelang/language/experimental/gemm_sp.py (1)
gemm_sp_v2(91-307)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-18)
tilelang/intrinsics/mma_sp_macro_generator.py (5)
tilelang/intrinsics/utils.py (2)
mma_store_index_map(81-82)get_ldmatrix_offset(21-63)tilelang/utils/language.py (1)
is_fragment(81-91)tilelang/intrinsics/mma_sp_layout.py (20)
shared_16x16_to_mma_sp_layout_sr_a(14-15)shared_16x16_to_mma_sp_layout_sr_b(18-20)shared_16x32_to_mma_sp_layout_sr_a(23-24)shared_16x32_to_mma_sp_layout_sr_b(27-29)shared_16x64_to_mma_sp_layout_sr_a(32-33)shared_16x64_to_mma_sp_layout_sr_b(36-38)mma_sp_load_a_32x4_to_shared_16x16_layout(41-42)mma_sp_load_a_32x8_to_shared_16x32_layout(45-46)mma_sp_load_a_32x16_to_shared_16x64_layout(49-50)mma_sp_load_b_32x8_to_shared_16x16_layout(53-56)mma_sp_load_b_32x16_to_shared_16x32_layout(59-62)mma_sp_load_b_32x32_to_shared_16x64_layout(65-68)metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(75-80)metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(83-88)metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(91-94)metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(97-100)metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(107-112)metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(115-120)metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(123-129)get_ldmatrix_offset_b(156-190)tilelang/language/tir/op.py (2)
ptx_ldmatrix(1313-1349)ptx_mma_sp(964-1062)tilelang/layout/fragment.py (2)
replicate(147-161)repeat(124-145)
🪛 markdownlint-cli2 (0.18.1)
docs/deeplearning_operators/matmul_sparse.md
39-39: Link text should be descriptive
(MD059, descriptive-link-text)
39-39: Link text should be descriptive
(MD059, descriptive-link-text)
176-176: Link text should be descriptive
(MD059, descriptive-link-text)
🪛 Ruff (0.14.3)
tilelang/tileop/gemm_sp/gemm_sp_mma.py
57-58: Avoid specifying long messages outside the exception class
(TRY003)
232-233: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/experimental/gemm_sp.py
161-162: Prefer TypeError exception for invalid type
(TRY004)
161-162: Avoid specifying long messages outside the exception class
(TRY003)
189-190: Prefer TypeError exception for invalid type
(TRY004)
189-190: Avoid specifying long messages outside the exception class
(TRY003)
253-254: Prefer TypeError exception for invalid type
(TRY004)
253-254: Avoid specifying long messages outside the exception class
(TRY003)
273-274: Prefer TypeError exception for invalid type
(TRY004)
273-274: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/mma_sp_macro_generator.py
52-60: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
62-119: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
121-129: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
178-180: Avoid specifying long messages outside the exception class
(TRY003)
212-212: Avoid specifying long messages outside the exception class
(TRY003)
313-313: Avoid specifying long messages outside the exception class
(TRY003)
381-381: Avoid specifying long messages outside the exception class
(TRY003)
390-390: Avoid specifying long messages outside the exception class
(TRY003)
395-395: Avoid specifying long messages outside the exception class
(TRY003)
397-397: Avoid specifying long messages outside the exception class
(TRY003)
445-445: Avoid specifying long messages outside the exception class
(TRY003)
696-696: Avoid specifying long messages outside the exception class
(TRY003)
713-713: Avoid specifying long messages outside the exception class
(TRY003)
771-771: Avoid specifying long messages outside the exception class
(TRY003)
785-785: Avoid specifying long messages outside the exception class
(TRY003)
| dtype_abbrv = { | ||
| "float16": "fp16", | ||
| "bfloat16": "bf16", | ||
| "float32": "fp32", | ||
| "int8": "int8", | ||
| "int32": "int32", | ||
| "float8_e4m3": "e4m3", | ||
| "float8_e5m2": "e5m2", | ||
| } | ||
|
|
||
| E_FACTOR_MAP = { # e_kdim = mma_kdim // e_factor | ||
| "float": { | ||
| "int16": 8, | ||
| "uint16": 8, | ||
| }, | ||
| "float32": { | ||
| "int16": 8, | ||
| "uint16": 8, | ||
| }, | ||
| "float16": { | ||
| "int8": 8, | ||
| "uint8": 8, | ||
| "int16": 16, | ||
| "uint16": 16, | ||
| "int32": 32, | ||
| "uint32": 32, | ||
| }, | ||
| "bfloat16": { | ||
| "int8": 8, | ||
| "uint8": 8, | ||
| "int16": 16, | ||
| "uint16": 16, | ||
| "int32": 32, | ||
| "uint32": 32, | ||
| }, | ||
| "int8": { | ||
| "int8": 8, | ||
| "uint8": 8, | ||
| "int16": 16, | ||
| "uint16": 16, | ||
| "int32": 32, | ||
| "uint32": 32, | ||
| }, | ||
| "uint8": { | ||
| "int8": 8, | ||
| "uint8": 8, | ||
| "int16": 16, | ||
| "uint16": 16, | ||
| "int32": 32, | ||
| "uint32": 32, | ||
| }, | ||
| "float8_e4m3": { | ||
| "int8": 8, | ||
| "uint8": 8, | ||
| "int16": 16, | ||
| "uint16": 16, | ||
| "int32": 32, | ||
| "uint32": 32, | ||
| }, | ||
| "float8_e5m2": { | ||
| "int8": 8, | ||
| "uint8": 8, | ||
| "int16": 16, | ||
| "uint16": 16, | ||
| "int32": 32, | ||
| "uint32": 32, | ||
| }, | ||
| } | ||
|
|
||
| E_REPLICATE_FACTOR = { # metadata replicate every 4 consecutive threads | ||
| "float32": 2, | ||
| "float16": 2, # 2 of 4 consecutive threads provides | ||
| "bfloat16": 2, | ||
| "int8": 1, # 4 of 4 consecutive threads provides | ||
| "uint8": 1, | ||
| "float8_e4m3": 1, | ||
| "float8_e5m2": 1, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add tf32 dtype metadata to emitter tables.
E_FACTOR_MAP already exposes "float" entries, but dtype_abbrv and E_REPLICATE_FACTOR don’t, so creating the emitter with in_dtype="float" (tf32) raises a KeyError. That breaks the advertised tf32 path. Please add the missing mappings so we can actually emit tf32 kernels.
Apply this diff:
dtype_abbrv = {
+ "float": "tf32",
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
@@
E_REPLICATE_FACTOR = { # metadata replicate every 4 consecutive threads
+ "float": 2,
"float32": 2,
"float16": 2, # 2 of 4 consecutive threads provides
"bfloat16": 2,🧰 Tools
🪛 Ruff (0.14.3)
52-60: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
62-119: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
121-129: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
🤖 Prompt for AI Agents
In tilelang/intrinsics/mma_sp_macro_generator.py around lines 52 to 129, the
emitter tables are missing entries for the tf32 dtype key "float", causing a
KeyError when in_dtype="float"; add "float": "tf32" to the dtype_abbrv mapping
and add "float": 2 to E_REPLICATE_FACTOR (matching float32 behavior) so the
existing E_FACTOR_MAP "float" entries can be used without error.
|
we're good to go if we can resolve the conflict and I think then we can let this pr in. |
Checklist
4090 mini benchmark
(2 experiments)
(Torch, CUTlASS backend)
(Torch, CUSPARSELT backend)
(TileLang Sparse, fp32 accum)
Summary by CodeRabbit
Release Notes
New Features
Documentation
Updates
Tests