-
Notifications
You must be signed in to change notification settings - Fork 268
[Enhancement] Add support for symbolic dimensions in Cython kernel adapter and improve static shape validation in wrapper #1024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds detection of symbolic tensor dimensions during static shape extraction and records them as -1; enhances Cython wrapper shape checks to validate tensor dimensionality and per-dimension sizes (treating -1 as a wildcard); example kernel updated indexing and shapes to use 1D expert arrays and Python boolean conditions. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Cython as "Cython Wrapper"
participant Adapter as "Adapter"
participant Launcher as "Kernel Launcher"
User->>Cython: call kernel(tensor)
Cython->>Cython: if tensor.dim() != len(shape_list) -> raise ValueError
Cython->>Adapter: request static_shape & strides
Adapter-->>Cython: return static_shape (symbolic dims returned as -1)
Cython->>Cython: loop dims i
alt shape_list[i] == -1
Cython->>Cython: skip check for dim i (wildcard)
else
Cython->>Cython: verify tensor.shape[i] == shape_list[i] or raise ValueError
end
alt all dims valid
Cython->>Launcher: launch kernel
Launcher-->>User: execution result
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
WalkthroughUpdates CUDA half/bfloat16 absolute value helpers to use generic abs and add bfloat16 overload. Extends Cython adapter to encode symbolic tensor dims as -1 in static shapes. Adds stricter shape validation in the Cython wrapper: checks rank match and treats -1 as a wildcard per-dimension. Changes
Sequence Diagram(s)sequenceDiagram
actor U as Python Caller
participant W as Cython Wrapper (_check_static_shape)
participant A as Cython Adapter (_process_static_buffer_infos)
participant K as Kernel Launcher
U->>A: Provide tensor buffers
A->>A: Build static_shape<br/>(IntImm -> index,value)<br/>(tir.Var -> index,-1)
A-->>W: static_shape with -1 for symbolic dims
U->>W: Launch request with tensors
W->>W: Check rank equality (dim == len(shape_list))<br/>If not, raise ValueError
alt Per-dimension
W->>W: If expected == -1, skip check
else W->>W: Enforce actual == expected<br/>If mismatch, raise ValueError
end
W-->>K: Proceed to launch on success
K-->>U: Execution result
note over W,K: -1 acts as wildcard for symbolic dimensions
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (3)
🧰 Additional context used🧬 Code graph analysis (1)src/tl_templates/cuda/common.h (1)
🔇 Additional comments (3)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…apter and improve static shape validation in wrapper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tilelang/jit/adapter/cython/cython_wrapper.pyx (1)
120-123
: Per-dimension wildcard (-1) handling — solid; consider minor hardeningCurrent logic is correct. Two optional improvements:
- Guard indices: assert 0 <= shape_idx < tensor.dim() to fail fast on bad metadata.
- Enrich error: include full expected vs actual shapes for easier debugging.
for shape_idx, expected_shape in shape_list: - actual_shape = tensor.shape[shape_idx] + assert 0 <= shape_idx < tensor.dim(), ( + f"Shape index {shape_idx} out of bounds for tensor with {tensor.dim()} dims" + ) + actual_shape = tensor.shape[shape_idx] if expected_shape != -1 and actual_shape != expected_shape: - raise ValueError( - f"Static shape mismatch for parameter {param}: " - f"expected {expected_shape} at index {shape_idx}, " - f"got {actual_shape}" - ) + raise ValueError( + f"Static shape mismatch for parameter {param}: " + f"expected dim[{shape_idx}]={expected_shape}, got {actual_shape}; " + f"tensor.shape={tuple(tensor.shape)}, expected_spec={shape_list}" + )tilelang/jit/adapter/cython/adapter.py (1)
281-282
: Record symbolic dims as -1 — LGTM; consider a named sentinelUsing -1 aligns with wrapper validation. To avoid magic numbers across modules, define a shared constant (e.g., SYMBOLIC_DIM = -1) and use it in both adapter and wrapper.
-elif isinstance(s, tir.Var): - static_shape.append((j, -1)) # -1 for symbolic +elif isinstance(s, tir.Var): + static_shape.append((j, SYMBOLIC_DIM)) # -1 for symbolicAnd in cython_wrapper.pyx replace literal -1 checks with SYMBOLIC_DIM.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tilelang/jit/adapter/cython/adapter.py
(1 hunks)tilelang/jit/adapter/cython/cython_wrapper.pyx
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: build-test-metal
🔇 Additional comments (1)
tilelang/jit/adapter/cython/cython_wrapper.pyx (1)
111-118
: NDIM check addition — LGTMExplicit rank validation closes the gap from #1011. Works with symbolic dims.
/codex review |
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
if tensor.dim() != len(shape_list): | ||
raise ValueError( | ||
f"Static shape mismatch for parameter {param}: " | ||
f"expected {len(shape_list)} dimensions, " | ||
f"got {tensor.dim()}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid rank checks based on partial static shape metadata
_check_static_shape
now raises when tensor.dim() != len(shape_list)
, but shape_list
only contains entries for dimensions whose shapes were captured as tir.IntImm
or tir.Var
. Buffers whose shapes are other expressions (e.g., T.Buffer((M * N,), "float32")
) produce an empty shape_list
even though the buffer is 1-D, so every tensor passed to that parameter now fails with “expected 0 dimensions”. Previously these buffers were accepted. The rank check should rely on the buffer’s actual rank or run only when shape metadata covers all dimensions.
Useful? React with 👍 / 👎.
some tests failed for those changes. maybe we can take some insights from the codex review |
…ns in TIR - Introduced a new utility function `is_symbolic_expr` to determine if an expression is a symbolic expression, enhancing type checking capabilities. - Updated shape handling in `CythonKernelAdapter` to utilize the new function, improving handling for symbolic shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tilelang/jit/adapter/cython/adapter.py (1)
290-291
: Consider more descriptive error message.The error message could be more helpful by indicating what types are supported (e.g.,
tir.IntImm
for concrete dimensions or symbolictir.PrimExpr
for dynamic dimensions).Apply this diff to improve the error message:
else: - raise ValueError(f"Unsupported shape type: {type(s)}") + raise ValueError( + f"Unsupported shape type: {type(s)}. " + f"Expected tir.IntImm (concrete) or tir.PrimExpr (symbolic), got {s}" + )Note: This also addresses the Ruff TRY003 hint by making the message more specific to the context.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/fusedmoe/example_fusedmoe_tilelang.py
(5 hunks)tilelang/jit/adapter/cython/adapter.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/jit/adapter/cython/adapter.py (2)
src/transform/lower_tile_op.cc (2)
expr
(433-445)expr
(433-433)src/transform/layout_inference.cc (2)
expr
(389-407)expr
(389-389)
🪛 Ruff (0.14.0)
tilelang/jit/adapter/cython/adapter.py
291-291: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (7)
examples/fusedmoe/example_fusedmoe_tilelang.py (6)
216-217
: LGTM: Python boolean conditional is appropriate.The change from TIR-style conditional to Python boolean
if i < actual_rows:
is correct for this scalar comparison within the parallel loop.
264-266
: LGTM: Consistent use of Python boolean conditional.The conditional logic correctly guards the output assignment to prevent out-of-bounds writes when
actual_rows < block_token
.
359-361
: LGTM: Shape simplification from 2D to 1D.The tensor shape change from
(N, 1)
to(N,)
is correct and consistent with the kernel signature at line 144 which expects a 1D tensor. This eliminates an unnecessary dimension and simplifies indexing.
362-365
: LGTM: Shape simplification maintains compatibility.The shape change to 1D is correct. The tensor is appropriately reshaped at line 461 via
.view(-1, 1).repeat(1, x_flat.shape[-1])
for the scatter_reduce operation that requires 2D indexing.
392-392
: LGTM: View operation aligns with 1D shape.The change from
.view(-1, 1)
to.view(-1)
correctly reflects thatstacked_expert_weights
is now 1D, maintaining consistency throughout the forward pass.
415-415
: LGTM: Direct assignment is appropriate.The assignment correctly populates the 1D
stacked_expert_tokens_idxs
tensor without unnecessary reshaping sinceexp_token_idxs
is already 1D.tilelang/jit/adapter/cython/adapter.py (1)
32-36
: LGTM: Helper function correctly identifies symbolic expressions.The type-check order is correct—checking
not isinstance(expr, tir.IntImm)
beforeisinstance(expr, tir.PrimExpr)
ensures that concrete integer expressions are excluded while capturing all other symbolic PrimExpr instances.
elif is_symbolic_expr(s): | ||
static_shape.append((j, -1)) # -1 for symbolic | ||
else: | ||
raise ValueError(f"Unsupported shape type: {type(s)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify stride processing handles symbolic dimensions consistently.
The shape processing now explicitly handles symbolic dimensions by marking them as -1. However, stride processing (lines 292-294) only handles tir.IntImm
without similar symbolic dimension support. Since _process_dynamic_symbolic
(lines 229-234) indicates that strides can also be symbolic, this may create an inconsistency.
Run the following script to check if symbolic strides are encountered in practice:
🏁 Script executed:
#!/bin/bash
# Description: Search for buffer strides usage patterns to determine if symbolic strides need similar handling
# Search for stride-related buffer operations
rg -nP --type=py -C3 'buffer\.strides'
# Search for symbolic stride handling in tests
rg -nP --type=py -C3 'stride.*Var|Var.*stride' -g '!**/test/**'
Length of output: 8975
Handle symbolic strides consistently with symbolic shapes.
The shape processing now handles symbolic dimensions by marking them as -1, but stride processing (lines 292-294) only handles tir.IntImm
without similar symbolic support. However, _process_dynamic_symbolic
(lines 228-234) explicitly detects and maps symbolic strides (tir.Var
). This creates an inconsistency: when symbolic strides are encountered, they are silently ignored during static stride extraction, while the method already identifies them as important.
Additionally, the contiguity check at lines 296-298 iterates over buffer.strides
without accounting for symbolic strides, which will produce incorrect results or unexpected behavior when symbolic strides are present.
Add symbolic stride handling at lines 292-294 to mark symbolic strides (similar to the -1 convention for shapes), or document why symbolic strides should be excluded from static stride processing.
🧰 Tools
🪛 Ruff (0.14.0)
291-291: Avoid specifying long messages outside the exception class
(TRY003)
Fixed #1011.
This pull request introduces improvements to buffer shape handling and validation in the Cython adapter layer, as well as updates to CUDA type-specific functions for numerical types. The main changes enhance the robustness of shape checks for tensors and ensure correct usage of absolute value functions for custom types.
Buffer shape handling and validation:
-1
in_process_static_buffer_infos
, allowing for more flexible shape specification.CythonKernelWrapper
by checking both the number of dimensions and each dimension's size, raising clear errors for mismatches and ignoring symbolic dimensions (-1
).CUDA numerical type functions:
__habs
function for bothhalf_t
andbfloat16_t
to use the CUTLASS library'sabs
function directly, ensuring correct absolute value computation for these types.Summary by CodeRabbit
New Features
Bug Fixes
Examples