Skip to content

Commit

Permalink
Convert multi-row reduction tests to hlo tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668007266
  • Loading branch information
jreiffers authored and copybara-github committed Aug 27, 2024
1 parent b425085 commit d3d3c22
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 124 deletions.
127 changes: 3 additions & 124 deletions xla/service/gpu/fusions/reduction_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,111 +39,12 @@ using ::testing::ElementsAre;
using ::testing::SizeIs;

template <typename EmitterType>
class ReductionTest : public MlirEmitterTestBase<EmitterType> {
protected:
absl::Status TestBijection(const IndexingMap& map,
absl::Span<int64_t const> shape) {
std::vector<Interval> intervals;
for (int64_t size : shape) {
intervals.push_back({0, size - 1});
}
auto status = VerifyBijection(map, intervals);
if (status.ok()) return status;
return absl::FailedPreconditionError(
absl::StrCat(status.message(), " in map ", map.ToString()));
}
};
class ReductionTest : public MlirEmitterTestBase<EmitterType> {};

using MlirMultiRowReductionTest = ReductionTest<MlirMultiRowReductionFusion>;

constexpr auto kMultiRowReductionX8 = R"(
Add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
fused_computation {
param_0 = f32[1024,4] parameter(0)
param_1 = f32[] parameter(1)
ROOT reduce = f32[1024] reduce(param_0, param_1), dimensions={1}, to_apply=Add
}
ENTRY main {
a = f32[1024,4] parameter(0)
c = f32[] constant(0)
ROOT fusion = f32[1024] fusion(a, c), kind=kInput, calls=fused_computation
})";

constexpr auto kMultiRowReductionX2VectorX4 = R"(
or {
tmp_0 = pred[] parameter(0)
tmp_1 = pred[] parameter(1)
ROOT tmp_2 = pred[] or(tmp_0, tmp_1)
}
fusion {
tmp_0 = f32[76800,16]{1,0} parameter(0)
tmp_1 = f32[] constant(-1.70141173e+38)
tmp_2 = f32[76800,16]{1,0} broadcast(tmp_1), dimensions={}
tmp_3 = pred[76800,16]{1,0} compare(tmp_0, tmp_2), direction=GT
tmp_4 = pred[] constant(false)
tmp_5 = pred[76800]{0} reduce(tmp_3, tmp_4), dimensions={1}, to_apply=or
tmp_6 = f32[76800,16]{1,0} parameter(1)
tmp_7 = pred[76800,16]{1,0} compare(tmp_6, tmp_2), direction=GT
tmp_8 = pred[76800]{0} reduce(tmp_7, tmp_4), dimensions={1}, to_apply=or
ROOT tmp_9 = (pred[76800]{0}, pred[76800]{0}) tuple(tmp_5, tmp_8)
}
ENTRY main {
p0 = f32[76800,16]{1,0} parameter(0)
p1 = f32[76800,16]{1,0} parameter(1)
ROOT fusion = (pred[76800]{0}, pred[76800]{0}) fusion(p0, p1), kind=kInput, calls=fusion
})";

constexpr auto kMultiRowReductionX16VectorX2 = R"(
or {
tmp_0 = pred[] parameter(0)
tmp_1 = pred[] parameter(1)
ROOT tmp_2 = pred[] or(tmp_0, tmp_1)
}
fusion {
p0 = pred[76800,2] parameter(0)
c0 = pred[] constant(false)
ROOT reduce = pred[76800] reduce(p0, c0), dimensions={1}, to_apply=or
}
ENTRY main {
p0 = pred[76800,2] parameter(0)
ROOT fusion = pred[76800] fusion(p0), kind=kInput, calls=fusion
})";

TEST_F(MlirMultiRowReductionTest, MultiRowReductionIndexing) {
auto fusion = GetEmitter(kMultiRowReductionX8);

TF_EXPECT_OK(TestBijection(
*fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_),
{1024, 4}));
TF_EXPECT_OK(TestBijection(
*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), {1024}));
EXPECT_EQ(Product(GetLoopTripCounts(
*fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_))),
1);
}

