From d5ba4684ed0cd5fc720f439b1ef56b6e87434e67 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Mon, 13 Oct 2025 16:20:34 +0000 Subject: [PATCH 1/3] [Enhancement] Add support for symbolic dimensions in Cython kernel adapter and improve static shape validation in wrapper --- tilelang/jit/adapter/cython/adapter.py | 2 ++ tilelang/jit/adapter/cython/cython_wrapper.pyx | 12 +++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py index a7bf6b4a0..1e64de34a 100644 --- a/tilelang/jit/adapter/cython/adapter.py +++ b/tilelang/jit/adapter/cython/adapter.py @@ -278,6 +278,8 @@ def _process_static_buffer_infos(self) -> \ for j, s in enumerate(buffer.shape): if isinstance(s, tir.IntImm): static_shape.append((j, s.value)) + elif isinstance(s, tir.Var): + static_shape.append((j, -1)) # -1 for symbolic for j, s in enumerate(buffer.strides): if isinstance(s, tir.IntImm): static_strides.append((j, s.value)) diff --git a/tilelang/jit/adapter/cython/cython_wrapper.pyx b/tilelang/jit/adapter/cython/cython_wrapper.pyx index c37cb4aa0..f6cd59276 100644 --- a/tilelang/jit/adapter/cython/cython_wrapper.pyx +++ b/tilelang/jit/adapter/cython/cython_wrapper.pyx @@ -107,9 +107,19 @@ cdef class CythonKernelWrapper: if not isinstance(tensor, torch.Tensor): # otherwise, maybe torch.data_ptr() for T.ptr inputs continue + + # Check ndim + if tensor.dim() != len(shape_list): + raise ValueError( + f"Static shape mismatch for parameter {param}: " + f"expected {len(shape_list)} dimensions, " + f"got {tensor.dim()}" + ) + + # Check each dimension for shape_idx, expected_shape in shape_list: actual_shape = tensor.shape[shape_idx] - if actual_shape != expected_shape: + if expected_shape != -1 and actual_shape != expected_shape: raise ValueError( f"Static shape mismatch for parameter {param}: " f"expected {expected_shape} at index {shape_idx}, " From abee53cb66607cfb4ff98e8e55eda411c87631d4 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Thu, 16 Oct 2025 14:25:48 +0000 Subject: [PATCH 2/3] [BugFix] Fix shape mismatch and deprecate `T.if()` in fused_moe example --- examples/fusedmoe/example_fusedmoe_tilelang.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/fusedmoe/example_fusedmoe_tilelang.py b/examples/fusedmoe/example_fusedmoe_tilelang.py index 5978d3b13..a8d684965 100644 --- a/examples/fusedmoe/example_fusedmoe_tilelang.py +++ b/examples/fusedmoe/example_fusedmoe_tilelang.py @@ -213,7 +213,7 @@ def kernel( up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] for i, j in T.Parallel(block_token, block_dexpert): - with T.If(i < actual_rows), T.Then(): + if i < actual_rows: up_logits[m_start + i, by * block_dexpert + j] = up_logits_local[i, j] # Step 2: Compute down logits @@ -261,7 +261,7 @@ def kernel( transpose_B=True) for i, j in T.Parallel(block_token, block_dhidden): - with T.If(i < actual_rows), T.Then(): + if i < actual_rows: output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i] @@ -356,11 +356,11 @@ def __init__(self, dtype=torch.float16, device=self.device) self.stacked_expert_weights = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], 1), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device) self.stacked_expert_tokens_idxs = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], 1), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device) @@ -389,7 +389,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: batch_size, seq_len, hidden_dim = x.shape expert_indices, expert_scores = self.gating_network(x) flat_expert_indices = expert_indices.view(-1) - flat_expert_weights = expert_scores.view(-1, 1) + flat_expert_weights = expert_scores.view(-1) x_flat = x.view(-1, hidden_dim) # Prepare for grouped GEMM @@ -412,7 +412,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: expert_tokens = x_flat[exp_token_idxs] self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens - self.stacked_expert_tokens_idxs[start_idx:end_idx, 0] = exp_token_idxs + self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[ idxs[start_idx:end_idx]] From 2e4978f1221a9069f7b62ce6a3081991527c12f4 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Thu, 16 Oct 2025 14:26:30 +0000 Subject: [PATCH 3/3] [Fix] Add `is_symbolic_expr` function to check for symbolic expressions in TIR - Introduced a new utility function `is_symbolic_expr` to determine if an expression is a symbolic expression, enhancing type checking capabilities. - Updated shape handling in `CythonKernelAdapter` to utilize the new function, improving handling for symbolic shapes. --- tilelang/jit/adapter/cython/adapter.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py index 1e64de34a..4e687bfdc 100644 --- a/tilelang/jit/adapter/cython/adapter.py +++ b/tilelang/jit/adapter/cython/adapter.py @@ -29,6 +29,13 @@ raise +def is_symbolic_expr(expr) -> bool: + """Check if the expression is a symbolic expression. + A symbolic expression can be a simple tvm.Var, or an tvm.PrimExpr containing tvm.Var. + """ + return not isinstance(expr, tir.IntImm) and isinstance(expr, tir.PrimExpr) + + class CythonKernelAdapter(BaseKernelAdapter): """Adapter class that converts TVM/TIR functions to callable CUDA kernels using cython. @@ -278,8 +285,10 @@ def _process_static_buffer_infos(self) -> \ for j, s in enumerate(buffer.shape): if isinstance(s, tir.IntImm): static_shape.append((j, s.value)) - elif isinstance(s, tir.Var): + elif is_symbolic_expr(s): static_shape.append((j, -1)) # -1 for symbolic + else: + raise ValueError(f"Unsupported shape type: {type(s)}") for j, s in enumerate(buffer.strides): if isinstance(s, tir.IntImm): static_strides.append((j, s.value))