-
Notifications
You must be signed in to change notification settings - Fork 248
Ck tile batched contraction kernel generelizing #3126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Ck tile batched contraction kernel generelizing #3126
Conversation
…aware calculation and some code cleanings
…s batched contraction inputs
…sional stride support
…_tensor_view to local RunGemm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds comprehensive support for arbitrary multi-dimensional stride patterns and non-contiguous tensor layouts to the batched contraction kernel. Previously, the kernel only supported contiguous row-major layouts with hardcoded strides.
- Introduces
TensorDescriptorUtilswith vectorization support to create stride-aware tensor descriptors - Implements custom
RunGemmmethod using descriptor-based tensor views instead of relying solely onUniversalGemmKernel - Updates reference implementation to use stride-aware indexing for validation
- Adds command-line support for custom stride specifications in examples
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tensor_descriptor_utils.hpp | Adds vector size template parameters and updates descriptor creation to support vectorized memory access |
| batched_contraction_kernel.hpp | Implements descriptor-based architecture with custom RunGemm, adds tensor descriptor storage to kernel args |
| reference_batched_contraction.hpp | Refactors reference computation to use stride-aware offset calculation, changes from std::vector to std::array for D tensors |
| run_batched_contraction_example.inc | Adds custom stride parsing, implements runtime dispatch for NumDTensor, creates tensors with non-contiguous layouts |
| contraction_utils.hpp | Updates argument parsing to support stride specifications and adds help documentation |
| batched_contraction.cpp | Updates dimension case handling (appears to have duplicate case) |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| HANDLE_CASE(2, 2, 2, 1); | ||
| HANDLE_CASE(1, 2, 1, 1); | ||
| HANDLE_CASE(1, 1, 1, 2); | ||
| HANDLE_CASE(2, 1, 1, 1); |
Copilot
AI
Nov 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate case HANDLE_CASE(2, 1, 1, 1) at lines 219 and 222. The second occurrence at line 222 appears to replace a removed case for (1, 1, 1, 2), which may be intentional removal but the duplicate is incorrect. Remove line 222 or replace it with the intended dimension combination.
| HANDLE_CASE(2, 1, 1, 1); |
|
|
||
| // Decode G dimensions | ||
| ck_tile::index_t temp = g_flat; | ||
| for(ck_tile::index_t i = num_g_dims - 1; i >= 0; --i) |
Copilot
AI
Nov 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Loop condition i >= 0 with unsigned type ck_tile::index_t will always be true, causing infinite loop when i underflows. This pattern appears in multiple offset computation lambdas (lines 109, 117, 125, 141, 149, 157, 173, 181, 189, 208, 216, 224). Change loop to use int i or rewrite to avoid decrementing below zero.
Proposed changes
Extends ck-tile batched contraction kernel to support arbitrary multi-dimensional non-contiguous tensor layouts using descriptors.
Extends example to cover testing this new added feature, user can pass any manual stride for input and outputs.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered