Skip to content

Conversation

Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Oct 13, 2025

Fixed #1011.
This pull request introduces improvements to buffer shape handling and validation in the Cython adapter layer, as well as updates to CUDA type-specific functions for numerical types. The main changes enhance the robustness of shape checks for tensors and ensure correct usage of absolute value functions for custom types.

Buffer shape handling and validation:

  • Added support for symbolic dimensions in buffer shapes by marking them with -1 in _process_static_buffer_infos, allowing for more flexible shape specification.
  • Improved static shape validation in CythonKernelWrapper by checking both the number of dimensions and each dimension's size, raising clear errors for mismatches and ignoring symbolic dimensions (-1).

CUDA numerical type functions:

  • Updated the implementation of the __habs function for both half_t and bfloat16_t to use the CUTLASS library's abs function directly, ensuring correct absolute value computation for these types.

Summary by CodeRabbit

  • New Features

    • Support symbolic dimensions in static shape detection (treated as wildcards).
    • Enforce tensor rank matches expected shape, with clear errors on mismatch.
  • Bug Fixes

    • Preserve symbolic axes instead of ignoring them during static shape extraction.
    • Treat -1 as a per-dimension wildcard and provide clearer mismatch errors.
  • Examples

    • Updated fused MoE example to use Python boolean conditionals and simplified 1D tensor shapes.

Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

Copy link
Contributor

coderabbitai bot commented Oct 13, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds detection of symbolic tensor dimensions during static shape extraction and records them as -1; enhances Cython wrapper shape checks to validate tensor dimensionality and per-dimension sizes (treating -1 as a wildcard); example kernel updated indexing and shapes to use 1D expert arrays and Python boolean conditions.

Changes

Cohort / File(s) Summary of Changes
Static shape extraction
tilelang/jit/adapter/cython/adapter.py
Added top-level helper is_symbolic_expr(expr) -> bool. _process_static_buffer_infos now recognizes symbolic dimensions (e.g., tir.Var) when scanning buffer.shape and records them as (index, -1) in static_shape; retains existing tir.IntImm handling and raises on unsupported types.
Cython wrapper validation
tilelang/jit/adapter/cython/cython_wrapper.pyx
_check_static_shape now checks tensor.dim() == len(shape_list) and raises ValueError if mismatched. Per-dimension validation treats -1 in shape_list as a wildcard (skips check); otherwise requires exact equality and raises ValueError on mismatch.
Example adjustments
examples/fusedmoe/example_fusedmoe_tilelang.py
Replaced TIR-style conditional checks with Python booleans for i < actual_rows in logits writes; converted several 2D tensors with trailing 1 dimension to 1D shapes and updated indexing/reshaping (e.g., stacked_expert_weights, stacked_expert_tokens_idxs, flat_expert_weights).

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Cython as "Cython Wrapper"
  participant Adapter as "Adapter"
  participant Launcher as "Kernel Launcher"

  User->>Cython: call kernel(tensor)
  Cython->>Cython: if tensor.dim() != len(shape_list) -> raise ValueError
  Cython->>Adapter: request static_shape & strides
  Adapter-->>Cython: return static_shape (symbolic dims returned as -1)
  Cython->>Cython: loop dims i
  alt shape_list[i] == -1
    Cython->>Cython: skip check for dim i (wildcard)
  else
    Cython->>Cython: verify tensor.shape[i] == shape_list[i] or raise ValueError
  end
  alt all dims valid
    Cython->>Launcher: launch kernel
    Launcher-->>User: execution result
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I hop through shapes with whiskered care,
Vars turned to -1 — a sign I share.
I count your dims and call the bluff,
Wildcards skip, mismatches get tough.
Hooray — kernels launch, and carrots are enough. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Out of Scope Changes Check ❓ Inconclusive The changes to adapter.py and cython_wrapper.pyx are clearly in-scope, directly addressing the shape validation and symbolic dimension support requirements. However, the modifications to examples/fusedmoe/example_fusedmoe_tilelang.py—including removal of trailing dimensions from tensor shapes and changes to control flow—appear to be adjustments made to accommodate the new stricter shape validation, which could be considered tangential to the core issue fix. Additionally, the PR description mentions updating CUDA numerical type functions (__habs implementation for half_t and bfloat16_t) to call CUTLASS's abs function directly, but these changes do not appear in the provided raw_summary, creating uncertainty about whether all changes are accounted for and in-scope.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "[Enhancement] Add support for symbolic dimensions in Cython kernel adapter and improve static shape validation in wrapper" accurately reflects the main changes in the codebase. The title explicitly mentions the two primary objectives: adding symbolic dimension support in the adapter and improving static shape validation in the wrapper, which align with the modifications in adapter.py and cython_wrapper.pyx. The title is concise, specific enough for teammates scanning history to understand the primary change, and does not contain vague or misleading terminology.
Linked Issues Check ✅ Passed The code changes directly address the requirements from linked issue #1011. The adapter.py introduces the is_symbolic_expr() function to identify and mark symbolic dimensions with -1 in _process_static_buffer_infos, fulfilling the requirement to support symbolic dimensions while remaining flexible. The cython_wrapper.pyx adds an early dimensionality check that rejects mismatched tensor ranks and refines per-dimension validation to treat -1 as a wildcard, directly implementing the requirement to validate tensor rank and concrete dimension sizes while providing clear error messages. These changes collectively satisfy the core objectives of enforcing shape validation and supporting symbolic dimensions.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

