From 474036b950cbb934f09f5485633997fb379e6364 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Fri, 13 Sep 2024 02:41:18 -0700 Subject: [PATCH] GPU TopK custom call needs the same layout for operand and output. We need to enforce this during layout assignment. Also, the output of TopK custom call needs to have the default layout. PiperOrigin-RevId: 674215617 --- xla/service/gpu/transforms/BUILD | 1 + .../gpu/transforms/layout_assignment.cc | 18 +++++++- .../gpu/transforms/layout_assignment_test.cc | 44 +++++++++++++++++++ xla/tests/topk_test.cc | 14 ++++++ 4 files changed, 76 insertions(+), 1 deletion(-) diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index d7be78888b792..c481e29083b5d 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -1890,6 +1890,7 @@ cc_library( "//xla/service:logical_buffer", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:reduction_utils", "//xla/service/gpu:stream_executor_util", diff --git a/xla/service/gpu/transforms/layout_assignment.cc b/xla/service/gpu/transforms/layout_assignment.cc index caa8d3c10f90e..99f306cfc6424 100644 --- a/xla/service/gpu/transforms/layout_assignment.cc +++ b/xla/service/gpu/transforms/layout_assignment.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/reduction_utils.h" #include "xla/service/gpu/stream_executor_util.h" @@ -441,6 +442,18 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( } TF_RETURN_IF_ERROR(SetBufferLayout(keys_layout, *output_buffer)); } + } else if (IsCustomCallToTopK(*instruction)) { + // The output of the TopK custom call needs to have default layout. + Layout default_layout = LayoutUtil::GetDefaultLayoutForRank( + instruction->operand(0)->shape().rank()); + TF_ASSIGN_OR_RETURN( + auto values_buffer, + points_to_analysis_->GetBufferDefinedAt(instruction, {0})); + TF_RETURN_IF_ERROR(SetBufferLayout(default_layout, *values_buffer)); + TF_ASSIGN_OR_RETURN( + auto indices_buffer, + points_to_analysis_->GetBufferDefinedAt(instruction, {1})); + TF_RETURN_IF_ERROR(SetBufferLayout(default_layout, *indices_buffer)); } else if (instruction->opcode() == HloOpcode::kTriangularSolve) { // TODO(phawkins): Ideally we would relax this constraint. What we // actually want is that: @@ -579,13 +592,16 @@ bool GpuLayoutAssignment::InstructionCanChangeLayoutInstance( // The host offloading custom calls will be eventually removed // by the offloader, so we need to make sure that the calls do not change // the layout and thus cause layout mismatches after the removal. + // The TopK custom call cannot handle the case if the operand has a different + // layout. const HloCustomCallInstruction* custom_call = DynCast(instruction); if (custom_call != nullptr && (custom_call->custom_call_target() == host_memory_offload_annotations::kMoveToHostCustomCallTarget || custom_call->custom_call_target() == - host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget || + custom_call->custom_call_target() == kTopKCustomCallTarget)) { return false; } diff --git a/xla/service/gpu/transforms/layout_assignment_test.cc b/xla/service/gpu/transforms/layout_assignment_test.cc index f9f90f9923158..4dbd453e1d485 100644 --- a/xla/service/gpu/transforms/layout_assignment_test.cc +++ b/xla/service/gpu/transforms/layout_assignment_test.cc @@ -360,6 +360,50 @@ TEST_F(LayoutAssignmentTest, SortLayout) { m::Op().WithShape(F32, {3, 2}, {1, 0})))); } +TEST_F(LayoutAssignmentTest, TopKLayout) { + const char* hlo_text = R"( + HloModule topk + + compare-greater-than { + p.1.lhs.3 = s32[] parameter(2) + p.1.rhs.4 = s32[] parameter(3) + p.0.lhs.1 = f32[] parameter(0) + bitcast-convert = s32[] bitcast-convert(p.0.lhs.1) + constant = s32[] constant(0) + compare = pred[] compare(bitcast-convert, constant), direction=LT + constant.2 = s32[] constant(2147483647) + xor = s32[] xor(constant.2, bitcast-convert) + select = s32[] select(compare, xor, bitcast-convert) + p.0.rhs.2 = f32[] parameter(1) + bitcast-convert.1 = s32[] bitcast-convert(p.0.rhs.2) + compare.1 = pred[] compare(bitcast-convert.1, constant), direction=LT + xor.1 = s32[] xor(constant.2, bitcast-convert.1) + select.1 = s32[] select(compare.1, xor.1, bitcast-convert.1) + ROOT compare.5 = pred[] compare(select, select.1), direction=GT + } + + ENTRY main { + Arg_0.1 = f32[2048,6]{1,0} parameter(0) + t = f32[6,2048]{0,1} transpose(Arg_0.1), dimensions={1,0} + ROOT custom-call.1 = (f32[6,8]{1,0}, s32[6,8]{1,0}) custom-call(t), custom_call_target="__gpu$TopK", api_version=API_VERSION_TYPED_FFI, called_computations={compare-greater-than} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + + EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::CustomCall( + m::Transpose(m::Copy().WithShape(F32, {2048, 6}, {0, 1})) + .WithShape(F32, {6, 2048}, {1, 0})))); +} + TEST_F(LayoutAssignmentTest, FftLayout) { const char* hlo_text = R"( HloModule Fft_module diff --git a/xla/tests/topk_test.cc b/xla/tests/topk_test.cc index 96237ba5a6a95..ec1059977a8b2 100644 --- a/xla/tests/topk_test.cc +++ b/xla/tests/topk_test.cc @@ -49,5 +49,19 @@ XLA_TEST_F(TopkTest, SmallestTopK) { EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{1e-5, 1e-5})); } +XLA_TEST_F(TopkTest, TopKOfTranspose) { + // Regression test for b/362565176 + std::string_view hlo_text_module = R"( + HloModule topk + + ENTRY main { + %Arg_0.1 = f32[2048,6]{1,0} parameter(0) + t = f32[6,2048]{0,1} transpose(%Arg_0.1), dimensions={1,0} + ROOT %topk.2 = (f32[6,8]{1,0}, s32[6,8]{1,0}) topk(t), k=8, largest=true + } + )"; + EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{1e-5, 1e-5})); +} + } // namespace } // namespace xla