Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/fusedmoe/example_fusedmoe_tilelang.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def kernel(
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]

for i, j in T.Parallel(block_token, block_dexpert):
with T.If(i < actual_rows), T.Then():
if i < actual_rows:
up_logits[m_start + i, by * block_dexpert + j] = up_logits_local[i, j]

# Step 2: Compute down logits
Expand Down Expand Up @@ -261,7 +261,7 @@ def kernel(
transpose_B=True)

for i, j in T.Parallel(block_token, block_dhidden):
with T.If(i < actual_rows), T.Then():
if i < actual_rows:
output[m_start + i, by * block_dhidden +
j] = output_local[i, j] * routed_expert_weights[m_start + i]

Expand Down Expand Up @@ -356,11 +356,11 @@ def __init__(self,
dtype=torch.float16,
device=self.device)
self.stacked_expert_weights = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], 1),
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]),
dtype=torch.float16,
device=self.device)
self.stacked_expert_tokens_idxs = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], 1),
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]),
dtype=torch.int64,
device=self.device)

Expand Down Expand Up @@ -389,7 +389,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, hidden_dim = x.shape
expert_indices, expert_scores = self.gating_network(x)
flat_expert_indices = expert_indices.view(-1)
flat_expert_weights = expert_scores.view(-1, 1)
flat_expert_weights = expert_scores.view(-1)
x_flat = x.view(-1, hidden_dim)

# Prepare for grouped GEMM
Expand All @@ -412,7 +412,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
expert_tokens = x_flat[exp_token_idxs]

self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens
self.stacked_expert_tokens_idxs[start_idx:end_idx, 0] = exp_token_idxs
self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs
self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[
idxs[start_idx:end_idx]]

Expand Down
11 changes: 11 additions & 0 deletions tilelang/jit/adapter/cython/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
raise


def is_symbolic_expr(expr) -> bool:
"""Check if the expression is a symbolic expression.
A symbolic expression can be a simple tvm.Var, or an tvm.PrimExpr containing tvm.Var.
"""
return not isinstance(expr, tir.IntImm) and isinstance(expr, tir.PrimExpr)


class CythonKernelAdapter(BaseKernelAdapter):
"""Adapter class that converts TVM/TIR functions to callable CUDA kernels using cython.

Expand Down Expand Up @@ -278,6 +285,10 @@ def _process_static_buffer_infos(self) -> \
for j, s in enumerate(buffer.shape):
if isinstance(s, tir.IntImm):
static_shape.append((j, s.value))
elif is_symbolic_expr(s):
static_shape.append((j, -1)) # -1 for symbolic
else:
raise ValueError(f"Unsupported shape type: {type(s)}")
Comment on lines +288 to +291
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify stride processing handles symbolic dimensions consistently.

The shape processing now explicitly handles symbolic dimensions by marking them as -1. However, stride processing (lines 292-294) only handles tir.IntImm without similar symbolic dimension support. Since _process_dynamic_symbolic (lines 229-234) indicates that strides can also be symbolic, this may create an inconsistency.

Run the following script to check if symbolic strides are encountered in practice:


🏁 Script executed:

#!/bin/bash
# Description: Search for buffer strides usage patterns to determine if symbolic strides need similar handling

# Search for stride-related buffer operations
rg -nP --type=py -C3 'buffer\.strides' 

# Search for symbolic stride handling in tests
rg -nP --type=py -C3 'stride.*Var|Var.*stride' -g '!**/test/**'

Length of output: 8975


Handle symbolic strides consistently with symbolic shapes.

The shape processing now handles symbolic dimensions by marking them as -1, but stride processing (lines 292-294) only handles tir.IntImm without similar symbolic support. However, _process_dynamic_symbolic (lines 228-234) explicitly detects and maps symbolic strides (tir.Var). This creates an inconsistency: when symbolic strides are encountered, they are silently ignored during static stride extraction, while the method already identifies them as important.

Additionally, the contiguity check at lines 296-298 iterates over buffer.strides without accounting for symbolic strides, which will produce incorrect results or unexpected behavior when symbolic strides are present.

Add symbolic stride handling at lines 292-294 to mark symbolic strides (similar to the -1 convention for shapes), or document why symbolic strides should be excluded from static stride processing.

🧰 Tools
🪛 Ruff (0.14.0)

291-291: Avoid specifying long messages outside the exception class

(TRY003)

for j, s in enumerate(buffer.strides):
if isinstance(s, tir.IntImm):
static_strides.append((j, s.value))
Expand Down
12 changes: 11 additions & 1 deletion tilelang/jit/adapter/cython/cython_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,19 @@ cdef class CythonKernelWrapper:
if not isinstance(tensor, torch.Tensor):
# otherwise, maybe torch.data_ptr() for T.ptr inputs
continue

# Check ndim
if tensor.dim() != len(shape_list):
raise ValueError(
f"Static shape mismatch for parameter {param}: "
f"expected {len(shape_list)} dimensions, "
f"got {tensor.dim()}"
Comment on lines +112 to +116

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid rank checks based on partial static shape metadata

_check_static_shape now raises when tensor.dim() != len(shape_list), but shape_list only contains entries for dimensions whose shapes were captured as tir.IntImm or tir.Var. Buffers whose shapes are other expressions (e.g., T.Buffer((M * N,), "float32")) produce an empty shape_list even though the buffer is 1-D, so every tensor passed to that parameter now fails with “expected 0 dimensions”. Previously these buffers were accepted. The rank check should rely on the buffer’s actual rank or run only when shape metadata covers all dimensions.

Useful? React with 👍 / 👎.

)

# Check each dimension
for shape_idx, expected_shape in shape_list:
actual_shape = tensor.shape[shape_idx]
if actual_shape != expected_shape:
if expected_shape != -1 and actual_shape != expected_shape:
raise ValueError(
f"Static shape mismatch for parameter {param}: "
f"expected {expected_shape} at index {shape_idx}, "
Expand Down