coderabbitai bot commented Oct 13, 2025

Walkthrough

Updates CUDA half/bfloat16 absolute value helpers to use generic abs and add bfloat16 overload. Extends Cython adapter to encode symbolic tensor dims as -1 in static shapes. Adds stricter shape validation in the Cython wrapper: checks rank match and treats -1 as a wildcard per-dimension.

Changes

Cohort / File(s) Summary
CUDA numeric abs helpers
src/tl_templates/cuda/common.h
Replaced half_t __habs implementation to return abs(x). Added bfloat16_t __habs returning abs(x). No signature changes.
Cython adapter: symbolic dims
tilelang/jit/adapter/cython/adapter.py
In _process_static_buffer_infos, records tir.Var dims in static_shape as (idx, -1); keeps previous handling for tir.IntImm and stride/contiguity logic unchanged.
Cython wrapper: shape checks
tilelang/jit/adapter/cython/cython_wrapper.pyx
In _check_static_shape, added rank equality check (tensor.dim() vs provided shape length). Per-dimension: skip when expected is -1; otherwise enforce exact equality; raise ValueError on mismatch.

Sequence Diagram(s)

sequenceDiagram
    actor U as Python Caller
    participant W as Cython Wrapper (_check_static_shape)
    participant A as Cython Adapter (_process_static_buffer_infos)
    participant K as Kernel Launcher

    U->>A: Provide tensor buffers
    A->>A: Build static_shape<br/>(IntImm -> index,value)<br/>(tir.Var -> index,-1)
    A-->>W: static_shape with -1 for symbolic dims

    U->>W: Launch request with tensors
    W->>W: Check rank equality (dim == len(shape_list))<br/>If not, raise ValueError
    alt Per-dimension
        W->>W: If expected == -1, skip check
        else W->>W: Enforce actual == expected<br/>If mismatch, raise ValueError
    end
    W-->>K: Proceed to launch on success
    K-->>U: Execution result
    note over W,K: -1 acts as wildcard for symbolic dimensions
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I twitch my whiskers, shapes aligned,
From ranks that match to dims unsigned—
A -1 hop for symbols free,
No sneaky tensors pass by me.
Half and bfloat, abs made neat,
Thump-thump: correctness at my feet. 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Out of Scope Changes Check ⚠️ Warning The pull request includes changes to the __habs implementations for half_t and bfloat16_t types in the CUDA common header, which are unrelated to the shape validation objectives of issue #1011. These modifications address CUDA numerical type functionality rather than the symbolic dimension handling or static shape checks that the linked issue targets. As a result, they are out of scope for the linked issue’s requirements. To maintain clear separation of concerns, move the __habs updates into a dedicated pull request linked to an issue focused on CUDA numeric type functionality, or explicitly document their relation to the shape validation enhancements if they must remain together.
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly highlights the enhancements made to support symbolic dimensions and improve static shape validation in the Cython adapter, which represent the primary focus of the pull request. It is concise, clear, and directly related to the key changes without unnecessary detail. By focusing on the most impactful change, it aligns with the guidelines for meaningful PR titles.
Linked Issues Check ✅ Passed The pull request implements the required static shape handling by marking symbolic dimensions as –1 in the adapter and enforces both dimension count and per-dimension size checks in the wrapper, raising errors for mismatches. These changes fully satisfy the objectives of issue #1011 by preventing kernel launches with incorrect tensor ranks or concrete shape mismatches while allowing symbolic dimensions. No related coding requirements from the linked issue remain unaddressed.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d89ba5b and 814b476.

