Skip to content

Commit

Permalink
PR #17257: Fix the command buffer scheduling pass return value
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17257

Fixes #17216. Returns true when either parameters are moved, or command buffer is created.
Copybara import of the project:

--
f9dc3e2 by Shraiysh Vaishay <[email protected]>:

Fix the command buffer scheduling pass return value

Fixes #17216. Returns true when either parameters are moved, or command
buffer is created.

Merging this change closes #17257

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17257 from shraiysh:fix_17216 f9dc3e2
PiperOrigin-RevId: 675625838
  • Loading branch information
shraiysh authored and Google-ML-Automation committed Sep 17, 2024
1 parent 8ace4ee commit c3f4eb8
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 5 deletions.
17 changes: 13 additions & 4 deletions xla/service/gpu/transforms/command_buffer_scheduling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,9 @@ CommandBufferScheduling::CollectCommandBufferSequences(
// the beginning of the computation. This simplifies the construction of command
// buffer computations because we don't need to deal with parameters and
// constants that have users outside of a command buffer.
absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront(
// Returns true if there is a change in the order of instructions, false
// otherwise.
absl::StatusOr<bool> CommandBufferScheduling::MoveParametersAndConstantsToFront(
HloComputation* computation) {
HloInstructionSequence new_sequence;
HloSchedule& schedule = computation->parent()->schedule();
Expand Down Expand Up @@ -463,7 +465,11 @@ absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront(
}

schedule.set_sequence(computation, new_sequence);
return absl::OkStatus();
for (auto [old_i, new_i] :
llvm::zip(sequence.instructions(), new_sequence.instructions())) {
if (old_i != new_i) return true;
}
return false;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -777,6 +783,7 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
std::reverse(order.begin(), order.end());
absl::flat_hash_set<HloComputation*> processed_command_buffers;

auto changed = false;
for (HloComputation* comp : order) {
// Skip special computations that do not have lowering to thunks.
if (comp->IsFusionComputation() || comp->IsAsyncComputation() ||
Expand All @@ -786,7 +793,8 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
// Skip computations that already part of command buffers.
if (processed_command_buffers.contains(comp)) continue;

TF_RETURN_IF_ERROR(MoveParametersAndConstantsToFront(comp));
TF_ASSIGN_OR_RETURN(bool changed_, MoveParametersAndConstantsToFront(comp));
changed |= changed_;

std::vector<HloInstructionSequence> sequences =
CollectCommandBufferSequences(
Expand All @@ -799,6 +807,7 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
TF_ASSIGN_OR_RETURN(
HloComputation * command_buffer_computation,
RewriteCommandBuffer(comp, seq, std::move(command_buffer)));
changed = true;

// All computations reachable from a command buffer computation are nested
// command buffers (i.e. body computations attached to a while operation).
Expand All @@ -810,7 +819,7 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
}
TF_RETURN_IF_ERROR(module->schedule().Update());

return true;
return changed;
}

} // namespace xla::gpu
4 changes: 3 additions & 1 deletion xla/service/gpu/transforms/command_buffer_scheduling.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class CommandBufferScheduling : public HloModulePass {
// the beginning of the computation. This simplifies the construction of
// command buffer computations because we don't need to deal with parameters
// and constants that have users outside of a command buffer.
static absl::Status MoveParametersAndConstantsToFront(
// Returns true if there is a change in the order of instructions, false
// otherwise.
static absl::StatusOr<bool> MoveParametersAndConstantsToFront(
HloComputation* computation);

struct CommandBuffer {
Expand Down
47 changes: 47 additions & 0 deletions xla/service/gpu/transforms/command_buffer_scheduling_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1071,5 +1071,52 @@ TEST_F(CommandBufferSchedulingTest, AsyncFusion) {
});
}

TEST_F(CommandBufferSchedulingTest, ReturnFalseWhenNoChange) {
const char* hlo = R"(
HloModule module, is_scheduled=true
ENTRY main {
a = s32[8,8] parameter(0)
b = s32[8,8] parameter(1)
ROOT call = s32[8,8] custom-call(a,b), custom_call_target="__cublas$gemm"
}
)";

HloModuleConfig config;
DebugOptions options = GetDebugOptionsForTest();
options.clear_xla_gpu_enable_command_buffer();
options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES);
config.set_debug_options(options);
TF_ASSIGN_OR_RETURN(auto m, ParseAndReturnVerifiedModule(hlo, config));
RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()),
std::nullopt);
}

TEST_F(CommandBufferSchedulingTest, ReturnTrueWhenOnlyParamMoved) {
const char* hlo = R"(
HloModule module, is_scheduled=true
ENTRY main {
a = s32[8,8] parameter(0)
b = s32[8,8] parameter(1)
call = s32[8,8] custom-call(a,b), custom_call_target="__cublas$gemm"
c = s32[8,8] parameter(2)
ROOT call2 = s32[8,8] custom-call(call, c), custom_call_target="__cublas$gemm"
}
)";

HloModuleConfig config;
DebugOptions options = GetDebugOptionsForTest();
options.clear_xla_gpu_enable_command_buffer();
options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES);
config.set_debug_options(options);
TF_ASSIGN_OR_RETURN(auto m, ParseAndReturnVerifiedModule(hlo, config));
RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()), R"(
// CHECK: %{{.+}} = {{.+}} parameter(0)
// CHECK: %{{.+}} = {{.+}} parameter(1)
// CHECK: %{{.+}} = {{.+}} parameter(2)
// CHECK: %{{.+}} = {{.+}} custom-call
// CHECK: %{{.+}} = {{.+}} custom-call
)");
}

} // namespace
} // namespace xla::gpu

0 comments on commit c3f4eb8

Please sign in to comment.