TEST_F(MlirMultiRowReductionTest, MultiRowReductionIr) {
// Multi-row reductions don't use shared memory.
TF_ASSERT_OK(EmitAndCheckIR(kMultiRowReductionX8, R"(
// CHECK: shuffle_reduce {{.*}} to 2
// CHECK-NOT: allocate_shared
)"));
}

TEST_F(MlirMultiRowReductionTest, MultiRowReductionCorrectness) {
EXPECT_TRUE(RunAndCompareNoHloPasses(kMultiRowReductionX8, ErrorSpec{1e-3}));
}

TEST_F(MlirMultiRowReductionTest, TwoGroups) {
// TODO(jreiffers): Move this test to reduction_base_test.
auto module = ParseAndReturnVerifiedModule(R"(
add {
p0 = f32[] parameter(0)
Expand Down Expand Up @@ -176,6 +77,7 @@ TEST_F(MlirMultiRowReductionTest, TwoGroups) {
}

TEST_F(MlirMultiRowReductionTest, OneGroup) {
// TODO(jreiffers): Move this test to reduction_base_test.
auto module = ParseAndReturnVerifiedModule(R"(
%add {
%p0 = c128[] parameter(0)
Expand Down Expand Up @@ -204,29 +106,6 @@ TEST_F(MlirMultiRowReductionTest, OneGroup) {
EXPECT_THAT(mlir_fusion.GetGroups().grouped_roots, SizeIs(1));
}

TEST_F(MlirMultiRowReductionTest, VectorizedX4Indexing) {
auto fusion = GetEmitter(kMultiRowReductionX2VectorX4);

TF_EXPECT_OK(TestBijection(
*fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_),
{76800, 16}));
TF_EXPECT_OK(TestBijection(
*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), {76800}));
EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing(
0, 0, &mlir_context_)),
ElementsAre(1 /* major reduced */, 4 /* vector size */));
}

TEST_F(MlirMultiRowReductionTest, LimitedVectorizationCorrectness) {
EXPECT_TRUE(
RunAndCompareNoHloPasses(kMultiRowReductionX16VectorX2, ErrorSpec{1e-3}));
}

TEST_F(MlirMultiRowReductionTest, VectorizedX4Correctness) {
EXPECT_TRUE(
RunAndCompareNoHloPasses(kMultiRowReductionX2VectorX4, ErrorSpec{1e-3}));
}

} // namespace
} // namespace gpu
} // namespace xla
20 changes: 20 additions & 0 deletions xla/service/gpu/fusions/tests/reduce_multirow/f32_x8.hlo
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: fusion_to_mlir %s | FileCheck %s
// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce

add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}

fused_computation {
param_0 = f32[1024,4] parameter(0)
c = f32[] constant(0)
ROOT reduce = f32[1024] reduce(param_0, c), dimensions={1}, to_apply=add
}

// Multi-row reductions do not use shared memory.
// CHECK-NOT: allocate_shared
// There should be 8 elements per warp.
// CHECK: shuffle_reduce {{.*}} to 2
// CHECK-NOT: allocate_shared
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-to-inline -xla-gpu-test-vectorize | FileCheck %s
// RUN: test_correctness %s

or {
tmp_0 = pred[] parameter(0)
tmp_1 = pred[] parameter(1)
ROOT tmp_2 = pred[] or(tmp_0, tmp_1)
}

fusion {
p0 = pred[76800,2] parameter(0)
c0 = pred[] constant(false)
ROOT reduce = pred[76800] reduce(p0, c0), dimensions={1}, to_apply=or
}

// Normally, we would attempt to vectorize this to v4. However, codegen does not
// currently support a larger vector size than row width.

// CHECK: vector.transfer_read {{.*}} vector<2xi8>
26 changes: 26 additions & 0 deletions xla/service/gpu/fusions/tests/reduce_multirow/pred_mof_x2_v4.hlo
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-to-inline -xla-gpu-test-vectorize | FileCheck %s
// RUN: test_correctness %s --bijection_inputs=tmp_5:0 --bijection_inputs=tmp_8:0 --bijection_outputs=tmp_5 --bijection_outputs=tmp_8

or {
tmp_0 = pred[] parameter(0)
tmp_1 = pred[] parameter(1)
ROOT tmp_2 = pred[] or(tmp_0, tmp_1)
}

fusion {
tmp_0 = f32[7680,16] parameter(0)
tmp_1 = f32[] constant(-1.70141173e+38)
tmp_2 = f32[7680,16] broadcast(tmp_1), dimensions={}
tmp_3 = pred[7680,16] compare(tmp_0, tmp_2), direction=GT
tmp_4 = pred[] constant(false)
tmp_5 = pred[7680] reduce(tmp_3, tmp_4), dimensions={1}, to_apply=or
tmp_6 = f32[7680,16] parameter(1)
tmp_7 = pred[7680,16] compare(tmp_6, tmp_2), direction=GT
tmp_8 = pred[7680] reduce(tmp_7, tmp_4), dimensions={1}, to_apply=or
ROOT tmp_9 = (pred[7680], pred[7680]) tuple(tmp_5, tmp_8)
}

// CHECK: vector.transfer_read {{.*}} vector<4xf32>
// CHECK: xla_gpu.shuffle_reduce @or_tmp_2
// CHECK: vector.transfer_read {{.*}} vector<4xf32>
// CHECK: xla_gpu.shuffle_reduce @or_tmp_2

0 comments on commit d3d3c22

Please sign in to comment.