📒 Files selected for processing (3)
  • src/tl_templates/cuda/common.h (1 hunks)
  • tilelang/jit/adapter/cython/adapter.py (1 hunks)
  • tilelang/jit/adapter/cython/cython_wrapper.pyx (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/tl_templates/cuda/common.h (1)
src/tl_templates/cpp/half.hpp (1)
  • abs (3242-3242)
🔇 Additional comments (3)
tilelang/jit/adapter/cython/adapter.py (1)

281-282: LGTM! Symbolic dimension handling added correctly.

The addition of tir.Var handling to encode symbolic dimensions as -1 in static_shape is correct and aligns with the downstream validation logic in cython_wrapper.pyx that treats -1 as a wildcard dimension.

tilelang/jit/adapter/cython/cython_wrapper.pyx (1)

110-127: LGTM! Enhanced shape validation correctly addresses issue #1011.

The additions properly validate tensor shapes in two stages:

  1. Rank validation (lines 111-117): Ensures the number of dimensions matches before checking individual sizes, preventing index out of bounds errors
  2. Per-dimension validation (line 122): Treats -1 as a wildcard (symbolic dimension) while enforcing concrete dimension size matches

This implementation correctly catches the bug described in issue #1011, where a kernel expecting a 1D tensor could be launched with a 2D tensor.

src/tl_templates/cuda/common.h (1)

57-61: Verify CUTLASS abs overloads for half_t and bfloat16_t

The __habs implementations now use unqualified abs(x); confirm that CUTLASS headers provide abs(cutlass::half_t) and abs(cutlass::bfloat16_t) overloads (via ADL). If they’re missing, qualify the calls or add the necessary overloads.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

…apter and improve static shape validation in wrapper
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tilelang/jit/adapter/cython/cython_wrapper.pyx (1)

120-123: Per-dimension wildcard (-1) handling — solid; consider minor hardening

Current logic is correct. Two optional improvements:

  • Guard indices: assert 0 <= shape_idx < tensor.dim() to fail fast on bad metadata.
  • Enrich error: include full expected vs actual shapes for easier debugging.
 for shape_idx, expected_shape in shape_list:
-    actual_shape = tensor.shape[shape_idx]
+    assert 0 <= shape_idx < tensor.dim(), (
+        f"Shape index {shape_idx} out of bounds for tensor with {tensor.dim()} dims"
+    )
+    actual_shape = tensor.shape[shape_idx]
     if expected_shape != -1 and actual_shape != expected_shape:
-        raise ValueError(
-            f"Static shape mismatch for parameter {param}: "
-            f"expected {expected_shape} at index {shape_idx}, "
-            f"got {actual_shape}"
-        )
+        raise ValueError(
+            f"Static shape mismatch for parameter {param}: "
+            f"expected dim[{shape_idx}]={expected_shape}, got {actual_shape}; "
+            f"tensor.shape={tuple(tensor.shape)}, expected_spec={shape_list}"
+        )
tilelang/jit/adapter/cython/adapter.py (1)

281-282: Record symbolic dims as -1 — LGTM; consider a named sentinel

Using -1 aligns with wrapper validation. To avoid magic numbers across modules, define a shared constant (e.g., SYMBOLIC_DIM = -1) and use it in both adapter and wrapper.

-elif isinstance(s, tir.Var):
-    static_shape.append((j, -1))  # -1 for symbolic
+elif isinstance(s, tir.Var):
+    static_shape.append((j, SYMBOLIC_DIM))  # -1 for symbolic

And in cython_wrapper.pyx replace literal -1 checks with SYMBOLIC_DIM.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d89ba5b and d5ba468.

📒 Files selected for processing (2)
  • tilelang/jit/adapter/cython/adapter.py (1 hunks)
  • tilelang/jit/adapter/cython/cython_wrapper.pyx (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: build-test-metal
🔇 Additional comments (1)
tilelang/jit/adapter/cython/cython_wrapper.pyx (1)

111-118: NDIM check addition — LGTM

Explicit rank validation closes the gap from #1011. Works with symbolic dims.

@LeiWang1999
Copy link
Member

/codex review

@LeiWang1999
Copy link
Member

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines +112 to +116
if tensor.dim() != len(shape_list):
raise ValueError(
f"Static shape mismatch for parameter {param}: "
f"expected {len(shape_list)} dimensions, "
f"got {tensor.dim()}"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid rank checks based on partial static shape metadata

_check_static_shape now raises when tensor.dim() != len(shape_list), but shape_list only contains entries for dimensions whose shapes were captured as tir.IntImm or tir.Var. Buffers whose shapes are other expressions (e.g., T.Buffer((M * N,), "float32")) produce an empty shape_list even though the buffer is 1-D, so every tensor passed to that parameter now fails with “expected 0 dimensions”. Previously these buffers were accepted. The rank check should rely on the buffer’s actual rank or run only when shape metadata covers all dimensions.

Useful? React with 👍 / 👎.

@LeiWang1999
Copy link
Member

some tests failed for those changes. maybe we can take some insights from the codex review

…ns in TIR

- Introduced a new utility function `is_symbolic_expr` to determine if an expression is a symbolic expression, enhancing type checking capabilities.
- Updated shape handling in `CythonKernelAdapter` to utilize the new function, improving handling for symbolic shapes.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tilelang/jit/adapter/cython/adapter.py (1)

290-291: Consider more descriptive error message.

The error message could be more helpful by indicating what types are supported (e.g., tir.IntImm for concrete dimensions or symbolic tir.PrimExpr for dynamic dimensions).

Apply this diff to improve the error message:

                    else:
-                        raise ValueError(f"Unsupported shape type: {type(s)}")
+                        raise ValueError(
+                            f"Unsupported shape type: {type(s)}. "
+                            f"Expected tir.IntImm (concrete) or tir.PrimExpr (symbolic), got {s}"
+                        )

Note: This also addresses the Ruff TRY003 hint by making the message more specific to the context.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 221c16b and 2e4978f.

📒 Files selected for processing (2)
  • examples/fusedmoe/example_fusedmoe_tilelang.py (5 hunks)
  • tilelang/jit/adapter/cython/adapter.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/jit/adapter/cython/adapter.py (2)
src/transform/lower_tile_op.cc (2)
  • expr (433-445)
  • expr (433-433)
src/transform/layout_inference.cc (2)
  • expr (389-407)
  • expr (389-389)
🪛 Ruff (0.14.0)
tilelang/jit/adapter/cython/adapter.py

291-291: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (7)
examples/fusedmoe/example_fusedmoe_tilelang.py (6)

216-217: LGTM: Python boolean conditional is appropriate.

The change from TIR-style conditional to Python boolean if i < actual_rows: is correct for this scalar comparison within the parallel loop.


264-266: LGTM: Consistent use of Python boolean conditional.

The conditional logic correctly guards the output assignment to prevent out-of-bounds writes when actual_rows < block_token.


359-361: LGTM: Shape simplification from 2D to 1D.

The tensor shape change from (N, 1) to (N,) is correct and consistent with the kernel signature at line 144 which expects a 1D tensor. This eliminates an unnecessary dimension and simplifies indexing.


362-365: LGTM: Shape simplification maintains compatibility.

The shape change to 1D is correct. The tensor is appropriately reshaped at line 461 via .view(-1, 1).repeat(1, x_flat.shape[-1]) for the scatter_reduce operation that requires 2D indexing.


392-392: LGTM: View operation aligns with 1D shape.

The change from .view(-1, 1) to .view(-1) correctly reflects that stacked_expert_weights is now 1D, maintaining consistency throughout the forward pass.


415-415: LGTM: Direct assignment is appropriate.

The assignment correctly populates the 1D stacked_expert_tokens_idxs tensor without unnecessary reshaping since exp_token_idxs is already 1D.

tilelang/jit/adapter/cython/adapter.py (1)

32-36: LGTM: Helper function correctly identifies symbolic expressions.

The type-check order is correct—checking not isinstance(expr, tir.IntImm) before isinstance(expr, tir.PrimExpr) ensures that concrete integer expressions are excluded while capturing all other symbolic PrimExpr instances.

Comment on lines +288 to +291
elif is_symbolic_expr(s):
static_shape.append((j, -1)) # -1 for symbolic
else:
raise ValueError(f"Unsupported shape type: {type(s)}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify stride processing handles symbolic dimensions consistently.

The shape processing now explicitly handles symbolic dimensions by marking them as -1. However, stride processing (lines 292-294) only handles tir.IntImm without similar symbolic dimension support. Since _process_dynamic_symbolic (lines 229-234) indicates that strides can also be symbolic, this may create an inconsistency.

Run the following script to check if symbolic strides are encountered in practice:


🏁 Script executed:

#!/bin/bash
# Description: Search for buffer strides usage patterns to determine if symbolic strides need similar handling

# Search for stride-related buffer operations
rg -nP --type=py -C3 'buffer\.strides' 

# Search for symbolic stride handling in tests
rg -nP --type=py -C3 'stride.*Var|Var.*stride' -g '!**/test/**'

Length of output: 8975


Handle symbolic strides consistently with symbolic shapes.

The shape processing now handles symbolic dimensions by marking them as -1, but stride processing (lines 292-294) only handles tir.IntImm without similar symbolic support. However, _process_dynamic_symbolic (lines 228-234) explicitly detects and maps symbolic strides (tir.Var). This creates an inconsistency: when symbolic strides are encountered, they are silently ignored during static stride extraction, while the method already identifies them as important.

Additionally, the contiguity check at lines 296-298 iterates over buffer.strides without accounting for symbolic strides, which will produce incorrect results or unexpected behavior when symbolic strides are present.

Add symbolic stride handling at lines 292-294 to mark symbolic strides (similar to the -1 convention for shapes), or document why symbolic strides should be excluded from static stride processing.

🧰 Tools
🪛 Ruff (0.14.0)

291-291: Avoid specifying long messages outside the exception class

(TRY003)

@LeiWang1999 LeiWang1999 merged commit cc00fb6 into tile-ai:main Oct 17, 2025
12 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Shape mismatching tensors are valid at launch

2 participants