Skip to content

Conversation

@botbw
Copy link
Contributor

@botbw botbw commented Oct 17, 2025

Checklist

  • bf16/fp16
  • customized metadata layout
  • tf32 (with precision issue(
  • int8
  • fp8
  • different scopes
    • sss
    • srs
    • rss
    • rrs
    • metadata in register (?)
  • custom compression utils example
  • Doc

4090 mini benchmark

mnk Ref TFLOPS (CuBLAS Dense)
(2 experiments)
Ref TFLOPS
(Torch, CUTlASS backend)
Ref TFLOPS
(Torch, CUSPARSELT backend)
Best TFLOPS
(TileLang Sparse, fp32 accum)
16384 154.373 / 156.286 194.450 183.236 278.731

Summary by CodeRabbit

Release Notes

  • New Features

    • Added sparse matrix-matrix multiplication with Tensor Core support and improved GEMM kernels (gemm_sp_v2 variant)
    • Introduced comprehensive sparse tensor compression/decompression utilities and layout helpers
    • Added support for additional data types in sparse operations
  • Documentation

    • New guide for sparse matrix-matrix multiplication with structured sparsity
    • Updated documentation index with sparse operator resources
  • Updates

    • Modernized sparse metadata layout API for better flexibility
    • Enhanced sparse layout utilities with new configurations for thread bindings
    • Updated examples and benchmarks to use latest sparse APIs
  • Tests

    • Added comprehensive test coverage for sparse GEMM operations across multiple configurations

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 17, 2025

Walkthrough

Introduces a new T.gemm_sp_v2 operator for 2:4 sparse matrix multiplication with flexible warp policies and improved layout handling. Adds comprehensive sparse tensor-core intrinsic support, metadata layout generation, and extensive tooling for sparse GEMM workloads across NVIDIA architectures.

Changes

Cohort / File(s) Summary
Core IR and FFI Registration
src/op/gemm_sp.cc, src/op/gemm_sp.h
Registers new tl.GemmSPWarpPolicy operator and reflection for warp partition computation with bits parameter support.
GEMM SP Operator Implementation
src/op/gemm_sp_py.cc, src/op/gemm_sp_py.h
Introduces GemmSPPy class implementing sparse GEMM operator deserialization, target-aware instruction selection (WGMMA/MFMA/MMA), layout inference, and FFI lowering hooks.
Sparse Layout Helpers
tilelang/intrinsics/mma_sp_layout.py
New module with 30+ layout functions for sparse MMA operations, including shared-to-MMA mappings, metadata load patterns, and LDMATRIX transformations across SM90/SM8x.
Sparse Intrinsic Emitter
tilelang/intrinsics/mma_sp_macro_generator.py
Implements SparseTensorCoreIntrinEmitter class providing parameterized load/store pathways, thread binding, fragment mappings, and MMA operations for sparse tensor-core style operations.
MMA Layout Extensions
tilelang/intrinsics/mma_layout.py, tilelang/intrinsics/mma_macro_generator.py
Adds 32x8-to-16x16 load layouts for 16-bit data paths and integrates new layouts into non-LDMATRIX fallback logic.
GEMM SP Public API
tilelang/language/experimental/gemm_sp.py, tilelang/language/__init__.py
Adds gemm_sp_v2 function with flexible transposition, accumulation control, and warp policy support; exports from public language module.
GEMM SP Tileop Layer
tilelang/tileop/gemm_sp/__init__.py, tilelang/tileop/gemm_sp/gemm_sp_base.py, tilelang/tileop/gemm_sp/gemm_sp_mma.py, tilelang/tileop/__init__.py
New tileop integration with GemmSPBase providing accessor abstractions and GemmSPMMA implementing layout inference and lowering for 2:4 sparse patterns (ss/sr/rs/rr).
Layout API Refactoring
tilelang/layout/__init__.py, tilelang/layout/gemm_sp.py
Renames make_metadata_layoutmake_cutlass_metadata_layout, removes backend parameter, implements architecture-specific dispatch (SM90/SM8x) with dynamic interleaved addressing.
GEMM Python Typing
tilelang/tileop/gemm/__init__.py
Adds type hints to gemm_py_infer_layout and gemm_py_lower for improved IDE support and clarity.
IR Class Updates
tilelang/ir.py
Adds GemmSPWarpPolicy class with compute_warp_partition method supporting bits parameter via FFI.
Debug Template
src/tl_templates/cuda/debug.h
Adds uint16_t specialization for debug_print_buffer_value.
Sparse Tensor Utilities
tilelang/utils/tensor.py, tilelang/utils/sparse.py
Adds is_float8, fp8_remove_negative_zeros_ utilities; introduces randint_semi_sparse for integer sparse tensor generation; extends TensorSupplyType enum.
Documentation
docs/deeplearning_operators/matmul_sparse.md, docs/index.md
New doc covering structured sparsity, compression workflows, kernel variants (gemm_sp/gemm_sp_v2), metadata formats, and migration guide.
Benchmark Updates
benchmark/matmul/benchmark_matmul_sp.py
Updates matmul_sp signature to include in_dtype, replaces T.gemm_sp with T.gemm_sp_v2, uses make_cutlass_metadata_layout.
Example Updates
examples/gemm_sp/example_gemm_sp.py, examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
Migrate imports from make_metadata_layout to make_cutlass_metadata_layout, rename default_configDEFAULT_CONFIG, switch T.gemm_spT.gemm_sp_v2.
New Examples
examples/gemm_sp/example_custom_compress.py
Comprehensive example with custom compression, metadata decoding, kernel-based encoding, and end-to-end benchmarking.
Example Tests
examples/gemm_sp/test_example_gemm_sp.py
Test harness invoking example entry points.
Comprehensive Test Suite
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py, testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
Extends existing gemm_sp tests with dynamic dtype handling, dense input generation, and metadata validation; adds new gemm_sp_v2 test module with ss/rs/sr/rr pattern variants and strict numerical comparison.
Module Exports
tilelang/tileop/__init__.py, tilelang/profiler/__init__.py
Exports GemmSPPy and imports is_float8 utility.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant gemm_sp_v2
    participant GemmSPPy
    participant GemmSPMMA
    participant SparseTensorCoreIntrinEmitter

    User->>gemm_sp_v2: call with A_sparse, E, B, C, policy
    gemm_sp_v2->>gemm_sp_v2: legalize_arguments() & retrieve shapes/strides/offsets
    gemm_sp_v2->>GemmSPPy: tl.gemm_sp_py intrinsic call
    
    Note over GemmSPPy: Kernel Compilation Phase
    GemmSPPy->>GemmSPMMA: infer_layout(target, thread_nums)
    GemmSPMMA->>GemmSPMMA: compute warp partitions
    GemmSPMMA->>SparseTensorCoreIntrinEmitter: create emitter per pattern (ss/sr/rs/rr)
    SparseTensorCoreIntrinEmitter->>SparseTensorCoreIntrinEmitter: make_mma_load_layout(A, B, E)
    SparseTensorCoreIntrinEmitter-->>GemmSPMMA: fragment layouts
    GemmSPMMA-->>GemmSPPy: layout_map
    
    Note over GemmSPPy: Kernel Lowering Phase
    GemmSPPy->>GemmSPMMA: lower(target, thread_nums, thread_var)
    GemmSPMMA->>SparseTensorCoreIntrinEmitter: orchestrate loads & mma_sp
    SparseTensorCoreIntrinEmitter->>SparseTensorCoreIntrinEmitter: ldmatrix_a/e/b()
    SparseTensorCoreIntrinEmitter->>SparseTensorCoreIntrinEmitter: mma_sp() with ptx_mma_sp
    SparseTensorCoreIntrinEmitter->>SparseTensorCoreIntrinEmitter: stmatrix()
    SparseTensorCoreIntrinEmitter-->>GemmSPMMA: prim_func kernel
    GemmSPMMA-->>GemmSPPy: lowered kernel
    GemmSPPy-->>User: compiled kernel
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Areas requiring close attention:

  • tilelang/intrinsics/mma_sp_macro_generator.py — Dense 500+ line implementation of SparseTensorCoreIntrinEmitter with intricate thread binding, fragment loading/storing semantics, and PTX-level MMA operations; requires understanding of tensor core memory layouts and thread-level indexing patterns.
  • tilelang/tileop/gemm_sp/gemm_sp_mma.py — Complex warp partitioning logic and four kernel variants (ss/sr/rs/rr) with distinct dataflow patterns; validation of K-dimension alignment and micro-tile configuration.
  • src/op/gemm_sp_py.cc/gemm_sp_py.h — FFI deserialization, instruction selection logic (WGMMA/MFMA/MMA eligibility checks), and architecture-aware lowering hooks; involves C++ TVM integration and pointer/stride computation.
  • tilelang/layout/gemm_sp.py — Refactored metadata layout generation with architecture-specific dispatch; interleaved addressing calculation logic differs from previous implementation and requires verification.
  • tilelang/intrinsics/mma_sp_layout.py — 40+ layout helper functions with intricate indexing formulas for shared-to-MMA mappings, metadata encoding patterns, and LDMATRIX transformations; each variant requires careful index formula validation.
  • Test coverage coordination — Two parallel test files (gemm_sp and gemm_sp_v2) with dense setup logic; ensure consistency across dtype handling, compression workflows, and numerical validation thresholds.

Possibly related issues

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • chengyupku

Poem

🐰 Sparse tensors dance on Tensor Cores now,
With 2:4 sparsity and CUTLASS-aware pow,
Warp policies flex, layouts interweave tight,
gemm_sp_v2 brings efficiency to light! ✨
Metadata flows, fragments align so true,
A new way forward for GEMM operations through! 🚀

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main objective: adding support for T.gemm_sp_v2 on SM80 and SM89 architectures, which is the primary change across the pull request.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d7ca20e and b2871dd.

📒 Files selected for processing (7)
  • src/op/gemm_sp.cc (1 hunks)
  • src/op/gemm_sp.h (1 hunks)
  • tilelang/intrinsics/mma_layout.py (1 hunks)
  • tilelang/intrinsics/mma_macro_generator.py (3 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/layout/__init__.py (1 hunks)
  • tilelang/tileop/gemm/__init__.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/op/gemm_sp.h
  • tilelang/tileop/gemm/init.py
🧰 Additional context used
🧬 Code graph analysis (4)
tilelang/layout/__init__.py (2)
tilelang/language/experimental/gemm_sp.py (1)
  • gemm_sp (10-87)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (136-150)
src/op/gemm_sp.cc (1)
src/op/gemm_sp.h (1)
  • RegisterReflection (27-33)
tilelang/intrinsics/mma_macro_generator.py (1)
tilelang/intrinsics/mma_layout.py (2)
  • mma_load_b_32x8_to_shared_16x16_layout (176-188)
  • mma_load_a_32x8_to_shared_16x16_layout (154-167)
tilelang/language/__init__.py (1)
tilelang/language/experimental/gemm_sp.py (2)
  • gemm_sp (10-87)
  • gemm_sp_v2 (91-307)
🪛 Ruff (0.14.4)
tilelang/layout/__init__.py

16-16: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/language/__init__.py

53-53: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (8)
tilelang/language/__init__.py (1)

53-53: LGTM! New gemm_sp_v2 API export.

The addition of gemm_sp_v2 to the public API is appropriate and aligns with the PR objectives for SM8x sparse GEMM support. The function implementation exists in the experimental module with proper documentation.

Note: The Ruff RUF100 warning about the unused noqa directive is a false positive. The # noqa: F401 comment is standard practice in __init__.py files that re-export symbols, as linters would otherwise flag imported-but-unused warnings.

tilelang/layout/__init__.py (1)

16-16: Verified: Complete migration from make_metadata_layout to make_cutlass_metadata_layout.

The codebase search confirmed no remaining imports or usages of the deprecated make_metadata_layout API. The breaking change has been successfully migrated across all files. The # noqa: F401 comment in the __init__.py re-export module is appropriate and not a false positive.

src/op/gemm_sp.cc (1)

306-307: LGTM: Clean TVM operator registration.

The registration of tl.GemmSPWarpPolicy follows standard TVM patterns.

tilelang/intrinsics/mma_layout.py (2)

154-167: LGTM: Layout function correctly implements the documented indexing scheme.

The implementation correctly translates the docstring specification into arithmetic expressions. The row and column calculations properly handle the conditional offsets based on local_id ranges.


176-188: Now let me search for CUDA PTX documentation on the correct layout to definitively confirm which is correct:

Let me search more specifically for matrix B layout in MMA/LDMATRIX documentation:

Now I have a clear picture. Let me search for one more confirmation on whether there are any tests that might validate this implementation:

Correct the docstring—implementation is correct, docstring is wrong.

The CUDA PTX documentation for matrix layout in MMA shows row uses groupID and col uses (threadID_in_group * 2) + (i & 0x1) for matrix A. The mma_load_b_32x8_to_shared_16x16_layout implementation correctly follows this same pattern: row = (thread_id // 4) + 8 * (local_id // 4) (groupID-based) and col = (thread_id % 4) * 2 + ... (threadID_in_group-based).

The docstring is incorrect—it has row and col reversed. The implementation is correct and consistent with mma_load_a_32x8_to_shared_16x16_layout (which also shows row = groupID, col = threadID_in_group formula).

Fix: Update the docstring to match the correct implementation:

"""
    groupID           = %laneid >> 2
    threadID_in_group = %laneid % 4

    row = groupID            for bi where  0 <= i < 2 || 4 <= i < 6
          groupID + 8        Otherwise

    col = (threadID_in_group * 2) + (i & 0x1)          for bi where i <  2
          (threadID_in_group * 2) + (i & 0x1) + 8      for bi where i >= 2
"""

Likely an incorrect or invalid review comment.

tilelang/intrinsics/mma_macro_generator.py (3)

23-26: LGTM: Imports correctly reference the new layout functions.

The imports properly bring in the newly added 32x8-to-16x16 layout functions for both A and B operands.


291-292: LGTM: Correctly adds 16-bit data path support for A operand.

This change properly fills the gap for 16-bit data types (fp16/bf16) when ldmatrix is unavailable, completing the non-ldmatrix fallback path for all supported bit widths (8, 16, 32).


419-420: LGTM: Correctly adds 16-bit data path support for B operand.

This change mirrors the A path update and properly completes 16-bit support for the B operand when ldmatrix is unavailable. However, note that this depends on the correctness of mma_load_b_32x8_to_shared_16x16_layout, which has a potential docstring/implementation mismatch flagged in the other file.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@botbw botbw force-pushed the gemm_sp_v2 branch 3 times, most recently from 016dd1c to 122abb5 Compare October 20, 2025 07:07
@botbw botbw changed the title [Language] support T.gemm_sp_v2 [Language] support T.gemm_sp_v2 on sm80 and sm89 Oct 31, 2025
@botbw botbw marked this pull request as ready for review November 5, 2025 07:37
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 13

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tilelang/profiler/__init__.py (1)

131-132: Fix incorrect usage of is_float8 function.

The is_float8 function expects torch.dtype but receives torch.Tensor objects. This will cause incorrect behavior.

Apply this diff:

-                    lhs if not is_float8(lhs) else lhs.to(torch.float32),
-                    rhs if not is_float8(rhs) else rhs.to(torch.float32),
+                    lhs if not is_float8(lhs.dtype) else lhs.to(torch.float32),
+                    rhs if not is_float8(rhs.dtype) else rhs.to(torch.float32),
benchmark/matmul/benchmark_matmul_sp.py (1)

270-270: Fix missing in_dtype parameter.

The matmul_sp function expects both in_dtype and accum_dtype parameters (line 89), but only accum_dtype is passed here.

Apply this diff:

-    best_result = matmul_sp(M, N, K, args.accum_dtype)
+    best_result = matmul_sp(M, N, K, "float16", args.accum_dtype)

Or add an --in_dtype argument:

+    parser.add_argument(
+        "--in_dtype",
+        type=str,
+        default="float16",
+        choices=["float16", "float8_e4m3fn"],
+        help="Input datatype")
     args = parser.parse_args()
     
     ...
-    best_result = matmul_sp(M, N, K, args.accum_dtype)
+    best_result = matmul_sp(M, N, K, args.in_dtype, args.accum_dtype)
🧹 Nitpick comments (11)
tilelang/intrinsics/mma_sp_macro_generator.py (2)

195-239: Remove leftover debug prints from initialization

These print(...) calls run every time the emitter is constructed, spamming stdout in normal builds. Please drop them (or gate behind an explicit debug flag).

-        print(f"{self.local_size_a=}, {self.local_size_e=}, {self.local_size_b=}, {self.local_size_out=}, {self.e_factor=} {n_dim=} {k_dim=}")
...
-        print(f"{self.warp_rows=}, {self.warp_cols=}, {self.n_dim=}, {self.micro_size_x=}, {self.micro_size_y=}, {self.micro_size_k=}, {self.warp_row_tiles=}, {self.warp_col_tiles=}")

551-553: Drop stray print in mma_sp

This diagnostic print executes at runtime and floods logs. Please remove it alongside the others.

-        print(f"{e_local_stride=}")
tilelang/language/experimental/gemm_sp.py (3)

191-191: Remove unused variable E_shape.

The variable E_shape is assigned but never used in the function.

Apply this diff:

 A_shape = retrieve_shape(A_sparse)
-E_shape = retrieve_shape(E)
 B_shape = retrieve_shape(B)
 C_shape = retrieve_shape(C)

143-274: Consider refactoring helper functions to reduce duplication.

The helper functions retrieve_shape, retrieve_stride, retrieve_ptr, and retrieve_offset have significant structural duplication with nearly identical type-checking patterns and branching logic for Buffer/BufferRegion/BufferLoad. Consider extracting the common pattern or creating a base dispatcher.

For example, you could create a generic dispatcher:

def _dispatch_on_buffer_type(object, buffer_handler, region_handler, load_handler):
    if isinstance(object, tir.Buffer):
        return buffer_handler(object)
    elif isinstance(object, tir.BufferRegion):
        return region_handler(object)
    elif isinstance(object, tir.BufferLoad):
        return load_handler(object)
    else:
        raise TypeError(f"Unsupported type: {type(object)}")

Then each retrieve function becomes a simple call to this dispatcher with specialized handlers.


159-160: Consider using TypeError for invalid type errors.

When raising exceptions for unsupported argument types, TypeError is more semantically appropriate than ValueError.

For example, at line 159-160:

-        raise ValueError(
-            f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
+        raise TypeError(
+            f"Unsupported retrieve_shape argument type: {type(object)}")

Apply similar changes to lines 187-188, 253-254, and 273-274.

Also applies to: 187-188, 253-254, 273-274

src/op/gemm_sp_py.cc (1)

143-193: Consider refactoring to reduce cognitive complexity.

The CheckWGMMA function has significant duplication with repeated checks for dtype combinations and K-alignment constraints. While correct, this makes the code harder to maintain.

Consider extracting common patterns into helper functions or using a lookup table:

// Example: Extract alignment checks
static constexpr int GetKAlignment(DataType c_dtype, DataType a_dtype, DataType b_dtype) {
  if (c_dtype == DataType::Float(16)) {
    if (a_dtype.is_float8() || b_dtype.is_float8()) return 32;
    return 16;
  }
  // ... similar patterns
}

bool GemmSPPyNode::CheckWGMMA() const {
  if (B.scope() != "shared.dyn" && B.scope() != "shared") {
    return false;
  }
  
  int required_alignment = GetKAlignment(C->dtype, A->dtype, B->dtype);
  if (required_alignment == 0) return false;
  
  bool requires_specific_transpose = /* condition */;
  if (requires_specific_transpose && (trans_A || !trans_B)) {
    return false;
  }
  
  return K % required_alignment == 0;
}
examples/gemm_sp/example_custom_compress.py (3)

119-120: Remove unused variable.

The device variable is assigned but never used in the function.

Apply this diff:

 m, k = dense.shape
-device = dense.device

281-282: TODO: Address the alloc_var issue.

The comment indicates alloc_var has buggy behavior. This should be tracked and resolved.

Do you want me to open a new issue to track this TODO?


296-296: TODO: Add device assertion after rebasing.

The comment suggests using T.device_assert(non_zero_cnt <= 2) after rebasing the main branch. This validation would help catch violations of the 2:4 sparsity constraint at runtime.

Do you want me to open a new issue to track adding this assertion?

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (1)

585-587: Minor formatting: Remove extra space.

There's an inconsistent space before the M parameter.

Apply this diff:

     program = matmul_rr(
-         M,
+        M,
         N,
tilelang/intrinsics/mma_sp_layout.py (1)

60-101: LGTM with optional style suggestion.

The metadata functions correctly handle different bit-width scenarios with appropriate logical ID transformations.

Note: The static analysis warning about unused local_id on line 96 is acceptable—the comment correctly explains that local_id is always 0 for this function, and maintaining the parameter keeps the signature consistent with similar functions.

Optional: Add explicit parentheses for clarity.

Lines 87 and 93 could benefit from explicit parentheses to improve readability, even though the operator precedence is correct:

-    col = (logical_id % 4) // 2 * 4 + local_id
+    col = ((logical_id % 4) // 2) * 4 + local_id
-    col = (logical_id % 4) // 2 * 2 + local_id
+    col = ((logical_id % 4) // 2) * 2 + local_id
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c37621c and 2efb496.

⛔ Files ignored due to path filters (1)
  • docs/_static/img/sparse_mma_storage_example.png is excluded by !**/*.png
📒 Files selected for processing (33)
  • benchmark/matmul/benchmark_matmul_sp.py (6 hunks)
  • docs/deeplearning_operators/matmul_sparse.md (1 hunks)
  • docs/index.md (1 hunks)
  • examples/gemm_sp/example_custom_compress.py (1 hunks)
  • examples/gemm_sp/example_gemm_sp.py (5 hunks)
  • examples/gemm_sp/test_example_gemm_sp.py (1 hunks)
  • examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (2 hunks)
  • src/op/gemm.h (1 hunks)
  • src/op/gemm_sp.cc (1 hunks)
  • src/op/gemm_sp.h (1 hunks)
  • src/op/gemm_sp_py.cc (1 hunks)
  • src/op/gemm_sp_py.h (1 hunks)
  • src/target/ptx.cc (1 hunks)
  • src/tl_templates/cuda/debug.h (1 hunks)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (7 hunks)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (1 hunks)
  • tilelang/intrinsics/mma_layout.py (1 hunks)
  • tilelang/intrinsics/mma_macro_generator.py (5 hunks)
  • tilelang/intrinsics/mma_sp_layout.py (1 hunks)
  • tilelang/intrinsics/mma_sp_macro_generator.py (1 hunks)
  • tilelang/ir.py (1 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/experimental/gemm_sp.py (2 hunks)
  • tilelang/layout/__init__.py (1 hunks)
  • tilelang/layout/gemm_sp.py (3 hunks)
  • tilelang/profiler/__init__.py (1 hunks)
  • tilelang/tileop/__init__.py (1 hunks)
  • tilelang/tileop/gemm/__init__.py (1 hunks)
  • tilelang/tileop/gemm_sp/__init__.py (1 hunks)
  • tilelang/tileop/gemm_sp/gemm_sp_base.py (1 hunks)
  • tilelang/tileop/gemm_sp/gemm_sp_mma.py (1 hunks)
  • tilelang/utils/sparse.py (5 hunks)
  • tilelang/utils/tensor.py (1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
  • tilelang/intrinsics/mma_sp_layout.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
  • tilelang/intrinsics/mma_sp_layout.py
🧬 Code graph analysis (27)
src/op/gemm.h (1)
src/transform/layout_reducer.h (1)
  • Object (63-71)
tilelang/layout/__init__.py (1)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (137-152)
src/op/gemm_sp.cc (1)
src/op/gemm.h (2)
  • RegisterReflection (36-42)
  • RegisterReflection (123-145)
src/op/gemm_sp.h (2)
src/ir.cc (2)
  • TVM_DECLARE_FINAL_OBJECT_INFO (197-197)
  • TVM_DECLARE_FINAL_OBJECT_INFO (314-314)
src/op/gemm.h (1)
  • GemmWarpPolicyNode (27-73)
tilelang/utils/tensor.py (1)
tilelang/engine/param.py (1)
  • is_float8 (81-91)
tilelang/language/__init__.py (1)
tilelang/language/experimental/gemm_sp.py (2)
  • gemm_sp (9-86)
  • gemm_sp_v2 (89-307)
tilelang/tileop/gemm_sp/gemm_sp_mma.py (5)
tilelang/tileop/gemm_sp/gemm_sp_base.py (15)
  • infer_layout (14-15)
  • policy (126-127)
  • M (33-34)
  • N (37-38)
  • e_dtype (53-54)
  • accum_dtype (62-63)
  • trans_A (45-46)
  • trans_B (49-50)
  • K (41-42)
  • is_gemm_ss (20-21)
  • A (66-67)
  • B (74-75)
  • C (78-79)
  • lower (17-18)
  • E (70-71)
tilelang/intrinsics/mma_sp_macro_generator.py (7)
  • SparseTensorCoreIntrinEmitter (39-867)
  • make_mma_store_layout (798-867)
  • make_mma_load_layout (655-796)
  • ldmatrix_a (293-357)
  • ldmatrix_e (359-425)
  • ldmatrix_b (428-522)
  • mma_sp (524-598)
tilelang/utils/language.py (2)
  • is_shared (25-39)
  • is_fragment (68-78)
tilelang/transform/simplify.py (1)
  • _Simplify (30-49)
tilelang/tileop/gemm_sp/__init__.py (2)
  • infer_layout (52-57)
  • lower (59-65)
tilelang/tileop/gemm_sp/__init__.py (3)
tilelang/utils/target.py (1)
  • target_is_cuda (90-91)
tilelang/tileop/gemm_sp/gemm_sp_mma.py (3)
  • GemmSPMMA (12-244)
  • infer_layout (14-58)
  • lower (60-232)
tilelang/tileop/gemm_sp/gemm_sp_base.py (23)
  • infer_layout (14-15)
  • lower (17-18)
  • A (66-67)
  • E (70-71)
  • B (74-75)
  • C (78-79)
  • APtr (82-83)
  • EPtr (86-87)
  • BPtr (90-91)
  • CPtr (94-95)
  • M (33-34)
  • N (37-38)
  • K (41-42)
  • trans_A (45-46)
  • trans_B (49-50)
  • stride_A (98-99)
  • stride_B (102-103)
  • offset_A (106-107)
  • offset_B (110-111)
  • clear_accum (114-115)
  • k_pack (118-119)
  • wg_wait (122-123)
  • policy (126-127)
tilelang/language/experimental/gemm_sp.py (3)
tilelang/language/frame.py (2)
  • has_let_value (189-198)
  • get_let_value (201-210)
tilelang/utils/language.py (1)
  • get_buffer_region_from_load (124-138)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
examples/gemm_sp/test_example_gemm_sp.py (2)
examples/gemm_sp/example_custom_compress.py (1)
  • main (314-357)
examples/gemm_sp/example_gemm_sp.py (1)
  • main (107-146)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (3)
tilelang/utils/sparse.py (3)
  • compress (77-106)
  • randn_semi_sparse (109-128)
  • randint_semi_sparse (130-157)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (137-152)
tilelang/utils/tensor.py (2)
  • torch_assert_close (233-325)
  • map_torch_type (34-51)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (137-152)
examples/gemm_sp/example_gemm_sp.py (4)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (137-152)
tilelang/language/experimental/gemm_sp.py (1)
  • gemm_sp_v2 (89-307)
examples/gemm_sp/example_custom_compress.py (1)
  • kernel (258-309)
tilelang/utils/sparse.py (1)
  • compress (77-106)
tilelang/tileop/__init__.py (2)
tilelang/language/experimental/gemm_sp.py (1)
  • gemm_sp (9-86)
tilelang/tileop/gemm_sp/__init__.py (1)
  • GemmSPPy (25-65)
tilelang/utils/sparse.py (1)
tilelang/utils/tensor.py (2)
  • is_float8 (10-16)
  • fp8_remove_negative_zeros_ (18-22)
tilelang/ir.py (1)
src/op/gemm_sp.h (5)
  • tvm (13-131)
  • GemmSPWarpPolicy (29-53)
  • GemmSPWarpPolicy (34-38)
  • GemmSPWarpPolicy (40-44)
  • GemmSPWarpPolicy (46-52)
examples/gemm_sp/example_custom_compress.py (3)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (137-152)
tilelang/utils/sparse.py (3)
  • compress (77-106)
  • randn_semi_sparse (109-128)
  • arange_semi_sparse (159-182)
tilelang/utils/tensor.py (2)
  • torch_assert_close (233-325)
  • map_torch_type (34-51)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (4)
tilelang/utils/sparse.py (3)
  • compress (77-106)
  • randn_semi_sparse (109-128)
  • randint_semi_sparse (130-157)
tilelang/utils/tensor.py (2)
  • torch_assert_close (233-325)
  • map_torch_type (34-51)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (137-152)
tilelang/intrinsics/mma_sp_macro_generator.py (1)
  • SparseTensorCoreIntrinEmitter (39-867)
src/op/gemm_sp_py.h (2)
src/op/operator.h (2)
  • TileOperatorNode (56-96)
  • TileOperator (63-95)
src/op/gemm_sp_py.cc (11)
  • CheckWGMMA (143-193)
  • CheckWGMMA (143-143)
  • Lower (222-255)
  • Lower (222-222)
  • InferLayout (257-272)
  • InferLayout (257-258)
  • Clone (92-95)
  • Clone (92-92)
  • GetGemmInst (97-111)
  • GetGemmInst (97-97)
  • GemmSPPy (50-82)
tilelang/tileop/gemm/__init__.py (4)
tilelang/language/ast/ir.py (2)
  • Range (1716-1728)
  • target (1682-1713)
tilelang/ir.py (1)
  • GemmWarpPolicy (30-39)
tilelang/tileop/gemm/gemm_mma.py (2)
  • GemmMMA (13-212)
  • infer_layout (15-58)
tilelang/tileop/gemm/gemm_base.py (1)
  • infer_layout (14-15)
tilelang/profiler/__init__.py (2)
tilelang/utils/tensor.py (1)
  • is_float8 (10-16)
tilelang/engine/param.py (1)
  • is_float8 (81-91)
tilelang/tileop/gemm_sp/gemm_sp_base.py (4)
tilelang/utils/language.py (2)
  • is_shared (25-39)
  • is_fragment (68-78)
tilelang/carver/roller/node.py (1)
  • Node (93-175)
tilelang/tileop/gemm_sp/gemm_sp_mma.py (6)
  • infer_layout (14-58)
  • lower (60-232)
  • is_gemm_ss (234-235)
  • is_gemm_sr (237-238)
  • is_gemm_rs (240-241)
  • is_gemm_rr (243-244)
tilelang/tileop/gemm_sp/__init__.py (2)
  • infer_layout (52-57)
  • lower (59-65)
benchmark/matmul/benchmark_matmul_sp.py (3)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (137-152)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)
  • matmul_sp (9-60)
tilelang/language/experimental/gemm_sp.py (1)
  • gemm_sp_v2 (89-307)
tilelang/intrinsics/mma_sp_layout.py (1)
tilelang/intrinsics/mma_layout.py (4)
  • mma_load_a_32x4_to_shared_16x8_layout (130-133)
  • mma_load_a_32x16_to_shared_16x32_layout (142-145)
  • mma_load_a_32x8_to_shared_16x16_layout (147-160)
  • ldmatrix_trans_32x8_to_shared_16x16_layout (24-27)
tilelang/intrinsics/mma_macro_generator.py (1)
tilelang/intrinsics/mma_layout.py (4)
  • mma_load_b_32x8_to_shared_16x16_layout (167-179)
  • mma_load_a_32x16_to_shared_16x32_layout (142-145)
  • mma_load_b_32x16_to_shared_16x32_layout (162-165)
  • mma_load_a_32x8_to_shared_16x16_layout (147-160)
src/op/gemm_sp_py.cc (1)
src/op/gemm_sp_py.h (2)
  • GemmSPPy (119-124)
  • RegisterReflection (41-65)
tilelang/intrinsics/mma_sp_macro_generator.py (4)
tilelang/intrinsics/utils.py (2)
  • mma_store_index_map (81-82)
  • get_ldmatrix_offset (21-63)
tilelang/utils/language.py (2)
  • is_fragment (68-78)
  • is_global (12-22)
tilelang/intrinsics/mma_sp_layout.py (20)
  • shared_16x16_to_mma_sp_layout_sr_a (14-15)
  • shared_16x16_to_mma_sp_layout_sr_b (17-19)
  • shared_16x32_to_mma_sp_layout_sr_a (21-22)
  • shared_16x32_to_mma_sp_layout_sr_b (24-26)
  • shared_16x64_to_mma_sp_layout_sr_a (28-29)
  • shared_16x64_to_mma_sp_layout_sr_b (31-33)
  • mma_sp_load_a_32x4_to_shared_16x16_layout (35-36)
  • mma_sp_load_a_32x8_to_shared_16x32_layout (38-39)
  • mma_sp_load_a_32x16_to_shared_16x64_layout (41-42)
  • mma_sp_load_b_32x8_to_shared_16x16_layout (44-47)
  • mma_sp_load_b_32x16_to_shared_16x32_layout (49-52)
  • mma_sp_load_b_32x32_to_shared_16x64_layout (54-57)
  • metadata_8bit_load_32x4_to_shared_16x4_layout_32bit (63-67)
  • metadata_16bit_load_32x2_to_shared_16x2_layout_32bit (69-73)
  • metadata_8bit_load_32x4_to_shared_16x4_layout_16bit (75-76)
  • metadata_16bit_load_32x2_to_shared_16x2_layout_16bit (78-79)
  • metadata_8bit_load_32x4_to_shared_16x4_layout_8bit (84-88)
  • metadata_16bit_load_32x2_to_shared_16x4_layout_8bit (90-94)
  • metadata_32bit_load_32x1_to_shared_16x2_layout_8bit (96-101)
  • get_ldmatrix_offset_b (123-157)
tilelang/language/tir/op.py (3)
  • ptx_ldmatrix (1123-1159)
  • address_of (463-479)
  • ptx_mma_sp (963-1061)
🪛 Clang (14.0.6)
src/op/gemm_sp.cc

[error] 306-306: variable 'TVM_FFI_STATIC_INIT_BLOCK' is non-const and globally accessible, consider making it const

(cppcoreguidelines-avoid-non-const-global-variables,-warnings-as-errors)

src/op/gemm_sp_py.cc

[error] 93-93: variable name 'op' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 97-97: method 'GetGemmInst' can be made static

(readability-convert-member-functions-to-static,-warnings-as-errors)


[error] 97-97: 2 adjacent parameters of 'GetGemmInst' of similar type ('int') are easily swapped by mistake

(bugprone-easily-swappable-parameters,-warnings-as-errors)


[error] 98-98: variable 'warp_size' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 100-100: variable 'allow_wgmma' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 104-104: do not use 'else' after 'return'

(readability-else-after-return,-warnings-as-errors)


[error] 143-143: function 'CheckWGMMA' has cognitive complexity of 40 (threshold 25)

(readability-function-cognitive-complexity,-warnings-as-errors)


[error] 149-149: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 150-150: 16 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 151-151: do not use 'else' after 'return'

(readability-else-after-return,-warnings-as-errors)


[error] 151-151: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 152-152: repeated branch in conditional chain

(bugprone-branch-clone,-warnings-as-errors)


[error] 152-152: 32 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 153-153: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 154-154: 32 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 155-155: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 156-156: 32 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 157-157: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 158-158: 32 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 159-159: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 162-162: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 163-163: repeated branch in conditional chain

(bugprone-branch-clone,-warnings-as-errors)


[error] 163-163: 16 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 164-164: do not use 'else' after 'return'

(readability-else-after-return,-warnings-as-errors)


[error] 165-165: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 166-166: 16 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 167-167: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 168-168: 8 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 169-169: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 170-170: repeated branch in conditional chain

(bugprone-branch-clone,-warnings-as-errors)


[error] 170-170: 32 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 171-171: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 172-172: 32 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 173-173: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 174-174: 32 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 175-175: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 176-176: 32 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 177-177: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 180-180: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 181-181: repeated branch in conditional chain

(bugprone-branch-clone,-warnings-as-errors)


[error] 181-181: 32 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 182-182: do not use 'else' after 'return'

(readability-else-after-return,-warnings-as-errors)


[error] 182-182: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 183-183: 32 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 184-184: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 185-185: 32 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 186-186: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 187-187: 32 is a magic number; consider replacing it with a named constant

(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers,-warnings-as-errors)


[error] 188-188: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 211-211: variable name 's' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 213-213: variable 'arch' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 224-224: variable 'gemm_inst' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 236-236: variable 'block_realize' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 239-239: variable 'n' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 239-239: variable name 'n' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 259-259: statement should be inside braces

(readability-braces-around-statements,-warnings-as-errors)


[error] 261-261: variable 'results' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)


[error] 263-263: if with identical then and else branches

(bugprone-branch-clone,-warnings-as-errors)


[error] 274-274: variable 'TVM_REGISTER_OP' is non-const and globally accessible, consider making it const

(cppcoreguidelines-avoid-non-const-global-variables,-warnings-as-errors)


[error] 274-274: variable name 'op' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 279-279: variable 'TVM_FFI_STATIC_INIT_BLOCK' is non-const and globally accessible, consider making it const

(cppcoreguidelines-avoid-non-const-global-variables,-warnings-as-errors)

🪛 LanguageTool
docs/deeplearning_operators/matmul_sparse.md

[grammar] ~50-~50: Ensure spelling is correct
Context: ... in A_sparse/A and E. (i.e. the 4-elment group at [n, k] doesn't match the 4-bit...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)


[grammar] ~137-~137: Ensure spelling is correct
Context: ...each 4-bit chunk represents two 2-bit indcies of non-zero elements within four cons...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)

🪛 markdownlint-cli2 (0.18.1)
docs/deeplearning_operators/matmul_sparse.md

39-39: Link text should be descriptive

(MD059, descriptive-link-text)


39-39: Link text should be descriptive

(MD059, descriptive-link-text)


176-176: Link text should be descriptive

(MD059, descriptive-link-text)

🪛 Ruff (0.14.3)
tilelang/layout/__init__.py

7-7: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/language/__init__.py

50-50: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/tileop/gemm_sp/gemm_sp_mma.py

57-58: Avoid specifying long messages outside the exception class

(TRY003)


231-232: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/tileop/gemm_sp/__init__.py

57-57: Avoid specifying long messages outside the exception class

(TRY003)


65-65: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/experimental/gemm_sp.py

153-153: Undefined name get_buffer_region_from_load

(F821)


159-160: Prefer TypeError exception for invalid type

(TRY004)


159-160: Avoid specifying long messages outside the exception class

(TRY003)


187-188: Prefer TypeError exception for invalid type

(TRY004)


187-188: Avoid specifying long messages outside the exception class

(TRY003)


191-191: Local variable E_shape is assigned to but never used

Remove assignment to unused variable E_shape

(F841)


239-239: Undefined name get_buffer_region_from_load

(F821)


253-254: Prefer TypeError exception for invalid type

(TRY004)


253-254: Avoid specifying long messages outside the exception class

(TRY003)


267-267: Undefined name get_buffer_region_from_load

(F821)


273-274: Prefer TypeError exception for invalid type

(TRY004)


273-274: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/tileop/__init__.py

2-2: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/layout/gemm_sp.py

116-116: Avoid specifying long messages outside the exception class

(TRY003)


119-119: Avoid specifying long messages outside the exception class

(TRY003)

examples/gemm_sp/example_custom_compress.py

115-117: Avoid specifying long messages outside the exception class

(TRY003)


120-120: Local variable device is assigned to but never used

Remove assignment to unused variable device

(F841)


128-128: Avoid specifying long messages outside the exception class

(TRY003)


131-131: Avoid specifying long messages outside the exception class

(TRY003)


135-137: Avoid specifying long messages outside the exception class

(TRY003)


140-142: Avoid specifying long messages outside the exception class

(TRY003)


144-146: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/mma_sp_layout.py

96-96: Unused function argument: local_id

(ARG001)


139-139: Avoid specifying long messages outside the exception class

(TRY003)


155-155: Avoid specifying long messages outside the exception class

(TRY003)


157-157: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/mma_sp_macro_generator.py

51-59: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


61-118: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


120-128: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


178-180: Avoid specifying long messages outside the exception class

(TRY003)


213-213: Avoid specifying long messages outside the exception class

(TRY003)


320-320: Avoid specifying long messages outside the exception class

(TRY003)


389-389: Avoid specifying long messages outside the exception class

(TRY003)


398-398: Avoid specifying long messages outside the exception class

(TRY003)


403-403: Avoid specifying long messages outside the exception class

(TRY003)


405-405: Avoid specifying long messages outside the exception class

(TRY003)


456-456: Avoid specifying long messages outside the exception class

(TRY003)


551-551: Local variable thread_binding is assigned to but never used

Remove assignment to unused variable thread_binding

(F841)


705-705: Avoid specifying long messages outside the exception class

(TRY003)


722-722: Avoid specifying long messages outside the exception class

(TRY003)


780-780: Avoid specifying long messages outside the exception class

(TRY003)


794-794: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (52)
src/target/ptx.cc (1)

608-608: PTX syntax for sparse MMA is correct and recommended.

The ".sp::ordered_metadata" syntax is the officially recommended form for sparse MMA operations, introduced in PTX ISA v8.5. Using ::ordered_metadata is recommended because the generic mma.sp form may have substantially reduced performance on some target architectures. This syntax is supported on modern SMs including sm80 and sm89. The code change correctly implements this best practice.

tilelang/layout/__init__.py (1)

7-7: LGTM: Public API export updated correctly.

The rename from make_metadata_layout to make_cutlass_metadata_layout aligns with the broader migration to CUTLASS-specific metadata layouts throughout the codebase.

tilelang/utils/tensor.py (2)

10-16: LGTM: Clean float8 dtype check.

The function correctly identifies all torch float8 variants.


18-22: Add unit tests for fp8_remove_negative_zeros_ to verify correctness across all float8 formats.

The function's bit manipulation logic is sound: it correctly identifies both positive and negative zeros using PyTorch's == operator (which follows IEEE-754 semantics where ±0.0 compare equal), then converts them all to positive zero (0x00). This works correctly across all supported float8 formats (e4m3fn, e4m3fnuz, e5m2, e5m2fnuz) since they all represent negative zero with the sign bit set (0x80) and positive zero as all bits zero (0x00).

However, no explicit unit tests were found for this critical function. Since it directly impacts sparse tensor compression, add tests that:

  • Verify negative zeros are correctly identified and converted
  • Confirm behavior is consistent across all four float8 dtypes
  • Ensure the function modifies tensors in-place as expected
tilelang/language/__init__.py (1)

50-50: LGTM: New gemm_sp_v2 API exported correctly.

Both gemm_sp and gemm_sp_v2 are now available, maintaining backward compatibility while introducing the enhanced v2 API.

tilelang/tileop/__init__.py (1)

2-2: LGTM: GemmSPPy operator exported correctly.

The new sparse GEMM operator class is now available through the public API.

src/op/gemm_sp.h (1)

21-22: LGTM: TVM object metadata added correctly.

The type key and reflection macros properly integrate GemmSPWarpPolicyNode into the TVM object system, following the same pattern as the base GemmWarpPolicyNode.

examples/gemm_sp/test_example_gemm_sp.py (1)

1-15: LGTM: Test structure is clean and follows conventions.

The test file properly wraps the example scripts with the testing harness. The tests rely on assertions within the main() functions of each example module.

src/op/gemm.h (1)

34-34: LGTM: Required change to enable inheritance.

Changing from TVM_DECLARE_FINAL_OBJECT_INFO to TVM_DECLARE_BASE_OBJECT_INFO is necessary to allow GemmSPWarpPolicyNode (in src/op/gemm_sp.h) to properly inherit from GemmWarpPolicyNode.

docs/deeplearning_operators/matmul_sparse.md (1)

1-261: Excellent documentation for the sparse GEMM feature.

The documentation is comprehensive and well-structured, covering:

  • Structured sparsity concepts
  • Compression workflow
  • Both T.gemm_sp and T.gemm_sp_v2 APIs
  • Practical code examples
  • Important notes on layout differences

The examples and explanations will help users understand when and how to use sparse GEMM operations.

src/op/gemm_sp.cc (1)

306-316: LGTM! FFI registration follows established patterns.

The FFI registration correctly exposes GemmSPWarpPolicyComputeWarpPartition to Python. The lambda calls ComputeWarpPartition for its side effects (mutating m_warp and n_warp fields), which the Python side then retrieves, matching the pattern used in GemmWarpPolicy.

Note: The static analysis warning about non-const global variables is a false positive for the TVM_FFI_STATIC_INIT_BLOCK macro.

tilelang/tileop/gemm/__init__.py (3)

7-7: LGTM! Type annotation improvement.

Adding explicit Range type import improves type safety and aligns with the updated function signatures.


15-17: LGTM! Proper use of Range type.

The updated signature and extraction of thread_nums from thread_bounds.extent is correct and improves type safety.


20-23: LGTM! Consistent Range usage.

The pattern matches gemm_py_infer_layout and correctly extracts thread numbers from the Range.

tilelang/ir.py (1)

41-51: LGTM! New sparse GEMM warp policy class.

The GemmSPWarpPolicy class is well-structured and follows the same pattern as GemmWarpPolicy. The additional bits parameter in compute_warp_partition is appropriate for handling different data type sizes in sparse operations.

tilelang/intrinsics/mma_layout.py (2)

147-160: LGTM! Well-documented layout helper.

The mma_load_a_32x8_to_shared_16x16_layout function is clearly documented with its mapping logic and correctly implements the 32x8 to 16x16 layout transformation.


167-179: LGTM! Complementary B-matrix layout.

The mma_load_b_32x8_to_shared_16x16_layout function properly implements the corresponding layout for the B matrix with clear documentation.

benchmark/matmul/benchmark_matmul_sp.py (4)

12-12: LGTM! Updated to new layout API.

Correctly imports the renamed make_cutlass_metadata_layout function.


89-89: LGTM! Added in_dtype parameter.

The function signature now properly includes in_dtype for flexible data type support.


206-210: LGTM! Using updated layout helper.

Correctly uses make_cutlass_metadata_layout with appropriate parameters for both global and shared metadata buffers.


222-229: LGTM! Updated to gemm_sp_v2 API.

The call correctly uses the new T.gemm_sp_v2 API with appropriate parameters.

tilelang/utils/sparse.py (4)

92-103: LGTM! Comprehensive compress_sm80 enhancement.

The addition of transposed flag support and float8 handling is well-implemented. The conversion to int8 for compression and back to the original float8 dtype maintains type correctness.


120-121: LGTM! Float32 sparsity support.

Correctly adjusts to 1:2 sparsity pattern for float32 dtype, which is appropriate for this data type.


130-157: LGTM! New integer semi-sparse generator.

The randint_semi_sparse function is well-implemented and follows the same pattern as randn_semi_sparse. The docstring is clear and comprehensive.


174-175: LGTM! Consistent float32 handling.

Correctly applies the same 1:2 sparsity pattern for float32 dtype as in other semi-sparse generators.

examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (2)

4-4: LGTM! Updated to new layout API.

Correctly imports the renamed make_cutlass_metadata_layout function.


43-50: LGTM! Consistent layout usage.

Both metadata layout annotations correctly use make_cutlass_metadata_layout with appropriate SM90 architecture specification and block_k parameter.

tilelang/tileop/gemm_sp/gemm_sp_mma.py (2)

14-58: Layout inference implementation looks correct.

The infer_layout method correctly handles the four GEMM dataflow variants (ss, sr, rs, rr) and delegates to the appropriate layout builders. The use of SparseTensorCoreIntrinEmitter for layout management is consistent with the design.


92-131: SS (shared-shared) kernel variant is well-structured.

The _gemm_ssr kernel correctly allocates local fragments for A, E, and B, loads them from shared memory, and performs the sparse MMA operation. The simplification pass with inline_let=True is appropriate for optimizing index computations.

examples/gemm_sp/example_gemm_sp.py (2)

8-8: API migration to gemm_sp_v2 and make_cutlass_metadata_layout is correct.

The changes consistently update the example to use the new v2 API:

  • Import path updated to make_cutlass_metadata_layout
  • Kernel call updated from T.gemm_sp to T.gemm_sp_v2
  • Layout factory calls updated to use make_cutlass_metadata_layout with explicit arch parameter

Also applies to: 86-93, 99-99


17-58: Configuration improvements enhance usability.

The renaming to DEFAULT_CONFIG (line 17), addition of ARCH_INFO (line 60), and providing a default value for --cfg (line 118) make the example more accessible and self-contained.

Also applies to: 60-60, 118-121

tilelang/intrinsics/mma_macro_generator.py (2)

21-24: 16-bit load layout integration is correct.

The changes properly integrate the new 32x8 → 16x16 load layouts for 16-bit A and B matrices when ldmatrix is not available. The layout functions mma_load_a_32x8_to_shared_16x16_layout and mma_load_b_32x8_to_shared_16x16_layout are imported and used consistently.

Also applies to: 223-224, 291-292


263-263: Transpose-aware indexing correctly handles memory access patterns.

The conditional indexing based on the trans flag ensures that elements are loaded from the correct memory locations for both transposed and non-transposed cases:

  • Line 263 (A): Uses A_shared_buf[wk + mk, wi + mi] when transposed, else A_shared_buf[wi + mi, wk + mk]
  • Line 337 (B): Uses B_shared_buf[wi + mi, wk + mk] when b_transposed, else B_shared_buf[wk + mk, wi + mi]

Also applies to: 337-337

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (4)

14-28: New generate_dense_input helper improves test clarity.

The centralized input generation function properly handles different data types (int vs float, 8-bit vs others, signed vs unsigned) and transpose modes. This reduces duplication and makes tests more maintainable.


111-112: Metadata dtype determination is cleaner and more robust.

The new approach using SparseTensorCoreIntrinEmitter.E_FACTOR_MAP provides a cleaner and more maintainable way to determine the correct metadata dtype based on input dtype, replacing the previous hardcoded logic.


65-78: Fragment buffer naming and layout updates are correct.

The changes from C_local to C_frag (lines 65, 78) accurately reflect that these are fragment buffers, and the updated make_cutlass_metadata_layout calls (lines 68-75, 134-137) align with the new API.

Also applies to: 131-139


209-216: Test assertion improvements enhance debugging.

The use of torch_assert_close with named tensors (base_name="tilelang_sp", ref_name="ref_dense") provides better error messages and aligns with the broader test suite improvements.

tilelang/language/experimental/gemm_sp.py (1)

278-281: No issues found; the offset assertion logic is correct for higher-dimensional tensors.

The code uses -2 indexing intentionally and correctly. The shape constraints (lines 121-130 in gemm.py) enforce that for tensors with more than 2 dimensions, all non-matrix dimensions must equal 1. This means the last two dimensions always contain the matrix data, making A_offset[-2] and A_offset[-1] consistently refer to the offsets of the first and second matrix dimensions, regardless of total tensor dimensionality.

For example:

  • 2D tensor: A_offset has 2 elements; [-2] → first matrix dimension ✓
  • 3D tensor: A_offset has 3 elements; [-2] → second-to-last (still first matrix dimension) ✓

The error message could be more explicit (e.g., "first matrix dimension"), but the logic is sound.

tilelang/tileop/gemm_sp/gemm_sp_base.py (2)

20-30: LGTM!

The is_gemm_* helper methods clearly express the four GEMM scope combinations (shared-shared, shared-register, register-shared, register-register) and correctly delegate to is_shared and is_fragment utilities.


32-127: LGTM!

The property accessors provide a clean, read-only interface to the underlying gemm_sp_node attributes. The assertion in in_dtype (line 58) ensures A and B have matching data types, which is a good defensive check.

examples/gemm_sp/example_custom_compress.py (2)

314-361: LGTM!

The main function has a good structure: argument parsing, input generation, compression (with alternative compressor support), correctness validation, and performance benchmarking. The assertion at line 337 ensures torch_compress is only used with the naive layout, which is a good safeguard.


66-108: LGTM!

The matmul_sp_fp16_custom_compress kernel correctly implements the sparse GEMM with optional Cutlass metadata layout. The conditional layout annotation (lines 84-95) properly handles both layout variants.

tilelang/tileop/gemm_sp/__init__.py (1)

24-65: LGTM!

The GemmSPPy class provides a clean FFI-compatible wrapper with proper field annotations. The delegation to GemmSPMMA for CUDA targets is the correct pattern, allowing target-specific implementations while maintaining a uniform interface.

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (3)

10-136: LGTM!

The matmul and run_gemm_ss functions provide a clean testing framework with proper layout annotations, compilation configurations, and validation against PyTorch references. The use of SparseTensorCoreIntrinEmitter.E_FACTOR_MAP ensures correct metadata dimensions across different dtype combinations.


184-664: LGTM!

The test functions for rs, sr, and rr variants follow a consistent structure with comprehensive coverage of transpose combinations, data types (float16, int8, float8, bfloat16), and edge cases (n8 dimensions). The validation against PyTorch references ensures correctness across all scope combinations.


158-158: The review comment is incorrect. Transposed A compression is fully supported.

The compress_sm90.cu template has native transposed support with conditional layout handling (ColumnMajor for transposed, RowMajor otherwise), dimension extraction based on the transposed flag, and proper output shape handling. The tests already exercise trans_A=True cases (lines 161–162, 171–172, 175), and the compress() function passes the transposed parameter through to the implementation. The TODO comment appears outdated.

Likely an incorrect or invalid review comment.

tilelang/intrinsics/mma_sp_layout.py (6)

1-12: LGTM!

The imports are well-organized and all appear to be used in the module. Good reuse of existing layout functions from mma_layout.


14-33: LGTM!

The layout conversion functions are well-structured. The A matrix functions appropriately delegate to existing implementations, while the B matrix functions provide new coordinate transformations specific to sparse layouts.


35-42: LGTM!

The A matrix load helpers correctly delegate to existing implementations with appropriate dimension mapping.


44-57: LGTM!

The B matrix load helpers provide new coordinate transformations that follow consistent patterns with appropriate dimension handling.


123-157: LGTM!

The function has good error handling with clear, specific error messages that aid debugging. The dtype branching logic correctly handles different bit-widths and transposition modes.

Note: The static analysis warnings about "long error messages" (lines 139, 155, 157) are minor style preferences. The inline error messages are clear and appropriate for simple validation errors.


103-106: ****

The different implementation is intentional and by design. The mma_sp_layout.py module explicitly does not import ldmatrix_trans_32x8_to_shared_16x16_layout from mma_layout.py—the import list shows it selectively imports 6 other functions but omits this one. Instead, mma_sp_layout.py defines its own SP-specific version (lines 103–106) and uses it locally (line 135). This is a valid design pattern: the sparse matrix module maintains its own specialized layout transformation logic separate from the general module. There is no oversight here.

Likely an incorrect or invalid review comment.

Comment on lines 209 to 220
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
} else {
arch_int = 0;
}
return arch_int;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add error handling for std::stoi.

The GetArchInt function uses std::stoi without error handling. If the architecture string contains non-numeric characters after "sm_" (e.g., "sm_abc"), std::stoi will throw std::invalid_argument or std::out_of_range, potentially causing unexpected crashes.

Apply this diff to add error handling:

 static int GetArchInt(Target target) {
   int arch_int = 0;
   auto s = target->GetAttr<String>("arch");
   ICHECK(s.defined());
   std::string arch = s.value();
   if (arch.rfind("sm_", 0) == 0) {
-    arch_int = std::stoi(arch.substr(3));
+    try {
+      arch_int = std::stoi(arch.substr(3));
+    } catch (const std::exception& e) {
+      LOG(WARNING) << "Failed to parse architecture number from '" << arch << "': " << e.what();
+      arch_int = 0;
+    }
   } else {
     arch_int = 0;
   }
   return arch_int;
 }
🧰 Tools
🪛 Clang (14.0.6)

[error] 211-211: variable name 's' is too short, expected at least 3 characters

(readability-identifier-length,-warnings-as-errors)


[error] 213-213: variable 'arch' is not initialized

(cppcoreguidelines-init-variables,-warnings-as-errors)

Comment on lines 67 to 79
bool SEqualReduce(const GemmSPPyNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) &&
equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(stride_A, other->stride_A) &&
equal(stride_B, other->stride_B) &&
equal(offset_A, other->offset_B) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
equal(policy, other->policy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

SEqual must compare the metadata buffer correctly

SEqualReduce skips E/Eptr entirely and even checks offset_A against other->offset_B. That lets operators with different metadata collapse together in the structural cache, yielding wrong reuse. Please include the metadata fields and fix the offset comparison.

-    return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
-           equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) &&
-           equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
+    return equal(A, other->A) && equal(E, other->E) && equal(B, other->B) &&
+           equal(C, other->C) && equal(Aptr, other->Aptr) &&
+           equal(Eptr, other->Eptr) && equal(Bptr, other->Bptr) &&
+           equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
            equal(trans_B, other->trans_B) && equal(M, other->M) &&
            equal(N, other->N) && equal(K, other->K) &&
            equal(stride_A, other->stride_A) &&
            equal(stride_B, other->stride_B) &&
-           equal(offset_A, other->offset_B) &&
+           equal(offset_A, other->offset_A) &&
            equal(offset_B, other->offset_B) &&
            equal(clear_accum, other->clear_accum) &&
            equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
            equal(policy, other->policy);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
bool SEqualReduce(const GemmSPPyNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) &&
equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(stride_A, other->stride_A) &&
equal(stride_B, other->stride_B) &&
equal(offset_A, other->offset_B) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
equal(policy, other->policy);
bool SEqualReduce(const GemmSPPyNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(E, other->E) && equal(B, other->B) &&
equal(C, other->C) && equal(Aptr, other->Aptr) &&
equal(Eptr, other->Eptr) && equal(Bptr, other->Bptr) &&
equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(stride_A, other->stride_A) &&
equal(stride_B, other->stride_B) &&
equal(offset_A, other->offset_A) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
equal(policy, other->policy);
🤖 Prompt for AI Agents
In src/op/gemm_sp_py.h around lines 67 to 79, the SEqualReduce implementation
omits comparing the metadata fields E and Eptr and incorrectly compares offset_A
to other->offset_B; update the boolean chain to include equal(E, other->E) &&
equal(Eptr, other->Eptr) and change the incorrect comparison to equal(offset_A,
other->offset_A) so all metadata buffers and offsets are checked correctly.

Comment on lines 82 to 101
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(A);
hash_reduce(B);
hash_reduce(C);
hash_reduce(Aptr);
hash_reduce(Bptr);
hash_reduce(Cptr);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(stride_A);
hash_reduce(stride_B);
hash_reduce(offset_A);
hash_reduce(offset_B);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
hash_reduce(policy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Hash must cover the metadata too

SHashReduce omits both E and Eptr, so nodes that only differ in metadata hash to the same bucket. After fixing SEqualReduce, this still causes excessive collisions – please fold the metadata into the hash.

-    hash_reduce(A);
-    hash_reduce(B);
-    hash_reduce(C);
-    hash_reduce(Aptr);
-    hash_reduce(Bptr);
-    hash_reduce(Cptr);
+    hash_reduce(A);
+    hash_reduce(E);
+    hash_reduce(B);
+    hash_reduce(C);
+    hash_reduce(Aptr);
+    hash_reduce(Eptr);
+    hash_reduce(Bptr);
+    hash_reduce(Cptr);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(A);
hash_reduce(B);
hash_reduce(C);
hash_reduce(Aptr);
hash_reduce(Bptr);
hash_reduce(Cptr);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(stride_A);
hash_reduce(stride_B);
hash_reduce(offset_A);
hash_reduce(offset_B);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
hash_reduce(policy);
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(A);
hash_reduce(E);
hash_reduce(B);
hash_reduce(C);
hash_reduce(Aptr);
hash_reduce(Eptr);
hash_reduce(Bptr);
hash_reduce(Cptr);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(stride_A);
hash_reduce(stride_B);
hash_reduce(offset_A);
hash_reduce(offset_B);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
hash_reduce(policy);
🤖 Prompt for AI Agents
In src/op/gemm_sp_py.h around lines 82 to 101, SHashReduce currently omits
hashing the metadata fields E and Eptr causing different-metadata nodes to
collide; update SHashReduce to call hash_reduce(E) and hash_reduce(Eptr)
(placing them logically with the other pointer/metadata fields, maintaining
consistent order) so the metadata is folded into the hash and collisions are
avoided.

Comment on lines 543 to 547
a_is_fragment = is_fragment(A_local_buf)
e_is_fragment = is_fragment(E_local_buf)
b_is_fragment = is_fragment(B_local_buf)
assert not e_is_fragment, f"currently E_local_buf must be a fragment buffer, found {E_local_buf.scope()}"
a_local_stride: PrimExpr = k_inner * warp_rows * local_size_a if a_is_fragment else 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix fragment assertion for metadata buffer

is_fragment(E_local_buf) must be True, but the assertion requires it to be False, so we abort as soon as the metadata fragment is passed in. Please drop the negation.

-        assert not e_is_fragment, f"currently E_local_buf must be a fragment buffer, found {E_local_buf.scope()}"
+        assert e_is_fragment, f"currently E_local_buf must be a fragment buffer, found {E_local_buf.scope()}"
🤖 Prompt for AI Agents
In tilelang/intrinsics/mma_sp_macro_generator.py around lines 543 to 547, the
assertion for E_local_buf incorrectly negates is_fragment so it fails when
E_local_buf is a fragment; change the assertion to require
is_fragment(E_local_buf) is True (remove the `not`) so the code asserts
E_local_buf is a fragment and keep the error message using E_local_buf.scope()
for context.

Comment on lines 13 to 16
@tvm.ffi.register_func("tl.gemm_sp_py.infer_layout")
def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range):
thread_nums = thread_bounds.extent
return gemm_sp_py.infer_layout(target, thread_nums)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Correct the parameter type annotation.

The function is registered to receive a GemmSPPy object (based on the C++ call in src/op/gemm_sp_py.cc lines 229-231), but the type annotation indicates GemmSPMMA. This mismatch could cause confusion and type-checking failures.

Apply this diff:

 @tvm.ffi.register_func("tl.gemm_sp_py.infer_layout")
-def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range):
+def gemm_sp_py_infer_layout(gemm_sp_py: "GemmSPPy", target: Target, thread_bounds: Range):
     thread_nums = thread_bounds.extent
     return gemm_sp_py.infer_layout(target, thread_nums)

Note: Use forward reference "GemmSPPy" since the class is defined later in the file.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@tvm.ffi.register_func("tl.gemm_sp_py.infer_layout")
def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range):
thread_nums = thread_bounds.extent
return gemm_sp_py.infer_layout(target, thread_nums)
@tvm.ffi.register_func("tl.gemm_sp_py.infer_layout")
def gemm_sp_py_infer_layout(gemm_sp_py: "GemmSPPy", target: Target, thread_bounds: Range):
thread_nums = thread_bounds.extent
return gemm_sp_py.infer_layout(target, thread_nums)
🤖 Prompt for AI Agents
In tilelang/tileop/gemm_sp/__init__.py around lines 13 to 16, the parameter type
annotation for gemm_sp_py_infer_layout incorrectly uses GemmSPMMA when the
registered function actually receives a GemmSPPy; update the annotation to use
the forward reference "GemmSPPy" (a string) for the gemm_sp_py parameter so it
matches the C++ call and avoids type-checking confusion, leaving the rest of the
signature and body unchanged.

Comment on lines 18 to 22
@tvm.ffi.register_func("tl.gemm_sp_py.lower")
def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range, thread_var: tir.Var):
thread_nums = thread_bounds.extent
stmt = gemm_sp_py.lower(target, thread_nums, thread_var)
return stmt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Correct the parameter type annotation.

The function is registered to receive a GemmSPPy object, but the type annotation indicates GemmSPMMA.

Apply this diff:

 @tvm.ffi.register_func("tl.gemm_sp_py.lower")
-def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range, thread_var: tir.Var):
+def gemm_sp_py_lower(gemm_sp_py: "GemmSPPy", target: Target, thread_bounds: Range, thread_var: tir.Var):
     thread_nums = thread_bounds.extent
     stmt = gemm_sp_py.lower(target, thread_nums, thread_var)
     return stmt

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (1)
src/op/gemm_sp_py.cc (1)

210-221: Add error handling for std::stoi.

The previous review correctly identified that std::stoi at line 216 can throw std::invalid_argument or std::out_of_range if the architecture string contains non-numeric characters (e.g., "sm_abc"). This remains unaddressed.

🧹 Nitpick comments (1)
benchmark/matmul/benchmark_matmul_sp.py (1)

95-114: Docstring omits the new in_dtype argument.

Please update the parameter section to describe both in_dtype and accum_dtype so the public API matches the signature.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2efb496 and 295a689.

📒 Files selected for processing (9)
  • benchmark/matmul/benchmark_matmul_sp.py (7 hunks)
  • docs/index.md (1 hunks)
  • examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (2 hunks)
  • src/op/gemm_sp.cc (1 hunks)
  • src/op/gemm_sp.h (1 hunks)
  • src/op/gemm_sp_py.cc (1 hunks)
  • src/op/gemm_sp_py.h (1 hunks)
  • src/target/ptx.cc (1 hunks)
  • src/tl_templates/cuda/debug.h (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • docs/index.md
  • src/target/ptx.cc
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
🧬 Code graph analysis (6)
src/op/gemm_sp.h (1)
src/op/gemm.h (2)
  • RegisterReflection (35-41)
  • RegisterReflection (106-128)
src/op/gemm_sp.cc (1)
src/op/gemm_sp.h (1)
  • RegisterReflection (27-33)
benchmark/matmul/benchmark_matmul_sp.py (4)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (137-152)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)
  • matmul_sp (9-60)
tilelang/language/allocate.py (1)
  • alloc_shared (27-42)
tilelang/language/experimental/gemm_sp.py (1)
  • gemm_sp_v2 (90-308)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (137-152)
src/op/gemm_sp_py.cc (4)
src/op/gemm_sp_py.h (2)
  • GemmSPPy (79-84)
  • RegisterReflection (40-64)
tilelang/tileop/gemm_sp/__init__.py (1)
  • GemmSPPy (25-65)
src/op/gemm.h (6)
  • GemmWarpPolicy (59-83)
  • GemmWarpPolicy (64-68)
  • GemmWarpPolicy (70-74)
  • GemmWarpPolicy (76-82)
  • RegisterReflection (35-41)
  • RegisterReflection (106-128)
src/op/gemm_sp.cc (6)
  • Clone (120-123)
  • Clone (120-120)
  • Lower (143-185)
  • Lower (143-143)
  • InferLayout (215-299)
  • InferLayout (215-216)
src/op/gemm_sp_py.h (3)
src/op/gemm_sp.h (3)
  • tvm (13-112)
  • tl (15-111)
  • RegisterReflection (27-33)
src/op/operator.h (2)
  • TileOperatorNode (56-66)
  • TileOperator (68-72)
src/op/gemm_sp_py.cc (11)
  • CheckWGMMA (143-194)
  • CheckWGMMA (143-143)
  • Lower (223-256)
  • Lower (223-223)
  • InferLayout (258-273)
  • InferLayout (258-259)
  • Clone (92-95)
  • Clone (92-92)
  • GetGemmInst (97-111)
  • GetGemmInst (97-97)
  • GemmSPPy (50-82)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (14)
src/tl_templates/cuda/debug.h (1)

261-269: LGTM! Clean addition for uint16_t debugging support.

The specialization correctly follows the established pattern (mirroring the int16_t variant), uses the appropriate format specifier (%u), and casts to uint32_t for safe variadic printf usage. This addition is useful for debugging sparse GEMM metadata, which is often stored in compact 16-bit unsigned formats.

src/op/gemm_sp_py.cc (5)

92-95: LGTM!

The Clone implementation correctly follows the established pattern seen in GemmSPNode::Clone (src/op/gemm_sp.cc:120-123).


97-111: LGTM!

The instruction selection logic correctly chooses between WGMMA, MFMA, and MMA based on target architecture and constraints. The CheckWGMMA() placeholder returning false (line 101) is appropriately documented at line 144.


143-194: Placeholder implementation with extensive commented roadmap.

The method returns false with a fully documented implementation commented out. This aligns with the PR's roadmap showing WGMMA support as incomplete. The commented code and comprehensive docstring (lines 113-142) provide clear eligibility criteria for future enablement.

Is WGMMA support expected in a follow-up PR, or should this placeholder be tracked separately? The checklist in the PR description doesn't explicitly mention WGMMA as a deliverable for this PR.


258-273: LGTM!

The layout inference correctly delegates to the FFI function and guards against re-entrance with the completed_ flag, matching the pattern established in GemmSPNode::InferLayout.


275-281: LGTM!

The operator registration correctly declares 5 inputs and kOpaque call effect, and properly registers reflection in the static initialization block.

src/op/gemm_sp_py.h (3)

7-7: TODO: Address code duplication with gemm_py.h.

The TODO at line 7 notes planned refactoring to eliminate duplication with gemm_py.h. Given that this is a new sparse GEMM variant, some initial duplication is acceptable, but this should be tracked.

Is this TODO tracked in an issue or follow-up work item?


40-64: LGTM!

The reflection registration comprehensively exposes all relevant fields including the metadata buffer E and pointer Eptr, which correctly distinguishes this sparse operator from the dense GEMM variant.


79-84: LGTM!

The wrapper class follows the standard TVM object reference pattern with appropriate constructor and accessor declarations.

src/op/gemm_sp.h (1)

27-33: LGTM!

The reflection registration for GemmSPWarpPolicyNode correctly mirrors the established pattern from GemmWarpPolicyNode (src/op/gemm.h:34-40) and properly exposes policy_type, m_warp, and n_warp fields.

src/op/gemm_sp.cc (2)

306-307: LGTM!

The operator registration for tl.GemmSPWarpPolicy correctly sets the TScriptPrinterName attribute for script printing support.


309-312: LGTM!

The static initialization block correctly registers reflection for both GemmSPNode and the newly added GemmSPWarpPolicyNode.

examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (2)

4-4: LGTM! Clean API migration.

The migration from make_metadata_layout to make_cutlass_metadata_layout is consistent and correct. Removing the backend="cutlass" parameter makes sense as the backend is now implied by the function name itself.

Also applies to: 43-50


57-57: No changes required — this example is outside the PR's scope.

Both T.gemm_sp and T.gemm_sp_v2 coexist in the codebase with identical signatures. The sparse_tensorcore example targets SM90 (arch="9.0"), while this PR specifically adds support for T.gemm_sp_v2 on SM80 and SM89. Other examples in the gemm_sp/ directory have already been updated to use T.gemm_sp_v2, indicating the API direction; however, the sparse_tensorcore example remains appropriately on the legacy function for its architecture-specific use case.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (13)
docs/deeplearning_operators/matmul_sparse.md (2)

50-50: Fix typo: "elment" → "element".


137-137: Fix typo: "indcies" → "indices".

src/op/gemm_sp.cc (1)

313-319: Return the computed warp partition from the FFI binding.

The lambda currently discards the std::pair<int, int> returned by ComputeWarpPartition and returns void. The Python caller expects to receive the computed (m_warp, n_warp) values.

Apply this diff:

-      "tl.GemmSPWarpPolicyComputeWarpPartition",
-      [](GemmSPWarpPolicy policy, int M, int N, int block_size, Target target,
-         bool use_wgmma, int bits) {
-        policy->ComputeWarpPartition(M, N, block_size, target, use_wgmma, bits);
-        return;
+      "tl.GemmSPWarpPolicyComputeWarpPartition",
+      [](GemmSPWarpPolicy policy, int M, int N, int block_size, Target target,
+         bool use_wgmma, int bits) {
+        return policy->ComputeWarpPartition(M, N, block_size, target, use_wgmma, bits);
       });
tilelang/tileop/gemm_sp/gemm_sp_mma.py (1)

208-230: Fix duplicate function name in the rr variant.

Line 209 defines _gemm_rsr for the register-register (rr) case, but this name conflicts with the register-shared (rs) variant at line 174. This should be a distinct name, such as _gemm_rrr, to correctly represent the register-register semantics.

Apply this diff:

         elif self.is_gemm_rr():
             A_local = self.A
             B_local = self.B
 
             @T.prim_func
-            def _gemm_rsr() -> None:
+            def _gemm_rrr() -> None:
                 """
                 The inner macro that loads data from shared buffers A_shared and
                 B_shared into local fragments, then issues Tensor Core mma ops,
                 accumulating into C_local.
                 """
                 E_local = T.alloc_local((warp_rows * local_size_e), self.e_dtype)
 
                 for ki in T.serial(0, (self.K // micro_size_k)):
                     # Load E into fragment
                     mma_emitter.ldmatrix_e(
                         E_local,
                         E_shared,
                         ki,
                     )
 
                     # Perform Matrix Multiplication
                     mma_emitter.mma_sp(A_local, E_local, B_local, C_local, ki)
 
             # Simplify to optimize the index computing
             # Must inline let statements to simplify the analysis
-            return _Simplify(_gemm_rsr, inline_let=True)
+            return _Simplify(_gemm_rrr, inline_let=True)
src/op/gemm_sp_py.cc (3)

67-67: Use GemmSPWarpPolicy instead of GemmWarpPolicy.

Line 67 constructs a GemmWarpPolicy, but the sparse GEMM path requires GemmSPWarpPolicy. Since GemmSPWarpPolicy has specialized ComputeWarpPartition logic for sparse GEMM (as shown in src/op/gemm_sp.cc:21-63), using the base type will bypass critical sparse-specific warp-partition adjustments.

Apply this diff:

-  node->policy = GemmWarpPolicy(args[9].as<IntImm>().value()->value);
+  node->policy = GemmSPWarpPolicy(args[9].as<IntImm>().value()->value);

212-223: Add error handling for std::stoi.

The GetArchInt function uses std::stoi without error handling. If the architecture string contains non-numeric characters after "sm_" (e.g., "sm_abc"), std::stoi will throw std::invalid_argument or std::out_of_range, potentially causing unexpected crashes.

Apply this diff:

 static int GetArchInt(Target target) {
   int arch_int = 0;
   auto s = target->GetAttr<String>("arch");
-  ICHECK(s.has_value());
+  ICHECK(s.defined());
   std::string arch = s.value();
   if (arch.rfind("sm_", 0) == 0) {
-    arch_int = std::stoi(arch.substr(3));
+    try {
+      arch_int = std::stoi(arch.substr(3));
+    } catch (const std::exception& e) {
+      LOG(WARNING) << "Failed to parse architecture number from '" << arch << "': " << e.what();
+      arch_int = 0;
+    }
   } else {
     arch_int = 0;
   }
   return arch_int;
 }

225-230: Fix ComputeWarpPartition call signature.

Lines 229-230 call policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst), but GemmSPWarpPolicyNode::ComputeWarpPartition expects (int M, int N, int block_size, Target target, bool use_wgmma, int bits) as shown in src/op/gemm_sp.h:21-23 and src/op/gemm_sp.cc:150-151. The current code passes a GemmInst enum where a bool use_wgmma is expected, and omits the required bits parameter.

Apply this diff:

+  bool use_wgmma = (gemm_inst == GemmInst::kWGMMA);
   auto [warp_m, warp_n] =
-      policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
+      policy->ComputeWarpPartition(M, N, block_size, T.target, use_wgmma, A->dtype.bits());
src/op/gemm_sp_py.h (1)

36-36: Declare policy as GemmSPWarpPolicy instead of GemmWarpPolicy.

Line 36 declares policy as mutable GemmWarpPolicy, but GemmSPNode (src/op/gemm_sp.h:73) uses mutable GemmSPWarpPolicy. Since GemmSPWarpPolicy::ComputeWarpPartition includes sparse-specific atom-size adjustments (src/op/gemm_sp.cc:31-62) that the base GemmWarpPolicy lacks, using the base type will skip critical sparse GEMM layout constraints.

Apply this diff:

-  mutable GemmWarpPolicy policy;
+  mutable GemmSPWarpPolicy policy;
tilelang/tileop/gemm_sp/__init__.py (2)

14-17: Correct the parameter type annotation.

The function is registered to receive a GemmSPPy object (based on the C++ call in src/op/gemm_sp_py.cc lines 232-235), but the type annotation indicates GemmSPMMA. This mismatch could cause confusion and type-checking failures.

Apply this diff:

 @tvm_ffi.register_global_func("tl.gemm_sp_py.infer_layout")
-def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range):
+def gemm_sp_py_infer_layout(gemm_sp_py: "GemmSPPy", target: Target, thread_bounds: Range):
     thread_nums = thread_bounds.extent
     return gemm_sp_py.infer_layout(target, thread_nums)

Note: Use forward reference "GemmSPPy" since the class is defined later in the file.


20-25: Correct the parameter type annotation.

The function is registered to receive a GemmSPPy object, but the type annotation indicates GemmSPMMA.

Apply this diff:

 @tvm_ffi.register_global_func("tl.gemm_sp_py.lower")
-def gemm_sp_py_lower(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds: Range,
+def gemm_sp_py_lower(gemm_sp_py: "GemmSPPy", target: Target, thread_bounds: Range,
                      thread_var: tir.Var):
     thread_nums = thread_bounds.extent
     stmt = gemm_sp_py.lower(target, thread_nums, thread_var)
     return stmt
tilelang/intrinsics/mma_sp_macro_generator.py (3)

382-390: Fix the conditional chain to prevent fallthrough error.

The second if at Line 385 should be elif to form a proper conditional ladder. Currently, when a_dtype is 8-bit, the code correctly assigns metadata_16bit_load_32x2_to_shared_16x4_layout_8bit, but then the following if evaluates to false and execution falls through to the else block at Line 389, raising an error incorrectly.

Apply this diff to fix the issue:

 elif DataType(e_dtype).bits == 16:
     if DataType(a_dtype).bits == 8:
         mma_load_layout = metadata_16bit_load_32x2_to_shared_16x4_layout_8bit
-    if DataType(a_dtype).bits == 16:
+    elif DataType(a_dtype).bits == 16:
         mma_load_layout = metadata_16bit_load_32x2_to_shared_16x2_layout_16bit
     elif DataType(a_dtype).bits == 32:
         mma_load_layout = metadata_16bit_load_32x2_to_shared_16x2_layout_32bit
     else:
         raise ValueError(f"Unsupported a_dtype for e_dtype 16bit: {a_dtype}")

431-432: Fix inverted ldmatrix availability logic for int8 transpose.

The current condition enables ldmatrix for the unsupported int8+transpose case. The predicate should be not (DataType(b_dtype).bits != 16 and b_transposed) to correctly disable ldmatrix when B is int8 and transposed.

Apply this diff to fix the logic:

-ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed)
+ldmatrix_available = not (DataType(b_dtype).bits != 16 and b_transposed)

537-540: Fix negated fragment assertion for E_local_buf.

The assertion requires not e_is_fragment while the error message states E_local_buf "must be a fragment buffer". This contradiction causes the assertion to fail when E_local_buf is correctly passed as a fragment. Remove the not to match the intended behavior.

Apply this diff to fix the assertion:

-assert not e_is_fragment, f"currently E_local_buf must be a fragment buffer, found {E_local_buf.scope()}"
+assert e_is_fragment, f"currently E_local_buf must be a fragment buffer, found {E_local_buf.scope()}"
🧹 Nitpick comments (1)
tilelang/intrinsics/mma_sp_macro_generator.py (1)

52-129: Consider annotating constant class dictionaries with ClassVar.

The dtype_abbrv, E_FACTOR_MAP, and E_REPLICATE_FACTOR dictionaries are shared across all instances and act as constant lookup tables. Annotating them with typing.ClassVar clarifies that they are class-level constants rather than instance attributes.

Apply this diff to add the annotations:

+from typing import ClassVar
+
 class SparseTensorCoreIntrinEmitter:
     ...
-    dtype_abbrv = {
+    dtype_abbrv: ClassVar[dict[str, str]] = {
         ...
     }
 
-    E_FACTOR_MAP = {
+    E_FACTOR_MAP: ClassVar[dict[str, dict[str, int]]] = {
         ...
     }
 
-    E_REPLICATE_FACTOR = {
+    E_REPLICATE_FACTOR: ClassVar[dict[str, int]] = {
         ...
     }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 295a689 and 82261e7.

📒 Files selected for processing (24)
  • benchmark/matmul/benchmark_matmul_sp.py (7 hunks)
  • docs/deeplearning_operators/matmul_sparse.md (1 hunks)
  • examples/gemm_sp/example_custom_compress.py (1 hunks)
  • examples/gemm_sp/example_gemm_sp.py (5 hunks)
  • examples/gemm_sp/test_example_gemm_sp.py (1 hunks)
  • examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (2 hunks)
  • src/op/gemm_sp.cc (1 hunks)
  • src/op/gemm_sp_py.cc (1 hunks)
  • src/op/gemm_sp_py.h (1 hunks)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (7 hunks)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (1 hunks)
  • tilelang/intrinsics/mma_layout.py (1 hunks)
  • tilelang/intrinsics/mma_macro_generator.py (5 hunks)
  • tilelang/intrinsics/mma_sp_layout.py (1 hunks)
  • tilelang/intrinsics/mma_sp_macro_generator.py (1 hunks)
  • tilelang/ir.py (1 hunks)
  • tilelang/language/experimental/gemm_sp.py (2 hunks)
  • tilelang/layout/gemm_sp.py (2 hunks)
  • tilelang/tileop/gemm/__init__.py (2 hunks)
  • tilelang/tileop/gemm_sp/__init__.py (1 hunks)
  • tilelang/tileop/gemm_sp/gemm_sp_base.py (1 hunks)
  • tilelang/tileop/gemm_sp/gemm_sp_mma.py (1 hunks)
  • tilelang/utils/sparse.py (5 hunks)
  • tilelang/utils/tensor.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (7)
  • tilelang/utils/tensor.py
  • tilelang/tileop/gemm/init.py
  • examples/gemm_sp/example_gemm_sp.py
  • tilelang/ir.py
  • tilelang/tileop/gemm_sp/gemm_sp_base.py
  • tilelang/intrinsics/mma_macro_generator.py
  • tilelang/intrinsics/mma_layout.py
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
  • src/op/gemm_sp_py.cc
  • tilelang/intrinsics/mma_sp_macro_generator.py
  • tilelang/layout/gemm_sp.py
  • tilelang/intrinsics/mma_sp_layout.py
  • docs/deeplearning_operators/matmul_sparse.md
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
  • src/op/gemm_sp_py.cc
  • tilelang/intrinsics/mma_sp_macro_generator.py
  • tilelang/layout/gemm_sp.py
  • tilelang/intrinsics/mma_sp_layout.py
🧬 Code graph analysis (15)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (136-150)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (5)
tilelang/utils/sparse.py (3)
  • compress (77-106)
  • randn_semi_sparse (109-128)
  • randint_semi_sparse (131-158)
tilelang/utils/tensor.py (2)
  • torch_assert_close (237-329)
  • map_torch_type (37-54)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (136-150)
tilelang/intrinsics/mma_sp_macro_generator.py (1)
  • SparseTensorCoreIntrinEmitter (40-858)
tilelang/language/experimental/gemm_sp.py (1)
  • gemm_sp_v2 (92-308)
examples/gemm_sp/test_example_gemm_sp.py (2)
examples/gemm_sp/example_custom_compress.py (1)
  • main (305-359)
examples/gemm_sp/example_gemm_sp.py (1)
  • main (105-144)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (4)
tilelang/utils/sparse.py (3)
  • compress (77-106)
  • randn_semi_sparse (109-128)
  • randint_semi_sparse (131-158)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (136-150)
tilelang/utils/tensor.py (2)
  • torch_assert_close (237-329)
  • map_torch_type (37-54)
tilelang/intrinsics/mma_sp_macro_generator.py (1)
  • SparseTensorCoreIntrinEmitter (40-858)
src/op/gemm_sp.cc (2)
src/op/gemm_sp_py.h (1)
  • RegisterReflection (41-65)
src/op/gemm_sp.h (1)
  • RegisterReflection (27-33)
benchmark/matmul/benchmark_matmul_sp.py (3)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (136-150)
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (1)
  • matmul_sp (9-57)
tilelang/language/experimental/gemm_sp.py (1)
  • gemm_sp_v2 (92-308)
src/op/gemm_sp_py.cc (4)
src/op/gemm_sp_py.h (2)
  • GemmSPPy (80-86)
  • RegisterReflection (41-65)
tilelang/tileop/gemm_sp/__init__.py (1)
  • GemmSPPy (29-69)
src/op/gemm.h (6)
  • GemmWarpPolicy (59-83)
  • GemmWarpPolicy (64-68)
  • GemmWarpPolicy (70-74)
  • GemmWarpPolicy (76-82)
  • RegisterReflection (35-41)
  • RegisterReflection (106-128)
src/op/gemm_sp.cc (6)
  • Clone (120-123)
  • Clone (120-120)
  • Lower (143-185)
  • Lower (143-143)
  • InferLayout (215-299)
  • InferLayout (215-216)
tilelang/utils/sparse.py (1)
tilelang/utils/tensor.py (2)
  • is_float8 (11-17)
  • fp8_remove_negative_zeros_ (20-24)
examples/gemm_sp/example_custom_compress.py (6)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (136-150)
tilelang/utils/sparse.py (1)
  • randn_semi_sparse (109-128)
tilelang/utils/tensor.py (1)
  • torch_assert_close (237-329)
tilelang/language/allocate.py (3)
  • alloc_shared (27-42)
  • alloc_fragment (59-70)
  • alloc_local (45-56)
tilelang/language/annotations.py (1)
  • annotate_layout (25-36)
tilelang/language/experimental/gemm_sp.py (1)
  • gemm_sp_v2 (92-308)
tilelang/intrinsics/mma_sp_macro_generator.py (4)
tilelang/intrinsics/utils.py (2)
  • mma_store_index_map (81-82)
  • get_ldmatrix_offset (21-63)
tilelang/utils/language.py (1)
  • is_fragment (81-91)
tilelang/intrinsics/mma_sp_layout.py (20)
  • shared_16x16_to_mma_sp_layout_sr_a (14-15)
  • shared_16x16_to_mma_sp_layout_sr_b (18-20)
  • shared_16x32_to_mma_sp_layout_sr_a (23-24)
  • shared_16x32_to_mma_sp_layout_sr_b (27-29)
  • shared_16x64_to_mma_sp_layout_sr_a (32-33)
  • shared_16x64_to_mma_sp_layout_sr_b (36-38)
  • mma_sp_load_a_32x4_to_shared_16x16_layout (41-42)
  • mma_sp_load_a_32x8_to_shared_16x32_layout (45-46)
  • mma_sp_load_a_32x16_to_shared_16x64_layout (49-50)
  • mma_sp_load_b_32x8_to_shared_16x16_layout (53-56)
  • mma_sp_load_b_32x16_to_shared_16x32_layout (59-62)
  • mma_sp_load_b_32x32_to_shared_16x64_layout (65-68)
  • metadata_8bit_load_32x4_to_shared_16x4_layout_32bit (75-80)
  • metadata_16bit_load_32x2_to_shared_16x2_layout_32bit (83-88)
  • metadata_8bit_load_32x4_to_shared_16x4_layout_16bit (91-94)
  • metadata_16bit_load_32x2_to_shared_16x2_layout_16bit (97-100)
  • metadata_8bit_load_32x4_to_shared_16x4_layout_8bit (107-112)
  • metadata_16bit_load_32x2_to_shared_16x4_layout_8bit (115-120)
  • metadata_32bit_load_32x1_to_shared_16x2_layout_8bit (123-129)
  • get_ldmatrix_offset_b (156-190)
tilelang/language/tir/op.py (3)
  • ptx_ldmatrix (1313-1349)
  • address_of (464-480)
  • ptx_mma_sp (964-1062)
src/op/gemm_sp_py.h (3)
src/op/gemm_sp.h (3)
  • tvm (13-112)
  • tl (15-111)
  • RegisterReflection (27-33)
src/op/operator.h (2)
  • TileOperatorNode (56-66)
  • TileOperator (68-72)
src/op/gemm_sp.cc (6)
  • Lower (143-185)
  • Lower (143-143)
  • InferLayout (215-299)
  • InferLayout (215-216)
  • Clone (120-123)
  • Clone (120-120)
tilelang/layout/gemm_sp.py (2)
tilelang/layout/layout.py (1)
  • Layout (13-148)
tilelang/contrib/nvcc.py (2)
  • get_target_compute_version (258-299)
  • parse_compute_version (302-324)
tilelang/tileop/gemm_sp/__init__.py (6)
src/op/gemm_sp_py.h (2)
  • tvm (15-89)
  • GemmSPPy (80-86)
tilelang/utils/target.py (1)
  • target_is_cuda (126-127)
tilelang/ir.py (1)
  • GemmWarpPolicy (30-39)
tilelang/tileop/gemm_sp/gemm_sp_mma.py (3)
  • GemmSPMMA (12-245)
  • infer_layout (14-58)
  • lower (60-233)
tilelang/tileop/gemm_sp/gemm_sp_base.py (23)
  • infer_layout (14-15)
  • lower (17-18)
  • A (66-67)
  • E (70-71)
  • B (74-75)
  • C (78-79)
  • APtr (82-83)
  • EPtr (86-87)
  • BPtr (90-91)
  • CPtr (94-95)
  • M (33-34)
  • N (37-38)
  • K (41-42)
  • trans_A (45-46)
  • trans_B (49-50)
  • stride_A (98-99)
  • stride_B (102-103)
  • offset_A (106-107)
  • offset_B (110-111)
  • clear_accum (114-115)
  • k_pack (118-119)
  • wg_wait (122-123)
  • policy (126-127)
src/op/gemm_sp_py.cc (1)
  • GemmSPPy (51-83)
tilelang/intrinsics/mma_sp_layout.py (1)
tilelang/intrinsics/mma_layout.py (4)
  • mma_load_a_32x4_to_shared_16x8_layout (130-133)
  • mma_load_a_32x16_to_shared_16x32_layout (142-145)
  • mma_load_a_32x8_to_shared_16x16_layout (148-161)
  • ldmatrix_trans_32x8_to_shared_16x16_layout (24-27)
tilelang/language/experimental/gemm_sp.py (2)
tilelang/utils/language.py (1)
  • get_buffer_region_from_load (137-159)
tilelang/language/gemm.py (10)
  • legalize_arguments (48-59)
  • legalize_arguments (251-262)
  • retrieve_shape (66-83)
  • retrieve_shape (269-286)
  • retrieve_stride (85-111)
  • retrieve_stride (288-314)
  • retrieve_ptr (140-175)
  • retrieve_ptr (343-378)
  • retrieve_offset (177-195)
  • retrieve_offset (380-398)
🪛 LanguageTool
docs/deeplearning_operators/matmul_sparse.md

[grammar] ~50-~50: Ensure spelling is correct
Context: ... in A_sparse/A and E. (i.e. the 4-elment group at [n, k] doesn't match the 4-bit...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)


[grammar] ~137-~137: Ensure spelling is correct
Context: ...each 4-bit chunk represents two 2-bit indcies of non-zero elements within four cons...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)

🪛 markdownlint-cli2 (0.18.1)
docs/deeplearning_operators/matmul_sparse.md

39-39: Link text should be descriptive

(MD059, descriptive-link-text)


39-39: Link text should be descriptive

(MD059, descriptive-link-text)


176-176: Link text should be descriptive

(MD059, descriptive-link-text)

🪛 Ruff (0.14.3)
tilelang/tileop/gemm_sp/gemm_sp_mma.py

57-58: Avoid specifying long messages outside the exception class

(TRY003)


232-233: Avoid specifying long messages outside the exception class

(TRY003)

examples/gemm_sp/example_custom_compress.py

111-112: Avoid specifying long messages outside the exception class

(TRY003)


122-122: Avoid specifying long messages outside the exception class

(TRY003)


125-125: Avoid specifying long messages outside the exception class

(TRY003)


129-129: Avoid specifying long messages outside the exception class

(TRY003)


132-132: Avoid specifying long messages outside the exception class

(TRY003)


134-136: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/mma_sp_macro_generator.py

52-60: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


62-119: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


121-129: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


178-180: Avoid specifying long messages outside the exception class

(TRY003)


212-212: Avoid specifying long messages outside the exception class

(TRY003)


313-313: Avoid specifying long messages outside the exception class

(TRY003)


381-381: Avoid specifying long messages outside the exception class

(TRY003)


390-390: Avoid specifying long messages outside the exception class

(TRY003)


395-395: Avoid specifying long messages outside the exception class

(TRY003)


397-397: Avoid specifying long messages outside the exception class

(TRY003)


445-445: Avoid specifying long messages outside the exception class

(TRY003)


696-696: Avoid specifying long messages outside the exception class

(TRY003)


713-713: Avoid specifying long messages outside the exception class

(TRY003)


771-771: Avoid specifying long messages outside the exception class

(TRY003)


785-785: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/layout/gemm_sp.py

115-115: Avoid specifying long messages outside the exception class

(TRY003)


118-118: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/tileop/gemm_sp/__init__.py

61-61: Avoid specifying long messages outside the exception class

(TRY003)


69-69: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/mma_sp_layout.py

124-124: Unused function argument: local_id

(ARG001)


172-172: Avoid specifying long messages outside the exception class

(TRY003)


188-188: Avoid specifying long messages outside the exception class

(TRY003)


190-190: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/experimental/gemm_sp.py

162-163: Prefer TypeError exception for invalid type

(TRY004)


162-163: Avoid specifying long messages outside the exception class

(TRY003)


190-191: Prefer TypeError exception for invalid type

(TRY004)


190-191: Avoid specifying long messages outside the exception class

(TRY003)


254-255: Prefer TypeError exception for invalid type

(TRY004)


254-255: Avoid specifying long messages outside the exception class

(TRY003)


274-275: Prefer TypeError exception for invalid type

(TRY004)


274-275: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (30)
tilelang/intrinsics/mma_sp_layout.py (1)

1-190: LGTM! Layout helpers follow established patterns.

The SP layout mapping functions correctly delegate to existing MMA layout helpers and maintain API consistency.

examples/gemm_sp/test_example_gemm_sp.py (1)

1-16: LGTM! Clean test structure.

The test module follows standard patterns for exercising example code.

tilelang/utils/sparse.py (2)

92-103: LGTM! Correct float8 and transposition handling.

The float8 handling correctly normalizes negative zeros and the transposition logic maintains correctness for the sparse compression workflow.


131-159: LGTM! New utility follows existing patterns.

randint_semi_sparse correctly mirrors the structure and behavior of randn_semi_sparse for integer tensors.

tilelang/language/experimental/gemm_sp.py (1)

91-308: LGTM! gemm_sp_v2 implementation is comprehensive.

The new gemm_sp_v2 function correctly handles buffer normalization, shape/stride computation, and offset calculation. The implementation properly validates matrix dimensions and generates the appropriate intrinsic call.

benchmark/matmul/benchmark_matmul_sp.py (2)

12-12: LGTM! Correct API migration to make_cutlass_metadata_layout.


220-227: LGTM! Successful migration to gemm_sp_v2.

The transition from T.gemm_sp to T.gemm_sp_v2 aligns with the PR objectives and maintains all required parameters.

examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (2)

4-4: LGTM! Consistent API migration.


43-47: LGTM! Correct usage of make_cutlass_metadata_layout.

The migration removes the backend parameter and maintains all required arguments for the new API.

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (3)

5-11: LGTM! Clean test infrastructure updates.

The imports and configuration changes (disabling TF32, importing new utilities) align well with the v2 pathway requirements and metadata layout updates.


14-42: LGTM! Well-structured input generation helper.

The generate_dense_input function cleanly handles dynamic dtype selection and transposition for both integer and floating-point types, supporting the test harness's expanded dtype coverage.


122-123: LGTM! Correct metadata dtype and layout usage.

The dynamic metadata dtype calculation via SparseTensorCoreIntrinEmitter.E_FACTOR_MAP and the switch to make_cutlass_metadata_layout align with the PR's standardized metadata layout approach for SM8x.

Also applies to: 141-145

tilelang/layout/gemm_sp.py (3)

21-104: LGTM! SM90 metadata layout correctly promoted to public API.

The function is now exposed as make_cutlass_metadata_layout_sm90 with proper error handling and dynamic shape/stride computation. The layout logic correctly implements CUTLASS-compatible interleaving for SM9.0 metadata.


107-133: LGTM! SM8x metadata layout correctly implements CUTLASS conventions.

The renamed make_cutlass_metadata_layout_sm8x correctly computes group and interweave constants based on metadata dtype bit-width, and the interleaved column-major transformation aligns with the CUTLASS SM8x metadata format referenced in the PyTorch link (line 109).

Based on learnings


136-150: LGTM! Clean architecture-based dispatch.

The make_cutlass_metadata_layout wrapper correctly determines the compute version and dispatches to the appropriate SM-specific layout generator, providing a unified public API.

examples/gemm_sp/example_custom_compress.py (3)

18-62: LGTM! Well-structured configuration constants.

The DEFAULT_CONFIG and ARCH_INFO dictionaries provide clear, architecture-specific configurations for different GPU families and data types, making the example easy to adapt and tune.


106-221: LGTM! Comprehensive naive compression implementation.

The torch_compress function provides a clear, well-documented reference implementation of 2:4 sparsity encoding with proper validation and dtype handling. The detailed encoding comments (lines 148-180) are particularly helpful for understanding the metadata format.


305-363: LGTM! Complete end-to-end example with benchmarking.

The main function demonstrates a full workflow: argument parsing, input generation, compression (with selectable compressor), correctness verification via torch_assert_close, and performance benchmarking. This provides excellent guidance for users implementing custom compression.

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (3)

11-65: LGTM! Clean SS variant test implementation.

The matmul function correctly constructs a gemm_sp_v2 kernel with proper shape handling, metadata layout annotation via make_cutlass_metadata_layout, and pipelined execution. This establishes a solid foundation for the test suite.


136-165: LGTM! Robust input generation helper.

The generate_dense_input function handles diverse dtypes (int8, float8, float16, bfloat16) with appropriate value ranges and semi-sparse generation via randint_semi_sparse/randn_semi_sparse, supporting comprehensive test coverage.


198-256: LGTM! Complete coverage of RS, SR, and RR variants.

The matmul_rs, matmul_sr, and matmul_rr functions correctly implement fragment-based loading patterns for each variant, with appropriate alloc_fragment allocations, T.copy operations, and layout annotations. This ensures all dataflow patterns (shared-shared, register-shared, shared-register, register-register) are tested.

Also applies to: 352-410, 506-568

tilelang/tileop/gemm_sp/__init__.py (1)

28-69: LGTM! Clean GemmSPPy class definition with proper delegation.

The GemmSPPy class correctly defines all required fields with proper type hints and delegates infer_layout and lower to GemmSPMMA for CUDA targets, with clear error handling for unsupported targets. This provides a clean separation between the FFI boundary and the implementation.

tilelang/intrinsics/mma_sp_macro_generator.py (8)

134-180: LGTM!

The constructor properly initializes all instance attributes and validates the warp configuration. The initialization helpers correctly set up dimensions, abbreviations, and MMA prefixes based on data types.


182-288: LGTM!

The initialization and helper methods correctly compute dimensions, local sizes, and thread bindings. The extract_thread_binding method properly handles both is_m_first orderings for flexible thread layout.


290-354: LGTM!

The ldmatrix_a method correctly handles data loading for the A matrix with proper fallback when ldmatrix is unavailable for transposed int8. The physical K dimension adjustment (dividing by SPARSE_FACTOR) correctly accounts for the 2:4 sparsity pattern.


447-516: LGTM!

The _warp_ldmatrix_b macro correctly handles both ldmatrix and fallback paths. The replicate_b logic properly doubles the load for 16x16 output tiles, and the fallback path correctly handles transposed and non-transposed layouts.


545-589: LGTM!

The _warp_mma_sp macro correctly implements the sparse MMA operation. The stride calculations properly handle fragment buffers, and the replicate_b path correctly performs two MMA operations with appropriate offsets for 16x16 output tiles. Metadata handling with SPARSE_SELECTOR = 0 is correct.


591-644: LGTM!

The stmatrix method correctly implements both shared and global memory stores. Thread-local indices are properly mapped to 2D positions using mma_store_index_map, and vectorized stores optimize memory access. The global variant correctly offsets by block coordinates.


646-787: LGTM!

The make_mma_load_layout method correctly constructs fragment layouts for A and B matrices. The axis order determination (sr vs rs) correctly handles transposed and non-transposed cases, and the sparse factor division for matrix A (Line 744) properly accounts for 2:4 sparsity. The repeat/replicate operations correctly build block-level fragments with appropriate thread mapping.


789-858: LGTM!

The make_mma_store_layout method correctly constructs the output fragment layout. The forward_thread and forward_index functions properly decompose 2D coordinates into block, warp, and MMA components, correctly handling both is_m_first thread orderings. The local offset calculation accurately reflects the warp-column-major layout of the output buffer.

from tilelang.utils.language import get_buffer_region_from_load
import tilelang.language as T
from tvm import tir
from typing import list
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix typing import: use List instead of list.

The import from typing import list is incorrect. Use from typing import List for Python < 3.9 compatibility, or if targeting Python 3.9+, remove this import and use the built-in list type directly.

Apply this diff for Python < 3.9 compatibility:

-from typing import list
+from typing import List

Then update the return type annotations:

  • Line 146: -> List[int]
  • Line 165: -> List[int]

Alternatively, if Python 3.9+ is required, remove the import entirely and keep the lowercase list annotations.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from typing import list
from typing import List
🤖 Prompt for AI Agents
In tilelang/language/experimental/gemm_sp.py around line 8 and affecting
annotations at lines 146 and 165, the typing import is incorrect: replace "from
typing import list" with "from typing import List" (for Python <3.9) and update
the two return type annotations at line 146 and line 165 from "-> list[int]" to
"-> List[int]". If the codebase targets Python 3.9+, instead remove the import
entirely and keep the built-in "list" annotations as-is; ensure imports and
annotations are consistent across the file.

@botbw botbw marked this pull request as draft November 5, 2025 15:43
@botbw botbw marked this pull request as ready for review November 5, 2025 15:43
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
docs/deeplearning_operators/matmul_sparse.md (1)

39-39: Improve link text for accessibility.

The duplicate [here] links violate MD059 and make the doc harder to navigate with assistive tech. Please replace them with descriptive titles (e.g., [PyTorch sparse kernel], [vLLM sparse kernel]).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 82261e7 and d7ca20e.

📒 Files selected for processing (5)
  • docs/deeplearning_operators/matmul_sparse.md (1 hunks)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (1 hunks)
  • tilelang/intrinsics/mma_sp_macro_generator.py (1 hunks)
  • tilelang/language/experimental/gemm_sp.py (2 hunks)
  • tilelang/tileop/gemm_sp/gemm_sp_mma.py (1 hunks)
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • docs/deeplearning_operators/matmul_sparse.md
  • tilelang/intrinsics/mma_sp_macro_generator.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • tilelang/intrinsics/mma_sp_macro_generator.py
🧬 Code graph analysis (4)
tilelang/tileop/gemm_sp/gemm_sp_mma.py (6)
tilelang/tileop/gemm_sp/gemm_sp_base.py (20)
  • GemmSPBase (11-127)
  • infer_layout (14-15)
  • policy (126-127)
  • M (33-34)
  • N (37-38)
  • in_dtype (57-59)
  • e_dtype (53-54)
  • accum_dtype (62-63)
  • trans_A (45-46)
  • trans_B (49-50)
  • K (41-42)
  • is_gemm_ss (20-21)
  • A (66-67)
  • B (74-75)
  • C (78-79)
  • is_gemm_sr (23-24)
  • is_gemm_rs (26-27)
  • is_gemm_rr (29-30)
  • lower (17-18)
  • E (70-71)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-18)
tilelang/intrinsics/mma_sp_macro_generator.py (6)
  • make_mma_store_layout (789-858)
  • make_mma_load_layout (646-787)
  • ldmatrix_a (290-354)
  • ldmatrix_e (356-418)
  • ldmatrix_b (420-516)
  • mma_sp (518-589)
tilelang/utils/language.py (2)
  • is_shared (25-39)
  • is_fragment (81-91)
tilelang/transform/simplify.py (1)
  • _Simplify (31-49)
tilelang/tileop/gemm_sp/__init__.py (2)
  • infer_layout (56-61)
  • lower (63-69)
tilelang/language/experimental/gemm_sp.py (2)
tilelang/utils/language.py (1)
  • get_buffer_region_from_load (137-159)
tilelang/language/gemm.py (10)
  • legalize_arguments (48-59)
  • legalize_arguments (251-262)
  • retrieve_shape (66-83)
  • retrieve_shape (269-286)
  • retrieve_stride (85-111)
  • retrieve_stride (288-314)
  • retrieve_ptr (140-175)
  • retrieve_ptr (343-378)
  • retrieve_offset (177-195)
  • retrieve_offset (380-398)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (6)
tilelang/utils/sparse.py (3)
  • compress (77-106)
  • randn_semi_sparse (109-128)
  • randint_semi_sparse (131-158)
tilelang/utils/tensor.py (2)
  • torch_assert_close (237-329)
  • map_torch_type (37-54)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (136-150)
tilelang/intrinsics/mma_sp_macro_generator.py (1)
  • SparseTensorCoreIntrinEmitter (40-858)
tilelang/language/experimental/gemm_sp.py (1)
  • gemm_sp_v2 (91-307)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-18)
tilelang/intrinsics/mma_sp_macro_generator.py (5)
tilelang/intrinsics/utils.py (2)
  • mma_store_index_map (81-82)
  • get_ldmatrix_offset (21-63)
tilelang/utils/language.py (1)
  • is_fragment (81-91)
tilelang/intrinsics/mma_sp_layout.py (20)
  • shared_16x16_to_mma_sp_layout_sr_a (14-15)
  • shared_16x16_to_mma_sp_layout_sr_b (18-20)
  • shared_16x32_to_mma_sp_layout_sr_a (23-24)
  • shared_16x32_to_mma_sp_layout_sr_b (27-29)
  • shared_16x64_to_mma_sp_layout_sr_a (32-33)
  • shared_16x64_to_mma_sp_layout_sr_b (36-38)
  • mma_sp_load_a_32x4_to_shared_16x16_layout (41-42)
  • mma_sp_load_a_32x8_to_shared_16x32_layout (45-46)
  • mma_sp_load_a_32x16_to_shared_16x64_layout (49-50)
  • mma_sp_load_b_32x8_to_shared_16x16_layout (53-56)
  • mma_sp_load_b_32x16_to_shared_16x32_layout (59-62)
  • mma_sp_load_b_32x32_to_shared_16x64_layout (65-68)
  • metadata_8bit_load_32x4_to_shared_16x4_layout_32bit (75-80)
  • metadata_16bit_load_32x2_to_shared_16x2_layout_32bit (83-88)
  • metadata_8bit_load_32x4_to_shared_16x4_layout_16bit (91-94)
  • metadata_16bit_load_32x2_to_shared_16x2_layout_16bit (97-100)
  • metadata_8bit_load_32x4_to_shared_16x4_layout_8bit (107-112)
  • metadata_16bit_load_32x2_to_shared_16x4_layout_8bit (115-120)
  • metadata_32bit_load_32x1_to_shared_16x2_layout_8bit (123-129)
  • get_ldmatrix_offset_b (156-190)
tilelang/language/tir/op.py (2)
  • ptx_ldmatrix (1313-1349)
  • ptx_mma_sp (964-1062)
tilelang/layout/fragment.py (2)
  • replicate (147-161)
  • repeat (124-145)
🪛 markdownlint-cli2 (0.18.1)
docs/deeplearning_operators/matmul_sparse.md

39-39: Link text should be descriptive

(MD059, descriptive-link-text)


39-39: Link text should be descriptive

(MD059, descriptive-link-text)


176-176: Link text should be descriptive

(MD059, descriptive-link-text)

🪛 Ruff (0.14.3)
tilelang/tileop/gemm_sp/gemm_sp_mma.py

57-58: Avoid specifying long messages outside the exception class

(TRY003)


232-233: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/experimental/gemm_sp.py

161-162: Prefer TypeError exception for invalid type

(TRY004)


161-162: Avoid specifying long messages outside the exception class

(TRY003)


189-190: Prefer TypeError exception for invalid type

(TRY004)


189-190: Avoid specifying long messages outside the exception class

(TRY003)


253-254: Prefer TypeError exception for invalid type

(TRY004)


253-254: Avoid specifying long messages outside the exception class

(TRY003)


273-274: Prefer TypeError exception for invalid type

(TRY004)


273-274: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/mma_sp_macro_generator.py

52-60: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


62-119: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


121-129: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


178-180: Avoid specifying long messages outside the exception class

(TRY003)


212-212: Avoid specifying long messages outside the exception class

(TRY003)


313-313: Avoid specifying long messages outside the exception class

(TRY003)


381-381: Avoid specifying long messages outside the exception class

(TRY003)


390-390: Avoid specifying long messages outside the exception class

(TRY003)


395-395: Avoid specifying long messages outside the exception class

(TRY003)


397-397: Avoid specifying long messages outside the exception class

(TRY003)


445-445: Avoid specifying long messages outside the exception class

(TRY003)


696-696: Avoid specifying long messages outside the exception class

(TRY003)


713-713: Avoid specifying long messages outside the exception class

(TRY003)


771-771: Avoid specifying long messages outside the exception class

(TRY003)


785-785: Avoid specifying long messages outside the exception class

(TRY003)

Comment on lines +52 to +129
dtype_abbrv = {
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"int8": "int8",
"int32": "int32",
"float8_e4m3": "e4m3",
"float8_e5m2": "e5m2",
}

E_FACTOR_MAP = { # e_kdim = mma_kdim // e_factor
"float": {
"int16": 8,
"uint16": 8,
},
"float32": {
"int16": 8,
"uint16": 8,
},
"float16": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"bfloat16": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"int8": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"uint8": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"float8_e4m3": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"float8_e5m2": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
}

E_REPLICATE_FACTOR = { # metadata replicate every 4 consecutive threads
"float32": 2,
"float16": 2, # 2 of 4 consecutive threads provides
"bfloat16": 2,
"int8": 1, # 4 of 4 consecutive threads provides
"uint8": 1,
"float8_e4m3": 1,
"float8_e5m2": 1,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add tf32 dtype metadata to emitter tables.

E_FACTOR_MAP already exposes "float" entries, but dtype_abbrv and E_REPLICATE_FACTOR don’t, so creating the emitter with in_dtype="float" (tf32) raises a KeyError. That breaks the advertised tf32 path. Please add the missing mappings so we can actually emit tf32 kernels.

Apply this diff:

     dtype_abbrv = {
+        "float": "tf32",
         "float16": "fp16",
         "bfloat16": "bf16",
         "float32": "fp32",
@@
     E_REPLICATE_FACTOR = {  # metadata replicate every 4 consecutive threads
+        "float": 2,
         "float32": 2,
         "float16": 2,  # 2 of 4 consecutive threads provides
         "bfloat16": 2,
🧰 Tools
🪛 Ruff (0.14.3)

52-60: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


62-119: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


121-129: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)

🤖 Prompt for AI Agents
In tilelang/intrinsics/mma_sp_macro_generator.py around lines 52 to 129, the
emitter tables are missing entries for the tf32 dtype key "float", causing a
KeyError when in_dtype="float"; add "float": "tf32" to the dtype_abbrv mapping
and add "float": 2 to E_REPLICATE_FACTOR (matching float32 behavior) so the
existing E_FACTOR_MAP "float" entries can be used without error.

@botbw botbw marked this pull request as draft November 10, 2025 04:10
@LeiWang1999 LeiWang1999 marked this pull request as ready for review November 11, 2025 07:38
@LeiWang1999
Copy link
Member

we're good to go if we can resolve the conflict and I think then we can let this pr in.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants