Skip to content

Commit

Permalink
Avoid acquiring BufferSequencingEvent::mu_ twice in the case where th…
Browse files Browse the repository at this point in the history
…e thread_pool executes the callbacks inline.

PiperOrigin-RevId: 681708085
  • Loading branch information
yifjiang authored and Google-ML-Automation committed Oct 3, 2024
1 parent 92e3c7a commit 0946107
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions xla/pjrt/tracked_device_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
Expand Down Expand Up @@ -139,7 +140,6 @@ bool BufferSequencingEvent::IsComplete() {

void BufferSequencingEvent::ExecuteOrAddToFutureTasks(
const std::string& task_name, std::function<void()> task) {
absl::MutexLock lock(&mu_);
tsl::profiler::TraceMeProducer producer(
"BufferSequencingEvent::ExecuteOrAddToFutureTasks",
tsl::profiler::ContextType::kPjRt);
Expand All @@ -150,17 +150,33 @@ void BufferSequencingEvent::ExecuteOrAddToFutureTasks(
context_id);
task();
};
if (defined_status_.IsConcrete()) {
thread_pool_->Schedule(std::move(wrapped_task));
return;
{
absl::MutexLock lock(&mu_);
if (!defined_status_.IsConcrete()) {
on_ready_tasks_callback_[task_name] = std::move(wrapped_task);
return;
}
// Release the lock to avoid deadlock, in the case where the
// thread_pool_->Schedule() executes wrapped_task inline.
// This is rare but could happen. The callbacks could potentially try to
// acquire the mutex of this BufferSequencingEvent.
}
on_ready_tasks_callback_[task_name] = std::move(wrapped_task);
thread_pool_->Schedule(std::move(wrapped_task));
}

void BufferSequencingEvent::ExecuteFutureTasks() {
absl::MutexLock lock(&mu_);
absl::flat_hash_map<std::string, std::function<void()>>
on_ready_tasks_callback;
{
absl::MutexLock lock(&mu_);
on_ready_tasks_callback = std::move(on_ready_tasks_callback_);
// Release the lock to avoid deadlock, in the case where the
// thread_pool_->Schedule() executes call_all_task_callbacks inline.
// This is rare but could happen. The callbacks could potentially try to
// acquire the mutex of this BufferSequencingEvent.
}
auto call_all_task_callbacks = [on_ready_tasks_callback =
std::move(on_ready_tasks_callback_)]() {
std::move(on_ready_tasks_callback)]() {
for (auto& [task_name, task_callback] : on_ready_tasks_callback) {
task_callback();
}
Expand Down

0 comments on commit 0946107

Please sign in to comment.