File tree Expand file tree Collapse file tree 2 files changed +13
-1
lines changed
tilelang/jit/adapter/cython Expand file tree Collapse file tree 2 files changed +13
-1
lines changed Original file line number Diff line number Diff line change @@ -278,6 +278,8 @@ def _process_static_buffer_infos(self) -> \
278278 for j , s in enumerate (buffer .shape ):
279279 if isinstance (s , tir .IntImm ):
280280 static_shape .append ((j , s .value ))
281+ elif isinstance (s , tir .Var ):
282+ static_shape .append ((j , - 1 )) # -1 for symbolic
281283 for j , s in enumerate (buffer .strides ):
282284 if isinstance (s , tir .IntImm ):
283285 static_strides .append ((j , s .value ))
Original file line number Diff line number Diff line change @@ -107,9 +107,19 @@ cdef class CythonKernelWrapper:
107107 if not isinstance (tensor, torch.Tensor):
108108 # otherwise, maybe torch.data_ptr() for T.ptr inputs
109109 continue
110+
111+ # Check ndim
112+ if tensor.dim() != len (shape_list):
113+ raise ValueError (
114+ f" Static shape mismatch for parameter {param}: "
115+ f" expected {len(shape_list)} dimensions, "
116+ f" got {tensor.dim()}"
117+ )
118+
119+ # Check each dimension
110120 for shape_idx, expected_shape in shape_list:
111121 actual_shape = tensor.shape[shape_idx]
112- if actual_shape != expected_shape:
122+ if expected_shape ! = - 1 and actual_shape != expected_shape:
113123 raise ValueError (
114124 f" Static shape mismatch for parameter {param}: "
115125 f" expected {expected_shape} at index {shape_idx}, "
You can’t perform that action at this time.
0 commit comments