Skip to content

Commit

Permalink
[hotfix] fix h2d count
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 committed Nov 27, 2024
1 parent fcfb031 commit c852f6c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
12 changes: 6 additions & 6 deletions csrc/pthread_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,23 @@ void PthreadAsyncIO::register_file(int fd) {}

void PthreadAsyncIO::register_h2d(unsigned int num_tensors)
{
this->h2d_in_progress.store(num_tensors); // register tensors to write for this run
this->total_h2d = num_tensors;
}

void PthreadAsyncIO::sync_h2d()
{
std::unique_lock<std::mutex> lock(this->mtx);
this->cv.wait(lock, [this]
{ return this->h2d_in_progress == 0; }); // block until all in-progress h2d are completed
{ return this->h2d_in_progress == this->total_h2d; }); // block until all in-progress h2d are completed
}

void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned)
{
auto stream = c10::cuda::getCurrentCUDAStream();
if (!t.is_cuda())
{
this->h2d_in_progress.fetch_sub(1); // already moved to cpu
if (this->h2d_in_progress.load() == 0)
auto cur_h2d = this->h2d_in_progress.fetch_add(1); // already moved to cpu
if (cur_h2d + 1 == this->total_h2d)
{ // notify when all h2d are completed and safe to optimizer.step()
std::lock_guard<std::mutex> lock(this->mtx);
cv.notify_one();
Expand All @@ -155,8 +155,8 @@ void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long of
{
cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu()
}
this->h2d_in_progress.fetch_sub(1);
if (this->h2d_in_progress.load() == 0)
auto cur_h2d = this->h2d_in_progress.fetch_add(1);
if (cur_h2d + 1 == this->total_h2d)
{ // notify when all h2d are completed and safe to optimizer.step()
std::lock_guard<std::mutex> lock(this->mtx);
cv.notify_one();
Expand Down
3 changes: 2 additions & 1 deletion include/pthread_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class PthreadAsyncIO : public AsyncIO
private:
BS::thread_pool pool;
std::atomic<unsigned int> h2d_in_progress;
unsigned int total_h2d;
std::condition_variable cv;
std::mutex mtx;
std::deque<std::tuple<std::future<ssize_t>, callback_t>> write_fut;
Expand All @@ -38,7 +39,7 @@ class PthreadAsyncIO : public AsyncIO

public:
PthreadAsyncIO(unsigned int n_entries, unsigned int n_tasks)
: pool(n_entries), h2d_in_progress(0), tasks_in_progress(0), total_tasks(n_tasks) {}
: pool(n_entries), h2d_in_progress(0), tasks_in_progress(0), total_tasks(n_tasks), total_h2d(0) {}

~PthreadAsyncIO() {}

Expand Down

0 comments on commit c852f6c

Please sign in to comment.