Skip to content

Commit

Permalink
Fix race condition in dumping logic in xla::Executable
Browse files Browse the repository at this point in the history
This is adding a mutex to the lazy initialization that happens in a `const` member function. It also adds a test which hopefully ensures that this getter stays thread compatible.

PiperOrigin-RevId: 676804217
  • Loading branch information
beckerhe authored and Google-ML-Automation committed Sep 20, 2024
1 parent 5fd336b commit 7a6369b
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 1 deletion.
19 changes: 19 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1713,11 +1713,13 @@ cc_library(
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_memory_allocator",
"//xla/tsl/lib/strings:proto_serialization",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
"@tsl//tsl/platform:env",
Expand All @@ -1728,6 +1730,23 @@ cc_library(
] + internal_hlo_deps(),
)

xla_cc_test(
name = "executable_test",
srcs = ["executable_test.cc"],
deps = [
":executable",
":hlo_execution_profile",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)

cc_library(
name = "compiler",
srcs = ["compiler.cc"],
Expand Down
14 changes: 13 additions & 1 deletion xla/service/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ limitations under the License.

#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "xla/debug_options_flags.h"
Expand Down Expand Up @@ -376,6 +379,12 @@ class Executable {

// Dumping helpers.
void set_hlo_proto(std::unique_ptr<xla::HloProto> hlo_proto) {
// Despite the mutex lock, this function is NOT thread-safe.
// The mutex is needed for the lazy HLO module loading in `hlo_proto()`.
// Since both `hlo_proto()` and `buffer_assignment_proto()` return a
// pointer to hlo_proto_, having the mutex is not enough to make this
// function thread-safe.
absl::MutexLock lock(&hlo_proto_mutex_);
hlo_proto_ = std::move(hlo_proto);
}
bool dumping_snapshot() const {
Expand All @@ -385,13 +394,15 @@ class Executable {
}

HloProto const* hlo_proto() const {
absl::MutexLock lock(&hlo_proto_mutex_);
if (hlo_proto_ != nullptr && !hlo_proto_->has_hlo_module()) {
*hlo_proto_->mutable_hlo_module() = module().ToProto();
}
return hlo_proto_.get();
}

const BufferAssignmentProto* buffer_assignment_proto() const {
absl::MutexLock lock(&hlo_proto_mutex_);
return hlo_proto_ != nullptr && hlo_proto_->has_buffer_assignment()
? &hlo_proto_->buffer_assignment()
: nullptr;
Expand Down Expand Up @@ -441,7 +452,8 @@ class Executable {
// hlo_proto_->buffer_assignment is set and hlo_proto_->hlo_module isn't, the
// hlo_module proto will be computed on the fly when requested with
// hlo_proto(). This avoids wasting CPU and memory if the proto isn't needed.
std::unique_ptr<HloProto> hlo_proto_;
std::unique_ptr<HloProto> hlo_proto_ ABSL_GUARDED_BY(hlo_proto_mutex_);
mutable absl::Mutex hlo_proto_mutex_;
};

} // namespace xla
Expand Down
86 changes: 86 additions & 0 deletions xla/service/executable_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/executable.h"

#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/hlo_execution_profile.h"
#include "xla/service/service_executable_run_options.h"
#include "xla/tests/hlo_test_base.h"
#include "tsl/platform/env.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
#include "tsl/platform/threadpool.h"

namespace xla {
namespace {

class TestExecutable : public Executable {
public:
explicit TestExecutable(std::shared_ptr<HloModule> module)
: Executable{std::move(module)} {}

absl::StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) override {
return absl::UnimplementedError("Not needed for this test.");
}
};

class ExecutableTest : public HloTestBase {};

TEST_F(ExecutableTest, HloProtoGetterIsThreadCompatible) {
// Executable::hlo_proto() is doing some lazy initialization of a
// part of `hlo_proto_`. This test ensures that this is done in a
// thread-compatible way.
// Note that this test needs to run with --config=tsan to reliably
// detect any potential data races.
constexpr std::string_view kHloModule = R"(
HloModule module
ENTRY main {
ROOT c = s32[] constant(1)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));

TestExecutable executable(module);

auto proto = std::make_unique<HloProto>();
executable.set_hlo_proto(std::move(proto));

{
tsl::thread::ThreadPool pool(tsl::Env::Default(), "test",
/*num_threads=*/2);
for (int i = 0; i < 2; ++i) {
pool.Schedule([&] { executable.hlo_proto()->SerializeAsString(); });
}
}
}

} // namespace
} // namespace xla

0 comments on commit 7a6369b

Please sign in to comment.