diff --git a/xla/pjrt/tracked_device_buffer.cc b/xla/pjrt/tracked_device_buffer.cc index 16ab8a669da699..ca551c4aa81b83 100644 --- a/xla/pjrt/tracked_device_buffer.cc +++ b/xla/pjrt/tracked_device_buffer.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" @@ -139,7 +140,6 @@ bool BufferSequencingEvent::IsComplete() { void BufferSequencingEvent::ExecuteOrAddToFutureTasks( const std::string& task_name, std::function task) { - absl::MutexLock lock(&mu_); tsl::profiler::TraceMeProducer producer( "BufferSequencingEvent::ExecuteOrAddToFutureTasks", tsl::profiler::ContextType::kPjRt); @@ -150,17 +150,33 @@ void BufferSequencingEvent::ExecuteOrAddToFutureTasks( context_id); task(); }; - if (defined_status_.IsConcrete()) { - thread_pool_->Schedule(std::move(wrapped_task)); - return; + { + absl::MutexLock lock(&mu_); + if (!defined_status_.IsConcrete()) { + on_ready_tasks_callback_[task_name] = std::move(wrapped_task); + return; + } + // Release the lock to avoid deadlock, in the case where the + // thread_pool_->Schedule() executes wrapped_task inline. + // This is rare but could happen. The callbacks could potentially try to + // acquire the mutex of this BufferSequencingEvent. } - on_ready_tasks_callback_[task_name] = std::move(wrapped_task); + thread_pool_->Schedule(std::move(wrapped_task)); } void BufferSequencingEvent::ExecuteFutureTasks() { - absl::MutexLock lock(&mu_); + absl::flat_hash_map> + on_ready_tasks_callback; + { + absl::MutexLock lock(&mu_); + on_ready_tasks_callback = std::move(on_ready_tasks_callback_); + // Release the lock to avoid deadlock, in the case where the + // thread_pool_->Schedule() executes call_all_task_callbacks inline. + // This is rare but could happen. The callbacks could potentially try to + // acquire the mutex of this BufferSequencingEvent. + } auto call_all_task_callbacks = [on_ready_tasks_callback = - std::move(on_ready_tasks_callback_)]() { + std::move(on_ready_tasks_callback)]() { for (auto& [task_name, task_callback] : on_ready_tasks_callback) { task_callback(); }