Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
356b50f
Add help for example
msaffari-amd Oct 15, 2025
b161cd9
Refactore the compute reference batched contraction to manage stride-…
msaffari-amd Oct 16, 2025
553c05e
Merge branch 'develop' into ck_tile_batched_contraction_kernel_genere…
msaffari-amd Oct 16, 2025
4027a92
Add stride-aware reference for batched contraction with independent D…
msaffari-amd Oct 17, 2025
9fc1a8c
Add -num_d argument for runtime D tensor count selection in batched c…
msaffari-amd Oct 17, 2025
fec8332
Add stride vector arguments in example code for testing non-contiguou…
msaffari-amd Oct 17, 2025
2ecb0bf
Add descriptor-based architecture for batched contraction multi-dimen…
msaffari-amd Oct 20, 2025
b8b56d5
Add multi-dimensional non-contiguous stride support to batched contra…
msaffari-amd Oct 20, 2025
bbfe450
Add complete multi-dimensional stride support via descriptors
msaffari-amd Oct 20, 2025
6144f5c
Enable vectorization in descriptor-based batched contraction. Add pad…
msaffari-amd Oct 21, 2025
4883883
Clean up batched contraction: remove old UniversalGemmKernel path
msaffari-amd Oct 27, 2025
670409c
merge develop
msaffari-amd Oct 29, 2025
e7f5f0b
Clean up batched contraction: remove legacy paths and finalize docs
msaffari-amd Oct 29, 2025
0eb1b55
Optimize batched contraction example: pass dimension sizes not vectors
msaffari-amd Oct 29, 2025
11514c7
Merge branch 'develop' into ck_tile_batched_contraction_kernel_genere…
msaffari-amd Oct 30, 2025
c3857ee
Merge branch 'develop' into ck_tile_batched_contraction_kernel_genere…
msaffari-amd Nov 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,8 @@ float batched_contraction(const ck_tile::BatchedContractionHostArgs<DsDataType::
HANDLE_CASE(2, 1, 1, 1);
HANDLE_CASE(2, 2, 2, 1);
HANDLE_CASE(1, 2, 1, 1);
HANDLE_CASE(1, 1, 1, 2);
HANDLE_CASE(2, 1, 1, 1);
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate case HANDLE_CASE(2, 1, 1, 1) at lines 219 and 222. The second occurrence at line 222 appears to replace a removed case for (1, 1, 1, 2), which may be intentional removal but the duplicate is incorrect. Remove line 222 or replace it with the intended dimension combination.

Suggested change
HANDLE_CASE(2, 1, 1, 1);

Copilot uses AI. Check for mistakes.
HANDLE_CASE(2, 2, 2, 2);
HANDLE_CASE(4, 4, 4, 4);
Comment on lines -222 to -224
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why removing those two?


throw std::runtime_error(
"Unsupported dimension combination: G=" + std::to_string(num_g_dims) +
Expand Down
72 changes: 69 additions & 3 deletions example/ck_tile/41_batched_contraction/contraction_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,83 @@ using AccDataType = ContractionTypes::AccDataType;
using EDataType = ContractionTypes::EDataType;
using DDataType = ContractionTypes::DDataType;

void print_help(const char* program_name)
{
std::cout << "\n";
std::cout << "Batched Tensor Contraction with element-wise fusion\n";
std::cout << "E[G,M,N] = element_wise_op(contraction(A[G,M,K], B[G,N,K]), D0, D1, ...)\n";
std::cout << "(Supports multiple D tensors with configurable element-wise operations)\n\n";

std::cout << "Usage: " << program_name << " [OPTIONS]\n\n";

std::cout << "Dimension Arguments (comma-separated, no spaces):\n";
std::cout << " -g_dims=<dims> Batch dimensions (default: \"1,2\")\n";
std::cout << " -m_dims=<dims> M (row) dimensions (default: \"4,256\")\n";
std::cout << " -n_dims=<dims> N (column) dimensions (default: \"16,128\")\n";
std::cout << " -k_dims=<dims> K (contract) dims (default: \"64\")\n";
std::cout << " -num_d=<int> Number of D tensors (default: 2, range: 0-4)\n\n";

std::cout << "Custom Stride Arguments (for testing non-contiguous tensors):\n";
std::cout << " -strides_a=<s> A tensor strides (comma-separated, empty = auto)\n";
std::cout << " -strides_b=<s> B tensor strides (comma-separated, empty = auto)\n";
std::cout << " -strides_e=<s> E tensor strides (comma-separated, empty = auto)\n";
std::cout << " -strides_ds=<s> D tensors strides (semicolon-separated, empty = same as E)\n";
std::cout << " Example: -strides_a=\"32768,128,1\" -strides_ds=\"512,2,1;1024,4,1\"\n\n";

std::cout << "Layout Arguments:\n";
std::cout
<< " -a_layout=<R|C> A tensor layout (R=Row-major, C=Column-major, default: \"R\")\n";
std::cout << " -b_layout=<R|C> B tensor layout (default: \"C\")\n";
std::cout << " -e_layout=<R|C> E tensor layout (default: \"R\")\n\n";

std::cout << "Examples:\n";
std::cout << " Single batch (12 batches of 256×128):\n";
std::cout << " " << program_name
<< " -g_dims=\"12\" -m_dims=\"256\" -n_dims=\"128\" -k_dims=\"64\"\n\n";

std::cout << " 2D batch grid (2×3=6 batches):\n";
std::cout << " " << program_name
<< " -g_dims=\"2,3\" -m_dims=\"128\" -n_dims=\"128\" -k_dims=\"64\"\n\n";

std::cout << " Multi-dimensional (flattened to M=128, N=128, K=128):\n";
std::cout << " " << program_name
<< " -g_dims=\"4\" -m_dims=\"8,16\" -n_dims=\"32,4\" -k_dims=\"16,8\"\n\n";

std::cout << "Other Options:\n";
std::cout << " -v=<0|1> Validation (0=off, 1=on, default: 1)\n";
std::cout << " -split_k=<int> Split-K value (default: 1)\n";
std::cout << " -warmup=<int> Warmup iterations (default: 5)\n";
std::cout << " -repeat=<int> Benchmark iterations (default: 10)\n";
std::cout << " -log=<0|1> Logging level (default: 1)\n";
std::cout << " -help Show this help\n\n";
}

auto create_args(int argc, char* argv[])
{
// Check for --help flag
for(int i = 1; i < argc; ++i)
{
std::string arg = argv[i];
if(arg == "--help" || arg == "-h" || arg == "-help")
{
print_help(argv[0]);
std::exit(0);
}
}

ck_tile::ArgParser arg_parser;
arg_parser.insert("m_dims", "4,256", "M dimensions separated by comma (e.g., '16,32' for 2D M)")
.insert("n_dims", "16,128", "N dimensions separated by comma (e.g., '32,32' for 2D N)")
.insert("k_dims", "64", "K dimensions separated by comma (e.g., '64,32' for 2D K)")
.insert(
"g_dims", "1,2", "G dimensions separated by comma (e.g., '4,2' for 2D, '2,3,4' for 3D)")
.insert("stride_a", "0", "Custom A tensor leading dimension stride (0 = auto)")
.insert("stride_b", "0", "Custom B tensor leading dimension stride (0 = auto)")
.insert("stride_e", "0", "Custom E tensor leading dimension stride (0 = auto)")
.insert("num_d", "2", "Number of D (auxiliary input) tensors")
.insert("strides_a", "", "A tensor strides (comma-separated, empty = auto/contiguous)")
.insert("strides_b", "", "B tensor strides (comma-separated, empty = auto/contiguous)")
.insert("strides_e", "", "E tensor strides (comma-separated, empty = auto/contiguous)")
.insert("strides_ds",
"",
"D tensors strides (semicolon-separated for multiple, empty = same as E)")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Col by default")
.insert("e_layout", "R", "E tensor data layout - Row by default")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ float invoke_batched_contraction_kernel(
const void* b_full_dims_dev_buf,
const std::array<const void*, DsDataType::size()>& ds_dev_buf,
void* e_full_dims_dev_buf,
const std::vector<ck_tile::index_t>& G_dims,
const std::vector<ck_tile::index_t>& M_dims,
const std::vector<ck_tile::index_t>& N_dims,
const std::vector<ck_tile::index_t>& K_dims,
ck_tile::index_t num_g_dims,
ck_tile::index_t num_m_dims,
ck_tile::index_t num_n_dims,
ck_tile::index_t num_k_dims,
const std::vector<ck_tile::index_t>& A_dims, // [G0,G1,..,M0,M1,..,K0,K1,..]
const std::vector<ck_tile::index_t>& B_dims, // [G0,G1,..,N0,N1,..,K0,K1,..]
const std::array<std::vector<ck_tile::index_t>, DsDataType::size()>&
Expand Down Expand Up @@ -79,9 +79,8 @@ float invoke_batched_contraction_kernel(
E_strides // E_strides
);

std::cout << "Calling batched_contraction with dimensions: G=" << G_dims.size()
<< ", M=" << M_dims.size() << ", N=" << N_dims.size() << ", K=" << K_dims.size()
<< std::endl;
std::cout << "Calling batched_contraction with dimensions: G=" << num_g_dims
<< ", M=" << num_m_dims << ", N=" << num_n_dims << ", K=" << num_k_dims << std::endl;

float ave_time = batched_contraction<ADataType,
BDataType,
Expand All @@ -95,16 +94,19 @@ float invoke_batched_contraction_kernel(
CDEElementWise>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
G_dims.size(), // num_g_dims
M_dims.size(), // num_m_dims
N_dims.size(), // num_n_dims
K_dims.size() // num_k_dims
);
num_g_dims,
num_m_dims,
num_n_dims,
num_k_dims);

return ave_time;
}

template <typename ALayout, typename BLayout, typename DLayout, typename ELayout>
template <typename ALayout,
typename BLayout,
typename DLayout,
typename ELayout,
ck_tile::index_t NumDTensor>
int run_batched_contraction_example_with_layouts(
int argc,
char* argv[],
Expand All @@ -122,8 +124,6 @@ int run_batched_contraction_example_with_layouts(
std::vector<ck_tile::index_t> N_dims = parse_dimensions(arg_parser.get_str("n_dims"));
std::vector<ck_tile::index_t> K_dims = parse_dimensions(arg_parser.get_str("k_dims"));

constexpr ck_tile::index_t NumDTensor = 2;

ck_tile::index_t G_total = calculate_total_elements(G_dims);
ck_tile::index_t M_total = calculate_total_elements(M_dims);
ck_tile::index_t N_total = calculate_total_elements(N_dims);
Expand All @@ -148,13 +148,105 @@ int run_batched_contraction_example_with_layouts(
return converted;
};

ck_tile::HostTensorDescriptor a_desc(A_dims);
ck_tile::HostTensorDescriptor b_desc(B_dims);
ck_tile::HostTensorDescriptor e_desc(E_dims);
// Get custom stride arguments
std::string strides_a_str = arg_parser.get_str("strides_a");
std::string strides_b_str = arg_parser.get_str("strides_b");
std::string strides_e_str = arg_parser.get_str("strides_e");
std::string strides_ds_str = arg_parser.get_str("strides_ds");

// Create A descriptor with custom or default strides
ck_tile::HostTensorDescriptor a_desc;
if(!strides_a_str.empty())
{
std::vector<ck_tile::index_t> custom_a_strides = parse_dimensions(strides_a_str);
if(custom_a_strides.size() != A_dims.size())
{
throw std::runtime_error("strides_a size must match A_dims size");
}
std::vector<std::size_t> a_strides_size_t(custom_a_strides.begin(), custom_a_strides.end());
a_desc = ck_tile::HostTensorDescriptor(A_dims, a_strides_size_t);
std::cout << "Using custom strides for A (non-contiguous)" << std::endl;
}
else
{
a_desc = ck_tile::HostTensorDescriptor(A_dims);
}

// Create B descriptor with custom or default strides
ck_tile::HostTensorDescriptor b_desc;
if(!strides_b_str.empty())
{
std::vector<ck_tile::index_t> custom_b_strides = parse_dimensions(strides_b_str);
if(custom_b_strides.size() != B_dims.size())
{
throw std::runtime_error("strides_b size must match B_dims size");
}
std::vector<std::size_t> b_strides_size_t(custom_b_strides.begin(), custom_b_strides.end());
b_desc = ck_tile::HostTensorDescriptor(B_dims, b_strides_size_t);
std::cout << "Using custom strides for B (non-contiguous)" << std::endl;
}
else
{
b_desc = ck_tile::HostTensorDescriptor(B_dims);
}

// Create E descriptor with custom or default strides
ck_tile::HostTensorDescriptor e_desc;
if(!strides_e_str.empty())
{
std::vector<ck_tile::index_t> custom_e_strides = parse_dimensions(strides_e_str);
if(custom_e_strides.size() != E_dims.size())
{
throw std::runtime_error("strides_e size must match E_dims size");
}
std::vector<std::size_t> e_strides_size_t(custom_e_strides.begin(), custom_e_strides.end());
e_desc = ck_tile::HostTensorDescriptor(E_dims, e_strides_size_t);
std::cout << "Using custom strides for E (non-contiguous)" << std::endl;
}
else
{
e_desc = ck_tile::HostTensorDescriptor(E_dims);
}
// Create D descriptors with custom or default strides (default = same as E)
std::array<ck_tile::HostTensorDescriptor, NumDTensor> ds_descs;
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
if(!strides_ds_str.empty())
{
ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides());
// Parse semicolon-separated stride vectors for multiple D tensors
std::vector<std::vector<ck_tile::index_t>> all_ds_strides;
std::stringstream ss(strides_ds_str);
std::string d_stride_str;

while(std::getline(ss, d_stride_str, ';'))
{
all_ds_strides.push_back(parse_dimensions(d_stride_str));
}

if(all_ds_strides.size() != NumDTensor)
{
throw std::runtime_error("Number of D stride vectors must match num_d=" +
std::to_string(NumDTensor));
}

std::cout << "Using custom strides for D tensors (non-contiguous)" << std::endl;
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
if(all_ds_strides[d].size() != E_dims.size())
{
throw std::runtime_error("D tensor " + std::to_string(d) +
" stride size must match E_dims size");
}
std::vector<std::size_t> d_strides_size_t(all_ds_strides[d].begin(),
all_ds_strides[d].end());
ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], d_strides_size_t);
}
}
else
{
// Default: use same strides as E
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides());
}
}

