From c3f4eb8f167ee0b945a8579cebc040d8af86f476 Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Tue, 17 Sep 2024 10:43:35 -0700 Subject: [PATCH] PR #17257: Fix the command buffer scheduling pass return value Imported from GitHub PR https://github.com/openxla/xla/pull/17257 Fixes #17216. Returns true when either parameters are moved, or command buffer is created. Copybara import of the project: -- f9dc3e277e414e8d1de423221ca3e2d9b0ae3c9e by Shraiysh Vaishay : Fix the command buffer scheduling pass return value Fixes #17216. Returns true when either parameters are moved, or command buffer is created. Merging this change closes #17257 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/17257 from shraiysh:fix_17216 f9dc3e277e414e8d1de423221ca3e2d9b0ae3c9e PiperOrigin-RevId: 675625838 --- .../transforms/command_buffer_scheduling.cc | 17 +++++-- .../transforms/command_buffer_scheduling.h | 4 +- .../command_buffer_scheduling_test.cc | 47 +++++++++++++++++++ 3 files changed, 63 insertions(+), 5 deletions(-) diff --git a/xla/service/gpu/transforms/command_buffer_scheduling.cc b/xla/service/gpu/transforms/command_buffer_scheduling.cc index a16b908c163304..9d76231aa4ad9a 100644 --- a/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -433,7 +433,9 @@ CommandBufferScheduling::CollectCommandBufferSequences( // the beginning of the computation. This simplifies the construction of command // buffer computations because we don't need to deal with parameters and // constants that have users outside of a command buffer. -absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront( +// Returns true if there is a change in the order of instructions, false +// otherwise. +absl::StatusOr CommandBufferScheduling::MoveParametersAndConstantsToFront( HloComputation* computation) { HloInstructionSequence new_sequence; HloSchedule& schedule = computation->parent()->schedule(); @@ -463,7 +465,11 @@ absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront( } schedule.set_sequence(computation, new_sequence); - return absl::OkStatus(); + for (auto [old_i, new_i] : + llvm::zip(sequence.instructions(), new_sequence.instructions())) { + if (old_i != new_i) return true; + } + return false; } //===----------------------------------------------------------------------===// @@ -777,6 +783,7 @@ absl::StatusOr CommandBufferScheduling::Run( std::reverse(order.begin(), order.end()); absl::flat_hash_set processed_command_buffers; + auto changed = false; for (HloComputation* comp : order) { // Skip special computations that do not have lowering to thunks. if (comp->IsFusionComputation() || comp->IsAsyncComputation() || @@ -786,7 +793,8 @@ absl::StatusOr CommandBufferScheduling::Run( // Skip computations that already part of command buffers. if (processed_command_buffers.contains(comp)) continue; - TF_RETURN_IF_ERROR(MoveParametersAndConstantsToFront(comp)); + TF_ASSIGN_OR_RETURN(bool changed_, MoveParametersAndConstantsToFront(comp)); + changed |= changed_; std::vector sequences = CollectCommandBufferSequences( @@ -799,6 +807,7 @@ absl::StatusOr CommandBufferScheduling::Run( TF_ASSIGN_OR_RETURN( HloComputation * command_buffer_computation, RewriteCommandBuffer(comp, seq, std::move(command_buffer))); + changed = true; // All computations reachable from a command buffer computation are nested // command buffers (i.e. body computations attached to a while operation). @@ -810,7 +819,7 @@ absl::StatusOr CommandBufferScheduling::Run( } TF_RETURN_IF_ERROR(module->schedule().Update()); - return true; + return changed; } } // namespace xla::gpu diff --git a/xla/service/gpu/transforms/command_buffer_scheduling.h b/xla/service/gpu/transforms/command_buffer_scheduling.h index edbf9bf8f95912..5731a157fdb682 100644 --- a/xla/service/gpu/transforms/command_buffer_scheduling.h +++ b/xla/service/gpu/transforms/command_buffer_scheduling.h @@ -99,7 +99,9 @@ class CommandBufferScheduling : public HloModulePass { // the beginning of the computation. This simplifies the construction of // command buffer computations because we don't need to deal with parameters // and constants that have users outside of a command buffer. - static absl::Status MoveParametersAndConstantsToFront( + // Returns true if there is a change in the order of instructions, false + // otherwise. + static absl::StatusOr MoveParametersAndConstantsToFront( HloComputation* computation); struct CommandBuffer { diff --git a/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index 3bffa6eaa621ed..553e683e2a9b2c 100644 --- a/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -1071,5 +1071,52 @@ TEST_F(CommandBufferSchedulingTest, AsyncFusion) { }); } +TEST_F(CommandBufferSchedulingTest, ReturnFalseWhenNoChange) { + const char* hlo = R"( + HloModule module, is_scheduled=true + ENTRY main { + a = s32[8,8] parameter(0) + b = s32[8,8] parameter(1) + ROOT call = s32[8,8] custom-call(a,b), custom_call_target="__cublas$gemm" + } + )"; + + HloModuleConfig config; + DebugOptions options = GetDebugOptionsForTest(); + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + config.set_debug_options(options); + TF_ASSIGN_OR_RETURN(auto m, ParseAndReturnVerifiedModule(hlo, config)); + RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()), + std::nullopt); +} + +TEST_F(CommandBufferSchedulingTest, ReturnTrueWhenOnlyParamMoved) { + const char* hlo = R"( + HloModule module, is_scheduled=true + ENTRY main { + a = s32[8,8] parameter(0) + b = s32[8,8] parameter(1) + call = s32[8,8] custom-call(a,b), custom_call_target="__cublas$gemm" + c = s32[8,8] parameter(2) + ROOT call2 = s32[8,8] custom-call(call, c), custom_call_target="__cublas$gemm" + } + )"; + + HloModuleConfig config; + DebugOptions options = GetDebugOptionsForTest(); + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + config.set_debug_options(options); + TF_ASSIGN_OR_RETURN(auto m, ParseAndReturnVerifiedModule(hlo, config)); + RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()), R"( + // CHECK: %{{.+}} = {{.+}} parameter(0) + // CHECK: %{{.+}} = {{.+}} parameter(1) + // CHECK: %{{.+}} = {{.+}} parameter(2) + // CHECK: %{{.+}} = {{.+}} custom-call + // CHECK: %{{.+}} = {{.+}} custom-call + )"); +} + } // namespace } // namespace xla::gpu