Skip to content

Commit

Permalink
GPU TopK custom call needs the same layout for operand and output.
Browse files Browse the repository at this point in the history
We need to enforce this during layout assignment. Also, the output of TopK
custom call needs to have the default layout.

PiperOrigin-RevId: 674215617
  • Loading branch information
akuegel authored and Google-ML-Automation committed Sep 13, 2024
1 parent baf1170 commit 474036b
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 1 deletion.
1 change: 1 addition & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1890,6 +1890,7 @@ cc_library(
"//xla/service:logical_buffer",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:cublas_cudnn",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:matmul_utils",
"//xla/service/gpu:reduction_utils",
"//xla/service/gpu:stream_executor_util",
Expand Down
18 changes: 17 additions & 1 deletion xla/service/gpu/transforms/layout_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "xla/primitive_util.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/reduction_utils.h"
#include "xla/service/gpu/stream_executor_util.h"
Expand Down Expand Up @@ -441,6 +442,18 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints(
}
TF_RETURN_IF_ERROR(SetBufferLayout(keys_layout, *output_buffer));
}
} else if (IsCustomCallToTopK(*instruction)) {
// The output of the TopK custom call needs to have default layout.
Layout default_layout = LayoutUtil::GetDefaultLayoutForRank(
instruction->operand(0)->shape().rank());
TF_ASSIGN_OR_RETURN(
auto values_buffer,
points_to_analysis_->GetBufferDefinedAt(instruction, {0}));
TF_RETURN_IF_ERROR(SetBufferLayout(default_layout, *values_buffer));
TF_ASSIGN_OR_RETURN(
auto indices_buffer,
points_to_analysis_->GetBufferDefinedAt(instruction, {1}));
TF_RETURN_IF_ERROR(SetBufferLayout(default_layout, *indices_buffer));
} else if (instruction->opcode() == HloOpcode::kTriangularSolve) {
// TODO(phawkins): Ideally we would relax this constraint. What we
// actually want is that:
Expand Down Expand Up @@ -579,13 +592,16 @@ bool GpuLayoutAssignment::InstructionCanChangeLayoutInstance(
// The host offloading custom calls will be eventually removed
// by the offloader, so we need to make sure that the calls do not change
// the layout and thus cause layout mismatches after the removal.
// The TopK custom call cannot handle the case if the operand has a different
// layout.
const HloCustomCallInstruction* custom_call =
DynCast<HloCustomCallInstruction>(instruction);
if (custom_call != nullptr &&
(custom_call->custom_call_target() ==
host_memory_offload_annotations::kMoveToHostCustomCallTarget ||
custom_call->custom_call_target() ==
host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) {
host_memory_offload_annotations::kMoveToDeviceCustomCallTarget ||
custom_call->custom_call_target() == kTopKCustomCallTarget)) {
return false;
}

Expand Down
44 changes: 44 additions & 0 deletions xla/service/gpu/transforms/layout_assignment_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,50 @@ TEST_F(LayoutAssignmentTest, SortLayout) {
m::Op().WithShape(F32, {3, 2}, {1, 0}))));
}

TEST_F(LayoutAssignmentTest, TopKLayout) {
const char* hlo_text = R"(
HloModule topk
compare-greater-than {
p.1.lhs.3 = s32[] parameter(2)
p.1.rhs.4 = s32[] parameter(3)
p.0.lhs.1 = f32[] parameter(0)
bitcast-convert = s32[] bitcast-convert(p.0.lhs.1)
constant = s32[] constant(0)
compare = pred[] compare(bitcast-convert, constant), direction=LT
constant.2 = s32[] constant(2147483647)
xor = s32[] xor(constant.2, bitcast-convert)
select = s32[] select(compare, xor, bitcast-convert)
p.0.rhs.2 = f32[] parameter(1)
bitcast-convert.1 = s32[] bitcast-convert(p.0.rhs.2)
compare.1 = pred[] compare(bitcast-convert.1, constant), direction=LT
xor.1 = s32[] xor(constant.2, bitcast-convert.1)
select.1 = s32[] select(compare.1, xor.1, bitcast-convert.1)
ROOT compare.5 = pred[] compare(select, select.1), direction=GT
}
ENTRY main {
Arg_0.1 = f32[2048,6]{1,0} parameter(0)
t = f32[6,2048]{0,1} transpose(Arg_0.1), dimensions={1,0}
ROOT custom-call.1 = (f32[6,8]{1,0}, s32[6,8]{1,0}) custom-call(t), custom_call_target="__gpu$TopK", api_version=API_VERSION_TYPED_FFI, called_computations={compare-greater-than}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));

ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape(),
/*ignore_layouts=*/false);
GpuLayoutAssignment layout_assignment(
&computation_layout, GetGpuComputeCapability(), GetDnnVersion());

EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));

EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::CustomCall(
m::Transpose(m::Copy().WithShape(F32, {2048, 6}, {0, 1}))
.WithShape(F32, {6, 2048}, {1, 0}))));
}

TEST_F(LayoutAssignmentTest, FftLayout) {
const char* hlo_text = R"(
HloModule Fft_module
Expand Down
14 changes: 14 additions & 0 deletions xla/tests/topk_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,19 @@ XLA_TEST_F(TopkTest, SmallestTopK) {
EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{1e-5, 1e-5}));
}

XLA_TEST_F(TopkTest, TopKOfTranspose) {
// Regression test for b/362565176
std::string_view hlo_text_module = R"(
HloModule topk
ENTRY main {
%Arg_0.1 = f32[2048,6]{1,0} parameter(0)
t = f32[6,2048]{0,1} transpose(%Arg_0.1), dimensions={1,0}
ROOT %topk.2 = (f32[6,8]{1,0}, s32[6,8]{1,0}) topk(t), k=8, largest=true
}
)";
EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{1e-5, 1e-5}));
}

} // namespace
} // namespace xla

0 comments on commit 474036b

Please sign in to comment.