Skip to content

Commit d5ba468

Browse files
committed
[Enhancement] Add support for symbolic dimensions in Cython kernel adapter and improve static shape validation in wrapper
1 parent d89ba5b commit d5ba468

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

tilelang/jit/adapter/cython/adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@ def _process_static_buffer_infos(self) -> \
278278
for j, s in enumerate(buffer.shape):
279279
if isinstance(s, tir.IntImm):
280280
static_shape.append((j, s.value))
281+
elif isinstance(s, tir.Var):
282+
static_shape.append((j, -1)) # -1 for symbolic
281283
for j, s in enumerate(buffer.strides):
282284
if isinstance(s, tir.IntImm):
283285
static_strides.append((j, s.value))

tilelang/jit/adapter/cython/cython_wrapper.pyx

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,19 @@ cdef class CythonKernelWrapper:
107107
if not isinstance(tensor, torch.Tensor):
108108
# otherwise, maybe torch.data_ptr() for T.ptr inputs
109109
continue
110+
111+
# Check ndim
112+
if tensor.dim() != len(shape_list):
113+
raise ValueError(
114+
f"Static shape mismatch for parameter {param}: "
115+
f"expected {len(shape_list)} dimensions, "
116+
f"got {tensor.dim()}"
117+
)
118+
119+
# Check each dimension
110120
for shape_idx, expected_shape in shape_list:
111121
actual_shape = tensor.shape[shape_idx]
112-
if actual_shape != expected_shape:
122+
if expected_shape != -1 and actual_shape != expected_shape:
113123
raise ValueError(
114124
f"Static shape mismatch for parameter {param}: "
115125
f"expected {expected_shape} at index {shape_idx}, "

0 commit comments

Comments
 (0)