diff --git a/csrc/aio.cpp b/csrc/aio.cpp index 0796c34..7a1d10b 100644 --- a/csrc/aio.cpp +++ b/csrc/aio.cpp @@ -138,4 +138,7 @@ void AIOAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset void *buffer = t.data_ptr(); size_t n_bytes = t.numel() * t.element_size(); this->write(fd, buffer, n_bytes, offset, callback); -} \ No newline at end of file +} + +void AIOAsyncIO::register_h2d(unsigned int num_tensors) {} +void AIOAsyncIO::sync_h2d() {} \ No newline at end of file diff --git a/csrc/async_file_io.cpp b/csrc/async_file_io.cpp index 918dfc8..6676c60 100644 --- a/csrc/async_file_io.cpp +++ b/csrc/async_file_io.cpp @@ -12,6 +12,13 @@ void AsyncFileWriter::write_tensor(torch::Tensor tensor, unsigned long long offs this->aio->write_tensor(this->fd, tensor, offset, callback, pinned); } +void AsyncFileWriter::register_h2d(unsigned int num_tensors) { + this->aio->register_h2d(num_tensors); +} + +void AsyncFileWriter::sync_h2d() { + this->aio->sync_h2d(); +} void AsyncFileWriter::synchronize() { diff --git a/csrc/pthread_backend.cpp b/csrc/pthread_backend.cpp index 04f0221..ede6da5 100644 --- a/csrc/pthread_backend.cpp +++ b/csrc/pthread_backend.cpp @@ -78,24 +78,36 @@ void PthreadAsyncIO::synchronize() { void PthreadAsyncIO::register_file(int fd) {} +void PthreadAsyncIO::register_h2d(unsigned int num_tensors) { + this->h2d_in_progress.store(num_tensors); // register tensors to write for this run +} + +void PthreadAsyncIO::sync_h2d() { + std::unique_lock lock(this->mtx); + this->cv.wait(lock, [this] { return this->h2d_in_progress == 0; }); // block until all in-progress h2d are completed +} + void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned) { auto stream = c10::cuda::getCurrentCUDAStream(); - at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html - auto event_ptr = std::make_shared(torch::kCUDA); // make a shared ptr here since event is not copyable - if (t.is_cuda()) { - if (pinned.has_value()) { - pinned.value().copy_(t, /*non_blocking*/ true); - t = pinned.value(); - } else { - t = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ true, /*copy*/ false); // modified from torch::Tensor::cpu() - } - } - event_ptr->record(stream); auto fut = this->pool.submit_task( - [fd, t, offset, pinned, event_ptr] { - event_ptr->synchronize(); // sync with comm stream - void *buf = t.data_ptr(); - size_t n_bytes = t.numel() * t.element_size(); + [this, fd, t, offset, pinned, stream] { + at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html + torch::Tensor cpu_tensor; + if (t.is_cuda()) { + if (pinned.has_value()) { + pinned.value().copy_(t, /*non_blocking*/ false); + cpu_tensor = pinned.value(); + } else { + cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu() + } + } + this->h2d_in_progress.fetch_sub(1); + if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step() + std::lock_guard lock(this->mtx); + cv.notify_one(); + } + void *buf = cpu_tensor.data_ptr(); + size_t n_bytes = cpu_tensor.numel() * cpu_tensor.element_size(); return pwrite(fd, buf, n_bytes, offset); } ); diff --git a/csrc/py_api.cpp b/csrc/py_api.cpp index a305ac2..085a1bb 100644 --- a/csrc/py_api.cpp +++ b/csrc/py_api.cpp @@ -30,5 +30,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) .def(py::init(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio") .def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"), py::arg("callback") = py::none()) .def("write_tensor", &AsyncFileWriter::write_tensor, py::arg("tensor"), py::arg("offset"), py::arg("callback") = py::none(), py::arg("pinned") = py::none()) - .def("synchronize", &AsyncFileWriter::synchronize); + .def("synchronize", &AsyncFileWriter::synchronize) + .def("sync_h2d", &AsyncFileWriter::sync_h2d) + .def("register_h2d", &AsyncFileWriter::register_h2d, py::arg("num_tensors")); } \ No newline at end of file diff --git a/csrc/uring.cpp b/csrc/uring.cpp index 8cd3dc0..255e74b 100644 --- a/csrc/uring.cpp +++ b/csrc/uring.cpp @@ -111,4 +111,7 @@ void UringAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offs void *buffer = t.data_ptr(); size_t n_bytes = t.numel() * t.element_size(); this->write(fd, buffer, n_bytes, offset, callback); -} \ No newline at end of file +} + +void UringAsyncIO::register_h2d(unsigned int num_tensors) {} +void UringAsyncIO::sync_h2d() {} \ No newline at end of file diff --git a/include/aio.h b/include/aio.h index a4aee4e..fec1d24 100644 --- a/include/aio.h +++ b/include/aio.h @@ -27,6 +27,8 @@ class AIOAsyncIO : public AsyncIO void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback); void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback); + void register_h2d(unsigned int num_tensors); + void sync_h2d(); void sync_write_events(); void sync_read_events(); void synchronize(); diff --git a/include/async_file_io.h b/include/async_file_io.h index d12e4fe..e9bb26b 100644 --- a/include/async_file_io.h +++ b/include/async_file_io.h @@ -21,6 +21,8 @@ class AsyncFileWriter void write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback); void write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional pinned); void synchronize(); + void register_h2d(unsigned int num_tensors); + void sync_h2d(); ~AsyncFileWriter(); private: diff --git a/include/asyncio.h b/include/asyncio.h index 68a501a..4981390 100644 --- a/include/asyncio.h +++ b/include/asyncio.h @@ -45,6 +45,8 @@ class AsyncIO virtual void get_event(WaitType wt) = 0; virtual void sync_write_events() = 0; virtual void sync_read_events() = 0; + virtual void register_h2d(unsigned int num_tensors) = 0; + virtual void sync_h2d() = 0; virtual void synchronize() = 0; virtual void register_file(int fd) = 0; diff --git a/include/pthread_backend.h b/include/pthread_backend.h index 75c83b9..41f95d5 100644 --- a/include/pthread_backend.h +++ b/include/pthread_backend.h @@ -12,6 +12,9 @@ #include #include #include +#include +#include +#include #include "asyncio.h" #include "threadpool.hpp" @@ -21,12 +24,15 @@ class PthreadAsyncIO : public AsyncIO { private: BS::thread_pool pool; + std::atomic h2d_in_progress; + std::condition_variable cv; + std::mutex mtx; std::deque, callback_t>> write_fut; std::deque, callback_t>> read_fut; public: PthreadAsyncIO(unsigned int n_entries) - : pool(n_entries) {} + : pool(n_entries), h2d_in_progress(0) {} ~PthreadAsyncIO() {} @@ -38,6 +44,8 @@ class PthreadAsyncIO : public AsyncIO void get_event(WaitType wt); void sync_write_events(); void sync_read_events(); + void register_h2d(unsigned int num_tensors); + void sync_h2d(); void synchronize(); void register_file(int fd); diff --git a/include/uring.h b/include/uring.h index 6f95215..24ce52c 100644 --- a/include/uring.h +++ b/include/uring.h @@ -21,6 +21,8 @@ class UringAsyncIO : public AsyncIO void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback); void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback); + void register_h2d(unsigned int num_tensors); + void sync_h2d(); void sync_write_events(); void sync_read_events(); void synchronize(); diff --git a/tensornvme/async_file_io.py b/tensornvme/async_file_io.py index b223a01..45f3489 100644 --- a/tensornvme/async_file_io.py +++ b/tensornvme/async_file_io.py @@ -38,12 +38,19 @@ def write_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None: self.io.write_tensor(tensor, self.offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1), pinned) self.offset += tensor.numel() * tensor.element_size() + def sync_h2d(self) -> None: + self.io.sync_h2d() + + def register_h2d(self, num_tensors: int) -> None: + self.io.register_h2d(num_tensors) + def write_gpu_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None: assert tensor.device.type == 'cuda', f"tensor must be on cuda device, got {tensor.device}" with torch.cuda.stream(self.comm_stream): self.write_tensor(tensor, pinned) def sync_before_step(self): + self.sync_h2d() self.comm_stream.synchronize() @staticmethod