diff --git a/xla/service/gpu/transforms/command_buffer_scheduling.cc b/xla/service/gpu/transforms/command_buffer_scheduling.cc index 3d8f11dd0dc7b..641a37a9659d3 100644 --- a/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -438,7 +438,9 @@ CommandBufferScheduling::CollectCommandBufferSequences( // the beginning of the computation. This simplifies the construction of command // buffer computations because we don't need to deal with parameters and // constants that have users outside of a command buffer. -absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront( +// Returns true if there is a change in the order of instructions, false +// otherwise. +absl::StatusOr CommandBufferScheduling::MoveParametersAndConstantsToFront( HloComputation* computation) { HloInstructionSequence new_sequence; HloSchedule& schedule = computation->parent()->schedule(); @@ -468,7 +470,11 @@ absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront( } schedule.set_sequence(computation, new_sequence); - return absl::OkStatus(); + for (auto [old_i, new_i] : + llvm::zip(sequence.instructions(), new_sequence.instructions())) { + if (old_i != new_i) return true; + } + return false; } //===----------------------------------------------------------------------===// @@ -767,7 +773,7 @@ absl::StatusOr CommandBufferScheduling::Run( if (std::min(device_description_.runtime_version(), device_description_.driver_version()) < se::SemanticVersion{12, 3, 0}) { - erase(kRequireTracing); // cuStreamBeginCaptureToGraph + erase(kRequireTracing); // cuStreamBeginCaptureToGraph } if (std::min(device_description_.runtime_version(), device_description_.driver_version()) < @@ -787,6 +793,7 @@ absl::StatusOr CommandBufferScheduling::Run( std::reverse(order.begin(), order.end()); absl::flat_hash_set processed_command_buffers; + auto changed = false; for (HloComputation* comp : order) { // Skip special computations that do not have lowering to thunks. if (comp->IsFusionComputation() || comp->IsAsyncComputation() || @@ -796,7 +803,8 @@ absl::StatusOr CommandBufferScheduling::Run( // Skip computations that already part of command buffers. if (processed_command_buffers.contains(comp)) continue; - TF_RETURN_IF_ERROR(MoveParametersAndConstantsToFront(comp)); + TF_ASSIGN_OR_RETURN(bool changed_, MoveParametersAndConstantsToFront(comp)); + changed |= changed_; std::vector sequences = CollectCommandBufferSequences( @@ -809,6 +817,7 @@ absl::StatusOr CommandBufferScheduling::Run( TF_ASSIGN_OR_RETURN( HloComputation * command_buffer_computation, RewriteCommandBuffer(comp, seq, std::move(command_buffer))); + changed = true; // All computations reachable from a command buffer computation are nested // command buffers (i.e. body computations attached to a while operation). @@ -820,7 +829,7 @@ absl::StatusOr CommandBufferScheduling::Run( } TF_RETURN_IF_ERROR(module->schedule().Update()); - return true; + return changed; } } // namespace xla::gpu diff --git a/xla/service/gpu/transforms/command_buffer_scheduling.h b/xla/service/gpu/transforms/command_buffer_scheduling.h index 15f0b2dd4d4da..71d5b421c1ee5 100644 --- a/xla/service/gpu/transforms/command_buffer_scheduling.h +++ b/xla/service/gpu/transforms/command_buffer_scheduling.h @@ -99,7 +99,9 @@ class CommandBufferScheduling : public HloModulePass { // the beginning of the computation. This simplifies the construction of // command buffer computations because we don't need to deal with parameters // and constants that have users outside of a command buffer. - static absl::Status MoveParametersAndConstantsToFront( + // Returns true if there is a change in the order of instructions, false + // otherwise. + static absl::StatusOr MoveParametersAndConstantsToFront( HloComputation* computation); struct CommandBuffer { diff --git a/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index 6c79316a75a51..be29a09897b54 100644 --- a/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -1228,5 +1228,52 @@ TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionStaticSlicing) { false, true, std::nullopt)); } +TEST_F(CommandBufferSchedulingTest, ReturnFalseWhenNoChange) { + const char* hlo = R"( + HloModule module, is_scheduled=true + ENTRY main { + a = s32[8,8] parameter(0) + b = s32[8,8] parameter(1) + ROOT call = s32[8,8] custom-call(a,b), custom_call_target="__cublas$gemm" + } + )"; + + HloModuleConfig config; + DebugOptions options = GetDebugOptionsForTest(); + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo, config)); + RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()), + std::nullopt); +} + +TEST_F(CommandBufferSchedulingTest, ReturnTrueWhenOnlyParamMoved) { + const char* hlo = R"( + HloModule module, is_scheduled=true + ENTRY main { + a = s32[8,8] parameter(0) + b = s32[8,8] parameter(1) + call = s32[8,8] custom-call(a,b), custom_call_target="__cublas$gemm" + c = s32[8,8] parameter(2) + ROOT call2 = s32[8,8] custom-call(call, c), custom_call_target="__cublas$gemm" + } + )"; + + HloModuleConfig config; + DebugOptions options = GetDebugOptionsForTest(); + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo, config)); + RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()), R"( + // CHECK: %{{.+}} = {{.+}} parameter(0) + // CHECK: %{{.+}} = {{.+}} parameter(1) + // CHECK: %{{.+}} = {{.+}} parameter(2) + // CHECK: %{{.+}} = {{.+}} custom-call + // CHECK: %{{.+}} = {{.+}} custom-call + )"); +} + } // namespace } // namespace xla::gpu diff --git a/xla/stream_executor/tpu/BUILD b/xla/stream_executor/tpu/BUILD index 3e5a9786889d6..f09de10311057 100644 --- a/xla/stream_executor/tpu/BUILD +++ b/xla/stream_executor/tpu/BUILD @@ -377,6 +377,7 @@ cc_library( "//xla/service:backend", "//xla/service:stream_pool", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/xla/stream_executor/tpu/tpu_executor.cc b/xla/stream_executor/tpu/tpu_executor.cc index fdf7fdf67fdc2..1317e238336a8 100644 --- a/xla/stream_executor/tpu/tpu_executor.cc +++ b/xla/stream_executor/tpu/tpu_executor.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include "absl/cleanup/cleanup.h" -#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/types/span.h" #include "xla/stream_executor/allocator_stats.h" @@ -37,7 +36,6 @@ limitations under the License. #include "xla/stream_executor/tpu/tpu_executor_api.h" #include "xla/stream_executor/tpu/tpu_stream.h" #include "xla/stream_executor/tpu/tpu_topology.h" -#include "xla/tsl/c/tsl_status.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep diff --git a/xla/stream_executor/tpu/tpu_node_context.cc b/xla/stream_executor/tpu/tpu_node_context.cc index 9175390af712a..e991bd931577c 100644 --- a/xla/stream_executor/tpu/tpu_node_context.cc +++ b/xla/stream_executor/tpu/tpu_node_context.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/service/backend.h" -#include "xla/service/stream_pool.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/tpu/status_helper.h" #include "xla/stream_executor/tpu/tpu_api.h" #include "xla/stream_executor/tpu/tpu_ops_c_api.h" diff --git a/xla/stream_executor/tpu/tpu_node_context.h b/xla/stream_executor/tpu/tpu_node_context.h index ba51e611a900a..645ba49aea520 100644 --- a/xla/stream_executor/tpu/tpu_node_context.h +++ b/xla/stream_executor/tpu/tpu_node_context.h @@ -25,6 +25,7 @@ limitations under the License. #include "xla/service/backend.h" #include "xla/service/stream_pool.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/tpu/tpu_ops_c_api.h" #include "xla/stream_executor/tpu/tpu_platform_interface.h" #include "tsl/platform/macros.h" diff --git a/xla/stream_executor/tpu/tpu_platform.h b/xla/stream_executor/tpu/tpu_platform.h index 8eb6f19b7cdd6..97e4a1c176011 100644 --- a/xla/stream_executor/tpu/tpu_platform.h +++ b/xla/stream_executor/tpu/tpu_platform.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/executor_cache.h" #include "xla/stream_executor/platform.h"