Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=#17257 from shraiysh:fix_17216 410974f
PiperOrigin-RevId: 676423808
  • Loading branch information
Google-ML-Automation committed Sep 20, 2024
1 parent de13ec6 commit 4f13a5b
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 9 deletions.
19 changes: 14 additions & 5 deletions xla/service/gpu/transforms/command_buffer_scheduling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,9 @@ CommandBufferScheduling::CollectCommandBufferSequences(
// the beginning of the computation. This simplifies the construction of command
// buffer computations because we don't need to deal with parameters and
// constants that have users outside of a command buffer.
absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront(
// Returns true if there is a change in the order of instructions, false
// otherwise.
absl::StatusOr<bool> CommandBufferScheduling::MoveParametersAndConstantsToFront(
HloComputation* computation) {
HloInstructionSequence new_sequence;
HloSchedule& schedule = computation->parent()->schedule();
Expand Down Expand Up @@ -468,7 +470,11 @@ absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront(
}

schedule.set_sequence(computation, new_sequence);
return absl::OkStatus();
for (auto [old_i, new_i] :
llvm::zip(sequence.instructions(), new_sequence.instructions())) {
if (old_i != new_i) return true;
}
return false;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -767,7 +773,7 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
if (std::min(device_description_.runtime_version(),
device_description_.driver_version()) <
se::SemanticVersion{12, 3, 0}) {
erase(kRequireTracing); // cuStreamBeginCaptureToGraph
erase(kRequireTracing); // cuStreamBeginCaptureToGraph
}
if (std::min(device_description_.runtime_version(),
device_description_.driver_version()) <
Expand All @@ -787,6 +793,7 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
std::reverse(order.begin(), order.end());
absl::flat_hash_set<HloComputation*> processed_command_buffers;

auto changed = false;
for (HloComputation* comp : order) {
// Skip special computations that do not have lowering to thunks.
if (comp->IsFusionComputation() || comp->IsAsyncComputation() ||
Expand All @@ -796,7 +803,8 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
// Skip computations that already part of command buffers.
if (processed_command_buffers.contains(comp)) continue;

TF_RETURN_IF_ERROR(MoveParametersAndConstantsToFront(comp));
TF_ASSIGN_OR_RETURN(bool changed_, MoveParametersAndConstantsToFront(comp));
changed |= changed_;

std::vector<HloInstructionSequence> sequences =
CollectCommandBufferSequences(
Expand All @@ -809,6 +817,7 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
TF_ASSIGN_OR_RETURN(
HloComputation * command_buffer_computation,
RewriteCommandBuffer(comp, seq, std::move(command_buffer)));
changed = true;

// All computations reachable from a command buffer computation are nested
// command buffers (i.e. body computations attached to a while operation).
Expand All @@ -820,7 +829,7 @@ absl::StatusOr<bool> CommandBufferScheduling::Run(
}
TF_RETURN_IF_ERROR(module->schedule().Update());

return true;
return changed;
}

} // namespace xla::gpu
4 changes: 3 additions & 1 deletion xla/service/gpu/transforms/command_buffer_scheduling.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class CommandBufferScheduling : public HloModulePass {
// the beginning of the computation. This simplifies the construction of
// command buffer computations because we don't need to deal with parameters
// and constants that have users outside of a command buffer.
static absl::Status MoveParametersAndConstantsToFront(
// Returns true if there is a change in the order of instructions, false
// otherwise.
static absl::StatusOr<bool> MoveParametersAndConstantsToFront(
HloComputation* computation);

struct CommandBuffer {
Expand Down
47 changes: 47 additions & 0 deletions xla/service/gpu/transforms/command_buffer_scheduling_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1228,5 +1228,52 @@ TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionStaticSlicing) {
false, true, std::nullopt));
}

TEST_F(CommandBufferSchedulingTest, ReturnFalseWhenNoChange) {
const char* hlo = R"(
HloModule module, is_scheduled=true
ENTRY main {
a = s32[8,8] parameter(0)
b = s32[8,8] parameter(1)
ROOT call = s32[8,8] custom-call(a,b), custom_call_target="__cublas$gemm"
}
)";

HloModuleConfig config;
DebugOptions options = GetDebugOptionsForTest();
options.clear_xla_gpu_enable_command_buffer();
options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES);
config.set_debug_options(options);
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo, config));
RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()),
std::nullopt);
}

TEST_F(CommandBufferSchedulingTest, ReturnTrueWhenOnlyParamMoved) {
const char* hlo = R"(
HloModule module, is_scheduled=true
ENTRY main {
a = s32[8,8] parameter(0)
b = s32[8,8] parameter(1)
call = s32[8,8] custom-call(a,b), custom_call_target="__cublas$gemm"
c = s32[8,8] parameter(2)
ROOT call2 = s32[8,8] custom-call(call, c), custom_call_target="__cublas$gemm"
}
)";

HloModuleConfig config;
DebugOptions options = GetDebugOptionsForTest();
options.clear_xla_gpu_enable_command_buffer();
options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES);
config.set_debug_options(options);
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo, config));
RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()), R"(
// CHECK: %{{.+}} = {{.+}} parameter(0)
// CHECK: %{{.+}} = {{.+}} parameter(1)
// CHECK: %{{.+}} = {{.+}} parameter(2)
// CHECK: %{{.+}} = {{.+}} custom-call
// CHECK: %{{.+}} = {{.+}} custom-call
)");
}

} // namespace
} // namespace xla::gpu
1 change: 1 addition & 0 deletions xla/stream_executor/tpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ cc_library(
"//xla/service:backend",
"//xla/service:stream_pool",
"//xla/stream_executor:device_memory_allocator",
"//xla/stream_executor:stream_executor_h",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down
2 changes: 0 additions & 2 deletions xla/stream_executor/tpu/tpu_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ limitations under the License.
#include <utility>

#include "absl/cleanup/cleanup.h"
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "xla/stream_executor/allocator_stats.h"
Expand All @@ -37,7 +36,6 @@ limitations under the License.
#include "xla/stream_executor/tpu/tpu_executor_api.h"
#include "xla/stream_executor/tpu/tpu_stream.h"
#include "xla/stream_executor/tpu/tpu_topology.h"
#include "xla/tsl/c/tsl_status.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h" // IWYU pragma: keep

Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/tpu/tpu_node_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/service/backend.h"
#include "xla/service/stream_pool.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/tpu/status_helper.h"
#include "xla/stream_executor/tpu/tpu_api.h"
#include "xla/stream_executor/tpu/tpu_ops_c_api.h"
Expand Down
1 change: 1 addition & 0 deletions xla/stream_executor/tpu/tpu_node_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "xla/service/backend.h"
#include "xla/service/stream_pool.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/tpu/tpu_ops_c_api.h"
#include "xla/stream_executor/tpu/tpu_platform_interface.h"
#include "tsl/platform/macros.h"
Expand Down
1 change: 1 addition & 0 deletions xla/stream_executor/tpu/tpu_platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/executor_cache.h"
#include "xla/stream_executor/platform.h"
Expand Down

0 comments on commit 4f13a5b

Please sign in to comment.