From 410974f533c1846386f001ed97d93a8291e77b1a Mon Sep 17 00:00:00 2001 From: Shraiysh Vaishay Date: Thu, 19 Sep 2024 21:25:57 +0000 Subject: [PATCH] Fix the command buffer scheduling pass return value Fixes #17216. Returns true when either parameters are moved, or command buffer is created. --- .../transforms/command_buffer_scheduling.cc | 19 ++++++-- .../transforms/command_buffer_scheduling.h | 4 +- .../command_buffer_scheduling_test.cc | 47 +++++++++++++++++++ 3 files changed, 64 insertions(+), 6 deletions(-) diff --git a/xla/service/gpu/transforms/command_buffer_scheduling.cc b/xla/service/gpu/transforms/command_buffer_scheduling.cc index 3d8f11dd0dc7b..641a37a9659d3 100644 --- a/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -438,7 +438,9 @@ CommandBufferScheduling::CollectCommandBufferSequences( // the beginning of the computation. This simplifies the construction of command // buffer computations because we don't need to deal with parameters and // constants that have users outside of a command buffer. -absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront( +// Returns true if there is a change in the order of instructions, false +// otherwise. +absl::StatusOr CommandBufferScheduling::MoveParametersAndConstantsToFront( HloComputation* computation) { HloInstructionSequence new_sequence; HloSchedule& schedule = computation->parent()->schedule(); @@ -468,7 +470,11 @@ absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront( } schedule.set_sequence(computation, new_sequence); - return absl::OkStatus(); + for (auto [old_i, new_i] : + llvm::zip(sequence.instructions(), new_sequence.instructions())) { + if (old_i != new_i) return true; + } + return false; } //===----------------------------------------------------------------------===// @@ -767,7 +773,7 @@ absl::StatusOr CommandBufferScheduling::Run( if (std::min(device_description_.runtime_version(), device_description_.driver_version()) < se::SemanticVersion{12, 3, 0}) { - erase(kRequireTracing); // cuStreamBeginCaptureToGraph + erase(kRequireTracing); // cuStreamBeginCaptureToGraph } if (std::min(device_description_.runtime_version(), device_description_.driver_version()) < @@ -787,6 +793,7 @@ absl::StatusOr CommandBufferScheduling::Run( std::reverse(order.begin(), order.end()); absl::flat_hash_set processed_command_buffers; + auto changed = false; for (HloComputation* comp : order) { // Skip special computations that do not have lowering to thunks. if (comp->IsFusionComputation() || comp->IsAsyncComputation() || @@ -796,7 +803,8 @@ absl::StatusOr CommandBufferScheduling::Run( // Skip computations that already part of command buffers. if (processed_command_buffers.contains(comp)) continue; - TF_RETURN_IF_ERROR(MoveParametersAndConstantsToFront(comp)); + TF_ASSIGN_OR_RETURN(bool changed_, MoveParametersAndConstantsToFront(comp)); + changed |= changed_; std::vector sequences = CollectCommandBufferSequences( @@ -809,6 +817,7 @@ absl::StatusOr CommandBufferScheduling::Run( TF_ASSIGN_OR_RETURN( HloComputation * command_buffer_computation, RewriteCommandBuffer(comp, seq, std::move(command_buffer))); + changed = true; // All computations reachable from a command buffer computation are nested // command buffers (i.e. body computations attached to a while operation). @@ -820,7 +829,7 @@ absl::StatusOr CommandBufferScheduling::Run( } TF_RETURN_IF_ERROR(module->schedule().Update()); - return true; + return changed; } } // namespace xla::gpu diff --git a/xla/service/gpu/transforms/command_buffer_scheduling.h b/xla/service/gpu/transforms/command_buffer_scheduling.h index 15f0b2dd4d4da..71d5b421c1ee5 100644 --- a/xla/service/gpu/transforms/command_buffer_scheduling.h +++ b/xla/service/gpu/transforms/command_buffer_scheduling.h @@ -99,7 +99,9 @@ class CommandBufferScheduling : public HloModulePass { // the beginning of the computation. This simplifies the construction of // command buffer computations because we don't need to deal with parameters // and constants that have users outside of a command buffer. - static absl::Status MoveParametersAndConstantsToFront( + // Returns true if there is a change in the order of instructions, false + // otherwise. + static absl::StatusOr MoveParametersAndConstantsToFront( HloComputation* computation); struct CommandBuffer { diff --git a/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index 6c79316a75a51..be29a09897b54 100644 --- a/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -1228,5 +1228,52 @@ TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionStaticSlicing) { false, true, std::nullopt)); } +TEST_F(CommandBufferSchedulingTest, ReturnFalseWhenNoChange) { + const char* hlo = R"( + HloModule module, is_scheduled=true + ENTRY main { + a = s32[8,8] parameter(0) + b = s32[8,8] parameter(1) + ROOT call = s32[8,8] custom-call(a,b), custom_call_target="__cublas$gemm" + } + )"; + + HloModuleConfig config; + DebugOptions options = GetDebugOptionsForTest(); + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo, config)); + RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()), + std::nullopt); +} + +TEST_F(CommandBufferSchedulingTest, ReturnTrueWhenOnlyParamMoved) { + const char* hlo = R"( + HloModule module, is_scheduled=true + ENTRY main { + a = s32[8,8] parameter(0) + b = s32[8,8] parameter(1) + call = s32[8,8] custom-call(a,b), custom_call_target="__cublas$gemm" + c = s32[8,8] parameter(2) + ROOT call2 = s32[8,8] custom-call(call, c), custom_call_target="__cublas$gemm" + } + )"; + + HloModuleConfig config; + DebugOptions options = GetDebugOptionsForTest(); + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo, config)); + RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()), R"( + // CHECK: %{{.+}} = {{.+}} parameter(0) + // CHECK: %{{.+}} = {{.+}} parameter(1) + // CHECK: %{{.+}} = {{.+}} parameter(2) + // CHECK: %{{.+}} = {{.+}} custom-call + // CHECK: %{{.+}} = {{.+}} custom-call + )"); +} + } // namespace } // namespace xla::gpu