diff --git a/examples/fusedmoe/example_fusedmoe_tilelang.py b/examples/fusedmoe/example_fusedmoe_tilelang.py index 5978d3b13..a8d684965 100644 --- a/examples/fusedmoe/example_fusedmoe_tilelang.py +++ b/examples/fusedmoe/example_fusedmoe_tilelang.py @@ -213,7 +213,7 @@ def kernel( up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] for i, j in T.Parallel(block_token, block_dexpert): - with T.If(i < actual_rows), T.Then(): + if i < actual_rows: up_logits[m_start + i, by * block_dexpert + j] = up_logits_local[i, j] # Step 2: Compute down logits @@ -261,7 +261,7 @@ def kernel( transpose_B=True) for i, j in T.Parallel(block_token, block_dhidden): - with T.If(i < actual_rows), T.Then(): + if i < actual_rows: output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i] @@ -356,11 +356,11 @@ def __init__(self, dtype=torch.float16, device=self.device) self.stacked_expert_weights = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], 1), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device) self.stacked_expert_tokens_idxs = torch.empty( - (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], 1), + (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device) @@ -389,7 +389,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: batch_size, seq_len, hidden_dim = x.shape expert_indices, expert_scores = self.gating_network(x) flat_expert_indices = expert_indices.view(-1) - flat_expert_weights = expert_scores.view(-1, 1) + flat_expert_weights = expert_scores.view(-1) x_flat = x.view(-1, hidden_dim) # Prepare for grouped GEMM @@ -412,7 +412,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: expert_tokens = x_flat[exp_token_idxs] self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens - self.stacked_expert_tokens_idxs[start_idx:end_idx, 0] = exp_token_idxs + self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[ idxs[start_idx:end_idx]] diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py index a7bf6b4a0..4e687bfdc 100644 --- a/tilelang/jit/adapter/cython/adapter.py +++ b/tilelang/jit/adapter/cython/adapter.py @@ -29,6 +29,13 @@ raise +def is_symbolic_expr(expr) -> bool: + """Check if the expression is a symbolic expression. + A symbolic expression can be a simple tvm.Var, or an tvm.PrimExpr containing tvm.Var. + """ + return not isinstance(expr, tir.IntImm) and isinstance(expr, tir.PrimExpr) + + class CythonKernelAdapter(BaseKernelAdapter): """Adapter class that converts TVM/TIR functions to callable CUDA kernels using cython. @@ -278,6 +285,10 @@ def _process_static_buffer_infos(self) -> \ for j, s in enumerate(buffer.shape): if isinstance(s, tir.IntImm): static_shape.append((j, s.value)) + elif is_symbolic_expr(s): + static_shape.append((j, -1)) # -1 for symbolic + else: + raise ValueError(f"Unsupported shape type: {type(s)}") for j, s in enumerate(buffer.strides): if isinstance(s, tir.IntImm): static_strides.append((j, s.value)) diff --git a/tilelang/jit/adapter/cython/cython_wrapper.pyx b/tilelang/jit/adapter/cython/cython_wrapper.pyx index 77fb9d5ad..6feca69dd 100644 --- a/tilelang/jit/adapter/cython/cython_wrapper.pyx +++ b/tilelang/jit/adapter/cython/cython_wrapper.pyx @@ -107,9 +107,19 @@ cdef class CythonKernelWrapper: if not isinstance(tensor, torch.Tensor): # otherwise, maybe torch.data_ptr() for T.ptr inputs continue + + # Check ndim + if tensor.dim() != len(shape_list): + raise ValueError( + f"Static shape mismatch for parameter {param}: " + f"expected {len(shape_list)} dimensions, " + f"got {tensor.dim()}" + ) + + # Check each dimension for shape_idx, expected_shape in shape_list: actual_shape = tensor.shape[shape_idx] - if actual_shape != expected_shape: + if expected_shape != -1 and actual_shape != expected_shape: raise ValueError( f"Static shape mismatch for parameter {param}: " f"expected {expected_shape} at index {shape_idx}, "