Skip to content

Commit cc00fb6

Browse files
authored
[Enhancement] Add support for symbolic dimensions in Cython kernel adapter and improve static shape validation in wrapper (#1024)
* [Enhancement] Add support for symbolic dimensions in Cython kernel adapter and improve static shape validation in wrapper * [BugFix] Fix shape mismatch and deprecate `T.if()` in fused_moe example * [Fix] Add `is_symbolic_expr` function to check for symbolic expressions in TIR - Introduced a new utility function `is_symbolic_expr` to determine if an expression is a symbolic expression, enhancing type checking capabilities. - Updated shape handling in `CythonKernelAdapter` to utilize the new function, improving handling for symbolic shapes.
1 parent a79bc5c commit cc00fb6

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed

examples/fusedmoe/example_fusedmoe_tilelang.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def kernel(
213213
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
214214

215215
for i, j in T.Parallel(block_token, block_dexpert):
216-
with T.If(i < actual_rows), T.Then():
216+
if i < actual_rows:
217217
up_logits[m_start + i, by * block_dexpert + j] = up_logits_local[i, j]
218218

219219
# Step 2: Compute down logits
@@ -261,7 +261,7 @@ def kernel(
261261
transpose_B=True)
262262

263263
for i, j in T.Parallel(block_token, block_dhidden):
264-
with T.If(i < actual_rows), T.Then():
264+
if i < actual_rows:
265265
output[m_start + i, by * block_dhidden +
266266
j] = output_local[i, j] * routed_expert_weights[m_start + i]
267267

@@ -356,11 +356,11 @@ def __init__(self,
356356
dtype=torch.float16,
357357
device=self.device)
358358
self.stacked_expert_weights = torch.empty(
359-
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], 1),
359+
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]),
360360
dtype=torch.float16,
361361
device=self.device)
362362
self.stacked_expert_tokens_idxs = torch.empty(
363-
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], 1),
363+
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]),
364364
dtype=torch.int64,
365365
device=self.device)
366366

@@ -389,7 +389,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
389389
batch_size, seq_len, hidden_dim = x.shape
390390
expert_indices, expert_scores = self.gating_network(x)
391391
flat_expert_indices = expert_indices.view(-1)
392-
flat_expert_weights = expert_scores.view(-1, 1)
392+
flat_expert_weights = expert_scores.view(-1)
393393
x_flat = x.view(-1, hidden_dim)
394394

395395
# Prepare for grouped GEMM
@@ -412,7 +412,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
412412
expert_tokens = x_flat[exp_token_idxs]
413413

414414
self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens
415-
self.stacked_expert_tokens_idxs[start_idx:end_idx, 0] = exp_token_idxs
415+
self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs
416416
self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[
417417
idxs[start_idx:end_idx]]
418418

tilelang/jit/adapter/cython/adapter.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929
raise
3030

3131

32+
def is_symbolic_expr(expr) -> bool:
33+
"""Check if the expression is a symbolic expression.
34+
A symbolic expression can be a simple tvm.Var, or an tvm.PrimExpr containing tvm.Var.
35+
"""
36+
return not isinstance(expr, tir.IntImm) and isinstance(expr, tir.PrimExpr)
37+
38+
3239
class CythonKernelAdapter(BaseKernelAdapter):
3340
"""Adapter class that converts TVM/TIR functions to callable CUDA kernels using cython.
3441
@@ -278,6 +285,10 @@ def _process_static_buffer_infos(self) -> \
278285
for j, s in enumerate(buffer.shape):
279286
if isinstance(s, tir.IntImm):
280287
static_shape.append((j, s.value))
288+
elif is_symbolic_expr(s):
289+
static_shape.append((j, -1)) # -1 for symbolic
290+
else:
291+
raise ValueError(f"Unsupported shape type: {type(s)}")
281292
for j, s in enumerate(buffer.strides):
282293
if isinstance(s, tir.IntImm):
283294
static_strides.append((j, s.value))

tilelang/jit/adapter/cython/cython_wrapper.pyx

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,19 @@ cdef class CythonKernelWrapper:
107107
if not isinstance(tensor, torch.Tensor):
108108
# otherwise, maybe torch.data_ptr() for T.ptr inputs
109109
continue
110+
111+
# Check ndim
112+
if tensor.dim() != len(shape_list):
113+
raise ValueError(
114+
f"Static shape mismatch for parameter {param}: "
115+
f"expected {len(shape_list)} dimensions, "
116+
f"got {tensor.dim()}"
117+
)
118+
119+
# Check each dimension
110120
for shape_idx, expected_shape in shape_list:
111121
actual_shape = tensor.shape[shape_idx]
112-
if actual_shape != expected_shape:
122+
if expected_shape != -1 and actual_shape != expected_shape:
113123
raise ValueError(
114124
f"Static shape mismatch for parameter {param}: "
115125
f"expected {expected_shape} at index {shape_idx}, "

0 commit comments

Comments
 (0)