Skip to content

Commit

Permalink
Fix shuffle bug in CodeGen C. (#8567)
Browse files Browse the repository at this point in the history
* Fix Shuffle-bug codegen for GPU_Codegen_C and add test.

* Improve shuffle test to support 4-wide vectors.

* Fix shuffle test.

* Two more asserts.

* Improve asserts

* Comments.

* Rename test shuffle.cpp to vector_shuffle.cpp
  • Loading branch information
mcourteaux authored Feb 18, 2025
1 parent d5681a4 commit 2e36da4
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 9 deletions.
68 changes: 59 additions & 9 deletions src/CodeGen_GPU_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,18 @@ void CodeGen_GPU_C::visit(const Shuffle *op) {
internal_assert(op->vectors[0].type() == op->vectors[i].type());
}
internal_assert(op->type.lanes() == (int)op->indices.size());
const int max_index = (int)(op->vectors[0].type().lanes() * op->vectors.size());
// We need to construct the mapping between shuffled-index,
// and source-vector-index and source-element-index-within-the-vector.
// To start, we'll figure out what the first shuffle-index is per
// source-vector. Also let's compute the total number of
// source-elements the to be able to assert that all of the
// shuffle-indices are within range.
std::vector<int> vector_first_index;
int max_index = 0;
for (const Expr &v : op->vectors) {
vector_first_index.push_back(max_index);
max_index += v.type().lanes();
}
for (int i : op->indices) {
internal_assert(i >= 0 && i < max_index);
}
Expand All @@ -162,25 +173,64 @@ void CodeGen_GPU_C::visit(const Shuffle *op) {
std::string src = vecs[0];
std::ostringstream rhs;
std::string storage_name = unique_name('_');
if (vector_declaration_style == VectorDeclarationStyle::OpenCLSyntax) {
switch (vector_declaration_style) {
case VectorDeclarationStyle::OpenCLSyntax:
rhs << "(" << print_type(op->type) << ")(";
} else if (vector_declaration_style == VectorDeclarationStyle::WGSLSyntax) {
break;
case VectorDeclarationStyle::WGSLSyntax:
rhs << print_type(op->type) << "(";
} else {
break;
case VectorDeclarationStyle::CLikeSyntax:
rhs << "{";
break;
}
int elem_num = 0;
for (int i : op->indices) {
rhs << vecs[i];
if (i < (int)(op->indices.size() - 1)) {
size_t vector_idx;
int lane_idx = -1;
// Find in which source vector this shuffle-index "i" falls:
for (vector_idx = 0; vector_idx < op->vectors.size(); ++vector_idx) {
const int first_index = vector_first_index[vector_idx];
if (i >= first_index &&
i < first_index + op->vectors[vector_idx].type().lanes()) {
lane_idx = i - first_index;
break;
}
}
internal_assert(lane_idx != -1) << "Shuffle lane index not found: i=" << i;
internal_assert(vector_idx < op->vectors.size()) << "Shuffle vector index not found: i=" << i << ", lane=" << lane_idx;
// Print the vector in which we will index.
rhs << vecs[vector_idx];
// In case we are dealing with an actual vector instead of scalar,
// print out the required indexing syntax.
if (op->vectors[vector_idx].type().lanes() > 1) {
switch (vector_declaration_style) {
case VectorDeclarationStyle::OpenCLSyntax:
rhs << ".s" << lane_idx;
break;
case VectorDeclarationStyle::WGSLSyntax:
case VectorDeclarationStyle::CLikeSyntax:
rhs << "[" << lane_idx << "]";
break;
}
}

// Elements of a vector are comma separated.
if (elem_num < (int)(op->indices.size() - 1)) {
rhs << ", ";
}
elem_num++;
}
if (vector_declaration_style == VectorDeclarationStyle::OpenCLSyntax) {
switch (vector_declaration_style) {
case VectorDeclarationStyle::OpenCLSyntax:
rhs << ")";
} else if (vector_declaration_style == VectorDeclarationStyle::WGSLSyntax) {
break;
case VectorDeclarationStyle::WGSLSyntax:
rhs << ")";
} else {
break;
case VectorDeclarationStyle::CLikeSyntax:
rhs << "}";
break;
}
print_assignment(op->type, rhs.str());
}
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ tests(GROUPS correctness
vector_math.cpp
vector_print_bug.cpp
vector_reductions.cpp
vector_shuffle.cpp
vector_tile.cpp
vectorize_guard_with_if.cpp
vectorize_mixed_widths.cpp
Expand Down
65 changes: 65 additions & 0 deletions test/correctness/vector_shuffle.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {
Target target = get_jit_target_from_environment();
if (target.has_feature(Target::Feature::Vulkan)) {
std::printf("[SKIP] Vulkan seems to be not working.\n");
return 0;
}

Var x{"x"}, y{"y"};

Func f0{"f0"}, f1{"f1"}, g{"g"};
f0(x, y) = x * (y + 1);
f1(x, y) = x * (y + 3);
Expr vec1 = Internal::Shuffle::make_concat({f0(x, 0), f0(x, 1), f0(x, 2), f0(x, 3)});
Expr vec2 = Internal::Shuffle::make_concat({f1(x, 4), f1(x, 5), f1(x, 6), f1(x, 7)});
std::vector<int> indices0;
std::vector<int> indices1;
if (!target.has_gpu_feature() || target.has_feature(Target::Feature::OpenCL) || target.has_feature(Target::Feature::CUDA)) {
indices0 = {3, 1, 6, 7, 2, 4, 0, 5};
indices1 = {1, 0, 3, 4, 7, 0, 5, 2};
} else {
indices0 = {3, 1, 6, 7};
indices1 = {1, 0, 3, 4};
}
Expr shuffle1 = Internal::Shuffle::make({vec1, vec2}, indices0);
Expr shuffle2 = Internal::Shuffle::make({vec1, vec2}, indices1);
Expr result = shuffle1 * shuffle2;

// Manual logarithmic reduce.
while (result.type().lanes() > 1) {
int half_lanes = result.type().lanes() / 2;
Expr half1 = Halide::Internal::Shuffle::make_slice(result, 0, 1, half_lanes);
Expr half2 = Halide::Internal::Shuffle::make_slice(result, half_lanes, 1, half_lanes);
result = half1 + half2;
}
g(x) = result;

f0.compute_root();
f1.compute_root();
if (target.has_gpu_feature()) {
Var xo, xi;
g.gpu_tile(x, xo, xi, 8).never_partition_all();
}

Buffer<int> im = g.realize({32}, target);
im.copy_to_host();
for (int x = 0; x < 32; x++) {
int exp = 0;
for (size_t i = 0; i < indices0.size(); ++i) {
int v0 = x * (indices0[i] + (indices0[i] >= 4 ? 3 : 1));
int v1 = x * (indices1[i] + (indices1[i] >= 4 ? 3 : 1));
exp += v0 * v1;
}
if (im(x) != exp) {
printf("im[%d] = %d (expected %d)\n", x, im(x), exp);
return 1;
}
}
printf("Success!\n");
return 0;
}

0 comments on commit 2e36da4

Please sign in to comment.