diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index d7be78888b792..c481e29083b5d 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -1890,6 +1890,7 @@ cc_library( "//xla/service:logical_buffer", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:reduction_utils", "//xla/service/gpu:stream_executor_util", diff --git a/xla/service/gpu/transforms/layout_assignment.cc b/xla/service/gpu/transforms/layout_assignment.cc index caa8d3c10f90e..99f306cfc6424 100644 --- a/xla/service/gpu/transforms/layout_assignment.cc +++ b/xla/service/gpu/transforms/layout_assignment.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/reduction_utils.h" #include "xla/service/gpu/stream_executor_util.h" @@ -441,6 +442,18 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( } TF_RETURN_IF_ERROR(SetBufferLayout(keys_layout, *output_buffer)); } + } else if (IsCustomCallToTopK(*instruction)) { + // The output of the TopK custom call needs to have default layout. + Layout default_layout = LayoutUtil::GetDefaultLayoutForRank( + instruction->operand(0)->shape().rank()); + TF_ASSIGN_OR_RETURN( + auto values_buffer, + points_to_analysis_->GetBufferDefinedAt(instruction, {0})); + TF_RETURN_IF_ERROR(SetBufferLayout(default_layout, *values_buffer)); + TF_ASSIGN_OR_RETURN( + auto indices_buffer, + points_to_analysis_->GetBufferDefinedAt(instruction, {1})); + TF_RETURN_IF_ERROR(SetBufferLayout(default_layout, *indices_buffer)); } else if (instruction->opcode() == HloOpcode::kTriangularSolve) { // TODO(phawkins): Ideally we would relax this constraint. What we // actually want is that: @@ -579,13 +592,16 @@ bool GpuLayoutAssignment::InstructionCanChangeLayoutInstance( // The host offloading custom calls will be eventually removed // by the offloader, so we need to make sure that the calls do not change // the layout and thus cause layout mismatches after the removal. + // The TopK custom call cannot handle the case if the operand has a different + // layout. const HloCustomCallInstruction* custom_call = DynCast(instruction); if (custom_call != nullptr && (custom_call->custom_call_target() == host_memory_offload_annotations::kMoveToHostCustomCallTarget || custom_call->custom_call_target() == - host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) { + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget || + custom_call->custom_call_target() == kTopKCustomCallTarget)) { return false; } diff --git a/xla/service/gpu/transforms/layout_assignment_test.cc b/xla/service/gpu/transforms/layout_assignment_test.cc index f9f90f9923158..4dbd453e1d485 100644 --- a/xla/service/gpu/transforms/layout_assignment_test.cc +++ b/xla/service/gpu/transforms/layout_assignment_test.cc @@ -360,6 +360,50 @@ TEST_F(LayoutAssignmentTest, SortLayout) { m::Op().WithShape(F32, {3, 2}, {1, 0})))); } +TEST_F(LayoutAssignmentTest, TopKLayout) { + const char* hlo_text = R"( + HloModule topk + + compare-greater-than { + p.1.lhs.3 = s32[] parameter(2) + p.1.rhs.4 = s32[] parameter(3) + p.0.lhs.1 = f32[] parameter(0) + bitcast-convert = s32[] bitcast-convert(p.0.lhs.1) + constant = s32[] constant(0) + compare = pred[] compare(bitcast-convert, constant), direction=LT + constant.2 = s32[] constant(2147483647) + xor = s32[] xor(constant.2, bitcast-convert) + select = s32[] select(compare, xor, bitcast-convert) + p.0.rhs.2 = f32[] parameter(1) + bitcast-convert.1 = s32[] bitcast-convert(p.0.rhs.2) + compare.1 = pred[] compare(bitcast-convert.1, constant), direction=LT + xor.1 = s32[] xor(constant.2, bitcast-convert.1) + select.1 = s32[] select(compare.1, xor.1, bitcast-convert.1) + ROOT compare.5 = pred[] compare(select, select.1), direction=GT + } + + ENTRY main { + Arg_0.1 = f32[2048,6]{1,0} parameter(0) + t = f32[6,2048]{0,1} transpose(Arg_0.1), dimensions={1,0} + ROOT custom-call.1 = (f32[6,8]{1,0}, s32[6,8]{1,0}) custom-call(t), custom_call_target="__gpu$TopK", api_version=API_VERSION_TYPED_FFI, called_computations={compare-greater-than} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + + EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::CustomCall( + m::Transpose(m::Copy().WithShape(F32, {2048, 6}, {0, 1})) + .WithShape(F32, {6, 2048}, {1, 0})))); +} + TEST_F(LayoutAssignmentTest, FftLayout) { const char* hlo_text = R"( HloModule Fft_module diff --git a/xla/tests/topk_test.cc b/xla/tests/topk_test.cc index 96237ba5a6a95..ec1059977a8b2 100644 --- a/xla/tests/topk_test.cc +++ b/xla/tests/topk_test.cc @@ -49,5 +49,19 @@ XLA_TEST_F(TopkTest, SmallestTopK) { EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{1e-5, 1e-5})); } +XLA_TEST_F(TopkTest, TopKOfTranspose) { + // Regression test for b/362565176 + std::string_view hlo_text_module = R"( + HloModule topk + + ENTRY main { + %Arg_0.1 = f32[2048,6]{1,0} parameter(0) + t = f32[6,2048]{0,1} transpose(%Arg_0.1), dimensions={1,0} + ROOT %topk.2 = (f32[6,8]{1,0}, s32[6,8]{1,0}) topk(t), k=8, largest=true + } + )"; + EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{1e-5, 1e-5})); +} + } // namespace } // namespace xla