Skip to content

Commit

Permalink
[h2d] add individual sync for h2d
Browse files Browse the repository at this point in the history
  • Loading branch information
botbw committed Nov 5, 2024
1 parent 39ea874 commit 5da46cd
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 19 deletions.
5 changes: 4 additions & 1 deletion csrc/aio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,7 @@ void AIOAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset
void *buffer = t.data_ptr();
size_t n_bytes = t.numel() * t.element_size();
this->write(fd, buffer, n_bytes, offset, callback);
}
}

void AIOAsyncIO::register_h2d(unsigned int num_tensors) {}
void AIOAsyncIO::sync_h2d() {}
7 changes: 7 additions & 0 deletions csrc/async_file_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ void AsyncFileWriter::write_tensor(torch::Tensor tensor, unsigned long long offs
this->aio->write_tensor(this->fd, tensor, offset, callback, pinned);
}

void AsyncFileWriter::register_h2d(unsigned int num_tensors) {
this->aio->register_h2d(num_tensors);
}

void AsyncFileWriter::sync_h2d() {
this->aio->sync_h2d();
}

void AsyncFileWriter::synchronize()
{
Expand Down
42 changes: 27 additions & 15 deletions csrc/pthread_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,36 @@ void PthreadAsyncIO::synchronize() {

void PthreadAsyncIO::register_file(int fd) {}

void PthreadAsyncIO::register_h2d(unsigned int num_tensors) {
this->h2d_in_progress.store(num_tensors); // register tensors to write for this run
}

void PthreadAsyncIO::sync_h2d() {
std::unique_lock<std::mutex> lock(this->mtx);
this->cv.wait(lock, [this] { return this->h2d_in_progress == 0; }); // block until all in-progress h2d are completed
}

void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
auto stream = c10::cuda::getCurrentCUDAStream();
at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html
auto event_ptr = std::make_shared<c10::Event>(torch::kCUDA); // make a shared ptr here since event is not copyable
if (t.is_cuda()) {
if (pinned.has_value()) {
pinned.value().copy_(t, /*non_blocking*/ true);
t = pinned.value();
} else {
t = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ true, /*copy*/ false); // modified from torch::Tensor::cpu()
}
}
event_ptr->record(stream);
auto fut = this->pool.submit_task(
[fd, t, offset, pinned, event_ptr] {
event_ptr->synchronize(); // sync with comm stream
void *buf = t.data_ptr();
size_t n_bytes = t.numel() * t.element_size();
[this, fd, t, offset, pinned, stream] {
at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html
torch::Tensor cpu_tensor;
if (t.is_cuda()) {
if (pinned.has_value()) {
pinned.value().copy_(t, /*non_blocking*/ false);
cpu_tensor = pinned.value();
} else {
cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu()
}
}
this->h2d_in_progress.fetch_sub(1);
if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step()
std::lock_guard<std::mutex> lock(this->mtx);
cv.notify_one();
}
void *buf = cpu_tensor.data_ptr();
size_t n_bytes = cpu_tensor.numel() * cpu_tensor.element_size();
return pwrite(fd, buf, n_bytes, offset);
}
);
Expand Down
4 changes: 3 additions & 1 deletion csrc/py_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def(py::init<int, unsigned int, const std::string &>(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio")
.def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"), py::arg("callback") = py::none())
.def("write_tensor", &AsyncFileWriter::write_tensor, py::arg("tensor"), py::arg("offset"), py::arg("callback") = py::none(), py::arg("pinned") = py::none())
.def("synchronize", &AsyncFileWriter::synchronize);
.def("synchronize", &AsyncFileWriter::synchronize)
.def("sync_h2d", &AsyncFileWriter::sync_h2d)
.def("register_h2d", &AsyncFileWriter::register_h2d, py::arg("num_tensors"));
}
5 changes: 4 additions & 1 deletion csrc/uring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,7 @@ void UringAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offs
void *buffer = t.data_ptr<float>();
size_t n_bytes = t.numel() * t.element_size();
this->write(fd, buffer, n_bytes, offset, callback);
}
}

void UringAsyncIO::register_h2d(unsigned int num_tensors) {}
void UringAsyncIO::sync_h2d() {}
2 changes: 2 additions & 0 deletions include/aio.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class AIOAsyncIO : public AsyncIO
void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);

void register_h2d(unsigned int num_tensors);
void sync_h2d();
void sync_write_events();
void sync_read_events();
void synchronize();
Expand Down
2 changes: 2 additions & 0 deletions include/async_file_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class AsyncFileWriter
void write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
void write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
void synchronize();
void register_h2d(unsigned int num_tensors);
void sync_h2d();
~AsyncFileWriter();

private:
Expand Down
2 changes: 2 additions & 0 deletions include/asyncio.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class AsyncIO
virtual void get_event(WaitType wt) = 0;
virtual void sync_write_events() = 0;
virtual void sync_read_events() = 0;
virtual void register_h2d(unsigned int num_tensors) = 0;
virtual void sync_h2d() = 0;
virtual void synchronize() = 0;

virtual void register_file(int fd) = 0;
Expand Down
10 changes: 9 additions & 1 deletion include/pthread_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#include <iostream>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
#include <atomic>
#include <condition_variable>
#include <mutex>

#include "asyncio.h"
#include "threadpool.hpp"
Expand All @@ -21,12 +24,15 @@ class PthreadAsyncIO : public AsyncIO
{
private:
BS::thread_pool pool;
std::atomic<unsigned int> h2d_in_progress;
std::condition_variable cv;
std::mutex mtx;
std::deque<std::tuple<std::future<ssize_t>, callback_t>> write_fut;
std::deque<std::tuple<std::future<ssize_t>, callback_t>> read_fut;

public:
PthreadAsyncIO(unsigned int n_entries)
: pool(n_entries) {}
: pool(n_entries), h2d_in_progress(0) {}

~PthreadAsyncIO() {}

Expand All @@ -38,6 +44,8 @@ class PthreadAsyncIO : public AsyncIO
void get_event(WaitType wt);
void sync_write_events();
void sync_read_events();
void register_h2d(unsigned int num_tensors);
void sync_h2d();
void synchronize();

void register_file(int fd);
Expand Down
2 changes: 2 additions & 0 deletions include/uring.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class UringAsyncIO : public AsyncIO
void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);

void register_h2d(unsigned int num_tensors);
void sync_h2d();
void sync_write_events();
void sync_read_events();
void synchronize();
Expand Down
7 changes: 7 additions & 0 deletions tensornvme/async_file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,19 @@ def write_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None:
self.io.write_tensor(tensor, self.offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1), pinned)
self.offset += tensor.numel() * tensor.element_size()

def sync_h2d(self) -> None:
self.io.sync_h2d()

def register_h2d(self, num_tensors: int) -> None:
self.io.register_h2d(num_tensors)

def write_gpu_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None:
assert tensor.device.type == 'cuda', f"tensor must be on cuda device, got {tensor.device}"
with torch.cuda.stream(self.comm_stream):
self.write_tensor(tensor, pinned)

def sync_before_step(self):
self.sync_h2d()
self.comm_stream.synchronize()

@staticmethod
Expand Down

0 comments on commit 5da46cd

Please sign in to comment.