diff --git a/csrc/pthread_backend.cpp b/csrc/pthread_backend.cpp index a5e2d07..62f040d 100644 --- a/csrc/pthread_backend.cpp +++ b/csrc/pthread_backend.cpp @@ -117,14 +117,14 @@ void PthreadAsyncIO::register_file(int fd) {} void PthreadAsyncIO::register_h2d(unsigned int num_tensors) { - this->h2d_in_progress.store(num_tensors); // register tensors to write for this run + this->total_h2d = num_tensors; } void PthreadAsyncIO::sync_h2d() { std::unique_lock lock(this->mtx); this->cv.wait(lock, [this] - { return this->h2d_in_progress == 0; }); // block until all in-progress h2d are completed + { return this->h2d_in_progress == this->total_h2d; }); // block until all in-progress h2d are completed } void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned) @@ -132,8 +132,8 @@ void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long of auto stream = c10::cuda::getCurrentCUDAStream(); if (!t.is_cuda()) { - this->h2d_in_progress.fetch_sub(1); // already moved to cpu - if (this->h2d_in_progress.load() == 0) + auto cur_h2d = this->h2d_in_progress.fetch_add(1); // already moved to cpu + if (cur_h2d + 1 == this->total_h2d) { // notify when all h2d are completed and safe to optimizer.step() std::lock_guard lock(this->mtx); cv.notify_one(); @@ -155,8 +155,8 @@ void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long of { cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu() } - this->h2d_in_progress.fetch_sub(1); - if (this->h2d_in_progress.load() == 0) + auto cur_h2d = this->h2d_in_progress.fetch_add(1); + if (cur_h2d + 1 == this->total_h2d) { // notify when all h2d are completed and safe to optimizer.step() std::lock_guard lock(this->mtx); cv.notify_one(); diff --git a/include/pthread_backend.h b/include/pthread_backend.h index 9acb216..27a5daa 100644 --- a/include/pthread_backend.h +++ b/include/pthread_backend.h @@ -26,6 +26,7 @@ class PthreadAsyncIO : public AsyncIO private: BS::thread_pool pool; std::atomic h2d_in_progress; + unsigned int total_h2d; std::condition_variable cv; std::mutex mtx; std::deque, callback_t>> write_fut; @@ -38,7 +39,7 @@ class PthreadAsyncIO : public AsyncIO public: PthreadAsyncIO(unsigned int n_entries, unsigned int n_tasks) - : pool(n_entries), h2d_in_progress(0), tasks_in_progress(0), total_tasks(n_tasks) {} + : pool(n_entries), h2d_in_progress(0), tasks_in_progress(0), total_tasks(n_tasks), total_h2d(0) {} ~PthreadAsyncIO() {}