std::vector<ck_tile::index_t> A_strides = convert_strides(a_desc.get_strides());
Expand Down Expand Up @@ -201,11 +293,14 @@ int run_batched_contraction_example_with_layouts(
ck_tile::HostTensor<::BDataType> b_full_dims_host(b_desc);
ck_tile::HostTensor<::EDataType> e_full_dims_host(e_desc);

std::vector<ck_tile::HostTensor<::DDataType>> ds_full_dims_host;
for(int d = 0; d < NumDTensor; ++d)
{
ds_full_dims_host.emplace_back(ck_tile::HostTensor<::DDataType>(ds_descs[d]));
}
// Helper to construct array of HostTensors using index_sequence
auto make_ds_host_tensors = []<std::size_t... Is>(const auto& descs,
std::index_sequence<Is...>) {
return std::array<ck_tile::HostTensor<::DDataType>, sizeof...(Is)>{
ck_tile::HostTensor<::DDataType>(descs[Is])...};
};

auto ds_full_dims_host = make_ds_host_tensors(ds_descs, std::make_index_sequence<NumDTensor>{});

ck_tile::FillUniformDistribution<::ADataType>{-5.f, 5.f, std::nullopt}(a_full_dims_host);
ck_tile::FillUniformDistribution<::BDataType>{-5.f, 5.f, std::nullopt}(b_full_dims_host);
Expand Down Expand Up @@ -260,10 +355,10 @@ int run_batched_contraction_example_with_layouts(
b_full_dims_dev_buf.GetDeviceBuffer(),
ds_ptr_buf,
e_full_dims_dev_buf.GetDeviceBuffer(),
G_dims,
M_dims,
N_dims,
K_dims,
G_dims.size(),
M_dims.size(),
N_dims.size(),
K_dims.size(),
A_dims,
B_dims,
Ds_dims,
Expand Down Expand Up @@ -316,20 +411,25 @@ int run_batched_contraction_example_with_layouts(

auto start_time = std::chrono::high_resolution_clock::now();

calculate_reference_flat_indexing<ADataType,
BDataType,
DDataType,
EDataType,
AccDataType,
CDEElementWise>(a_full_dims_host,
compute_reference_batched_contraction<ADataType,
BDataType,
DDataType,
EDataType,
AccDataType,
CDEElementWise,
NumDTensor>(a_full_dims_host,
b_full_dims_host,
ds_full_dims_host,
e_full_dims_host_ref,
G_total,
M_total,
N_total,
K_total,
CDEElementWise{});
CDEElementWise{},
G_dims,
M_dims,
N_dims,
K_dims);

auto end_time = std::chrono::high_resolution_clock::now();
auto duration =
Expand Down Expand Up @@ -387,15 +487,45 @@ int run_batched_contraction_example(int argc, char* argv[])
if(!result)
return -1;

// Get NumDTensor to dispatch at runtime
const int num_d = arg_parser.get_int("num_d");

using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;

std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");

// Runtime dispatch based on num_d value
if(a_layout == "R" && b_layout == "C")
{
return run_batched_contraction_example_with_layouts(argc, argv, Row{}, Col{}, Row{}, Row{});
// Dispatch to appropriate template instantiation based on runtime num_d
switch(num_d)
{
case 0:
std::cout << "Running with 0 D tensors" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 0>(
argc, argv, Row{}, Col{}, Row{}, Row{});
case 1:
std::cout << "Running with 1 D tensor" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 1>(
argc, argv, Row{}, Col{}, Row{}, Row{});
case 2:
std::cout << "Running with 2 D tensors" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 2>(
argc, argv, Row{}, Col{}, Row{}, Row{});
case 3:
std::cout << "Running with 3 D tensors" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 3>(
argc, argv, Row{}, Col{}, Row{}, Row{});
case 4:
std::cout << "Running with 4 D tensors" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 4>(
argc, argv, Row{}, Col{}, Row{}, Row{});
default:
throw std::runtime_error("num_d must be between 0 and 4, got: " +
std::to_string(num_d));
}
}
else
{
Expand Down
Loading