-
Notifications
You must be signed in to change notification settings - Fork 273
[Enhancement] Add support for symbolic dimensions in Cython kernel adapter and improve static shape validation in wrapper #1024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d5ba468
221c16b
abee53c
2e4978f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -107,9 +107,19 @@ cdef class CythonKernelWrapper: | |
if not isinstance(tensor, torch.Tensor): | ||
# otherwise, maybe torch.data_ptr() for T.ptr inputs | ||
continue | ||
|
||
# Check ndim | ||
if tensor.dim() != len(shape_list): | ||
raise ValueError( | ||
f"Static shape mismatch for parameter {param}: " | ||
f"expected {len(shape_list)} dimensions, " | ||
f"got {tensor.dim()}" | ||
Comment on lines
+112
to
+116
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎. |
||
) | ||
|
||
# Check each dimension | ||
for shape_idx, expected_shape in shape_list: | ||
actual_shape = tensor.shape[shape_idx] | ||
if actual_shape != expected_shape: | ||
if expected_shape != -1 and actual_shape != expected_shape: | ||
raise ValueError( | ||
f"Static shape mismatch for parameter {param}: " | ||
f"expected {expected_shape} at index {shape_idx}, " | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify stride processing handles symbolic dimensions consistently.
The shape processing now explicitly handles symbolic dimensions by marking them as -1. However, stride processing (lines 292-294) only handles
tir.IntImm
without similar symbolic dimension support. Since_process_dynamic_symbolic
(lines 229-234) indicates that strides can also be symbolic, this may create an inconsistency.Run the following script to check if symbolic strides are encountered in practice:
🏁 Script executed:
Length of output: 8975
Handle symbolic strides consistently with symbolic shapes.
The shape processing now handles symbolic dimensions by marking them as -1, but stride processing (lines 292-294) only handles
tir.IntImm
without similar symbolic support. However,_process_dynamic_symbolic
(lines 228-234) explicitly detects and maps symbolic strides (tir.Var
). This creates an inconsistency: when symbolic strides are encountered, they are silently ignored during static stride extraction, while the method already identifies them as important.Additionally, the contiguity check at lines 296-298 iterates over
buffer.strides
without accounting for symbolic strides, which will produce incorrect results or unexpected behavior when symbolic strides are present.Add symbolic stride handling at lines 292-294 to mark symbolic strides (similar to the -1 convention for shapes), or document why symbolic strides should be excluded from static stride processing.
🧰 Tools
🪛 Ruff (0.14.0)
291-291: Avoid specifying long messages outside the exception class
(TRY003)