diff --git a/csrc/aio.cpp b/csrc/aio.cpp index 6f29252..b83a20f 100644 --- a/csrc/aio.cpp +++ b/csrc/aio.cpp @@ -1,6 +1,6 @@ #include "aio.h" -AIOAsyncIO::AIOAsyncIO(unsigned int n_entries) +AIOAsyncIO::AIOAsyncIO(unsigned int n_entries, unsigned int n_tasks) { // printf("Initializing the io Context\n"); this->max_nr = n_entries; diff --git a/csrc/async_file_io.cpp b/csrc/async_file_io.cpp index fc9b3cd..3d26f0b 100644 --- a/csrc/async_file_io.cpp +++ b/csrc/async_file_io.cpp @@ -1,6 +1,6 @@ #include "async_file_io.h" -AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {} +AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend, unsigned int n_tasks) : fd(fd), aio(create_asyncio(n_entries, backend, n_tasks)) {} void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback) { diff --git a/csrc/backend.cpp b/csrc/backend.cpp index 5afee31..85ab355 100644 --- a/csrc/backend.cpp +++ b/csrc/backend.cpp @@ -44,7 +44,7 @@ void probe_asyncio(const std::string &backend) if (backend == "uring") { #ifndef DISABLE_URING - aio.reset(new UringAsyncIO(2)); + aio.reset(new UringAsyncIO(2, 0)); #else throw std::runtime_error("backend uring is not installed\n"); #endif @@ -52,7 +52,7 @@ void probe_asyncio(const std::string &backend) else if (backend == "aio") { #ifndef DISABLE_AIO - aio.reset(new AIOAsyncIO(2)); + aio.reset(new AIOAsyncIO(2, 0)); #else throw std::runtime_error("backend aio is not installed\n"); #endif @@ -60,7 +60,7 @@ void probe_asyncio(const std::string &backend) else if (backend == "pthread") { #ifndef DISABLE_PTHREAD - aio.reset(new PthreadAsyncIO(2)); + aio.reset(new PthreadAsyncIO(2, 0)); #else throw std::runtime_error("backend pthread is not installed\n"); #endif @@ -160,7 +160,7 @@ std::string get_debug_log() return std::string(env_); } -AsyncIO *create_asyncio(unsigned int n_entries, std::string backend) +AsyncIO *create_asyncio(unsigned int n_entries, std::string backend, unsigned int n_tasks) { std::unordered_set backends = get_backends(); std::string default_backend = get_default_backend(); @@ -188,15 +188,15 @@ AsyncIO *create_asyncio(unsigned int n_entries, std::string backend) #ifndef DISABLE_URING if (backend == "uring") - return new UringAsyncIO(n_entries); + return new UringAsyncIO(n_entries, n_tasks); #endif #ifndef DISABLE_AIO if (backend == "aio") - return new AIOAsyncIO(n_entries); + return new AIOAsyncIO(n_entries, n_tasks); #endif #ifndef DISABLE_PTHREAD if (backend == "pthread") - return new PthreadAsyncIO(n_entries); + return new PthreadAsyncIO(n_entries, n_tasks); #endif throw std::runtime_error("Unsupported backend: " + backend); } \ No newline at end of file diff --git a/csrc/offload.cpp b/csrc/offload.cpp index 35c1881..b43bde1 100644 --- a/csrc/offload.cpp +++ b/csrc/offload.cpp @@ -28,7 +28,7 @@ iovec *tensors_to_iovec(const std::vector &tensors) Offloader::Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend) : filename(filename), space_mgr(SpaceManager(0)) { - this->aio = create_asyncio(n_entries, backend); + this->aio = create_asyncio(n_entries, backend, 0); this->fd = open(filename.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR); this->aio->register_file(fd); } diff --git a/csrc/pthread_backend.cpp b/csrc/pthread_backend.cpp index 0fbc15e..62f040d 100644 --- a/csrc/pthread_backend.cpp +++ b/csrc/pthread_backend.cpp @@ -16,8 +16,8 @@ void PthreadAsyncIO::write(int fd, void *buffer, size_t n_bytes, unsigned long l auto val = pwrite(fd, buffer, n_bytes, offset); if (this->is_debug) { - auto cur_tasks = this->tasks_in_progress.fetch_sub(1); - if (cur_tasks == 1) + auto cur_tasks = this->tasks_in_progress.fetch_add(1); + if (cur_tasks + 1 == this->total_tasks) { if (this->debug_log.empty()) { @@ -117,14 +117,14 @@ void PthreadAsyncIO::register_file(int fd) {} void PthreadAsyncIO::register_h2d(unsigned int num_tensors) { - this->h2d_in_progress.store(num_tensors); // register tensors to write for this run + this->total_h2d = num_tensors; } void PthreadAsyncIO::sync_h2d() { std::unique_lock lock(this->mtx); this->cv.wait(lock, [this] - { return this->h2d_in_progress == 0; }); // block until all in-progress h2d are completed + { return this->h2d_in_progress == this->total_h2d; }); // block until all in-progress h2d are completed } void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned) @@ -132,8 +132,8 @@ void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long of auto stream = c10::cuda::getCurrentCUDAStream(); if (!t.is_cuda()) { - this->h2d_in_progress.fetch_sub(1); // already moved to cpu - if (this->h2d_in_progress.load() == 0) + auto cur_h2d = this->h2d_in_progress.fetch_add(1); // already moved to cpu + if (cur_h2d + 1 == this->total_h2d) { // notify when all h2d are completed and safe to optimizer.step() std::lock_guard lock(this->mtx); cv.notify_one(); @@ -155,8 +155,8 @@ void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long of { cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu() } - this->h2d_in_progress.fetch_sub(1); - if (this->h2d_in_progress.load() == 0) + auto cur_h2d = this->h2d_in_progress.fetch_add(1); + if (cur_h2d + 1 == this->total_h2d) { // notify when all h2d are completed and safe to optimizer.step() std::lock_guard lock(this->mtx); cv.notify_one(); @@ -171,8 +171,8 @@ void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long of auto val = pwrite(fd, buf, n_bytes, offset); if (this->is_debug) { - auto cur_tasks = this->tasks_in_progress.fetch_sub(1); - if (cur_tasks == 1) + auto cur_tasks = this->tasks_in_progress.fetch_add(1); + if (cur_tasks + 1 == this->total_tasks) { if (this->debug_log.empty()) { diff --git a/csrc/py_api.cpp b/csrc/py_api.cpp index ced0c3b..2bfb26e 100644 --- a/csrc/py_api.cpp +++ b/csrc/py_api.cpp @@ -27,7 +27,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("get_backends", get_backends); m.def("probe_backend", probe_backend, py::arg("backend")); py::class_(m, "AsyncFileWriter") - .def(py::init(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio") + .def(py::init(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio", py::arg("n_tasks") = 0) .def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"), py::arg("callback") = py::none()) .def("write_tensor", &AsyncFileWriter::write_tensor, py::arg("tensor"), py::arg("offset"), py::arg("callback") = py::none(), py::arg("pinned") = py::none()) .def("synchronize", &AsyncFileWriter::synchronize) diff --git a/csrc/uring.cpp b/csrc/uring.cpp index 28a73fb..4637892 100644 --- a/csrc/uring.cpp +++ b/csrc/uring.cpp @@ -2,7 +2,7 @@ #include #include "uring.h" -UringAsyncIO::UringAsyncIO(unsigned int n_entries) : n_write_events(0), n_read_events(0), n_entries(n_entries) +UringAsyncIO::UringAsyncIO(unsigned int n_entries, unsigned int n_tasks) : n_write_events(0), n_read_events(0), n_entries(n_entries) { io_uring_queue_init(n_entries, &this->ring, 0); } diff --git a/include/aio.h b/include/aio.h index 7b9b996..1516ece 100644 --- a/include/aio.h +++ b/include/aio.h @@ -19,7 +19,7 @@ class AIOAsyncIO : public AsyncIO void get_event(WaitType wt); public: - AIOAsyncIO(unsigned int n_entries); + AIOAsyncIO(unsigned int n_entries, unsigned int n_tasks); ~AIOAsyncIO(); void write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback); diff --git a/include/async_file_io.h b/include/async_file_io.h index 6077ec3..bfd4117 100644 --- a/include/async_file_io.h +++ b/include/async_file_io.h @@ -17,7 +17,7 @@ class AsyncFileWriter { public: - AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend); + AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend, unsigned int n_tasks); void write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback); void write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional pinned); void synchronize(); diff --git a/include/backend.h b/include/backend.h index 09ee07a..a135d2f 100644 --- a/include/backend.h +++ b/include/backend.h @@ -14,6 +14,6 @@ std::string get_default_backend(); bool get_debug_flag(); -AsyncIO *create_asyncio(unsigned int n_entries, std::string backend); +AsyncIO *create_asyncio(unsigned int n_entries, std::string backend, unsigned int n_tasks); std::string get_debug_log(); \ No newline at end of file diff --git a/include/pthread_backend.h b/include/pthread_backend.h index 1eb049e..27a5daa 100644 --- a/include/pthread_backend.h +++ b/include/pthread_backend.h @@ -26,6 +26,7 @@ class PthreadAsyncIO : public AsyncIO private: BS::thread_pool pool; std::atomic h2d_in_progress; + unsigned int total_h2d; std::condition_variable cv; std::mutex mtx; std::deque, callback_t>> write_fut; @@ -34,10 +35,11 @@ class PthreadAsyncIO : public AsyncIO const std::string debug_log = get_debug_log(); std::atomic tasks_in_progress; + unsigned int total_tasks; public: - PthreadAsyncIO(unsigned int n_entries) - : pool(n_entries), h2d_in_progress(0) {} + PthreadAsyncIO(unsigned int n_entries, unsigned int n_tasks) + : pool(n_entries), h2d_in_progress(0), tasks_in_progress(0), total_tasks(n_tasks), total_h2d(0) {} ~PthreadAsyncIO() {} diff --git a/include/uring.h b/include/uring.h index 9df91b4..add7b38 100644 --- a/include/uring.h +++ b/include/uring.h @@ -13,7 +13,7 @@ class UringAsyncIO : public AsyncIO void get_event(WaitType wt); public: - UringAsyncIO(unsigned int n_entries); + UringAsyncIO(unsigned int n_entries, unsigned int n_tasks); ~UringAsyncIO(); void write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback); diff --git a/tensornvme/_C/__init__.pyi b/tensornvme/_C/__init__.pyi index 41bdd30..8f40ec3 100644 --- a/tensornvme/_C/__init__.pyi +++ b/tensornvme/_C/__init__.pyi @@ -20,7 +20,7 @@ def get_backends() -> Set[str]: ... def probe_backend(backend: str) -> bool: ... class AsyncFileWriter: - def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ... + def __init__(self, fd: int, n_entries: int, backend: str = "aio", n_tasks: int = 0) -> None: ... def write(self, buffer: int, n_bytes: int, offset: int, callback: Optional[Callable[[], None]] = None) -> None: ... def write_tensor( self, diff --git a/tensornvme/async_file_io.py b/tensornvme/async_file_io.py index e01a09a..2a19210 100644 --- a/tensornvme/async_file_io.py +++ b/tensornvme/async_file_io.py @@ -10,14 +10,14 @@ class AsyncFileWriter: - def __init__(self, path: str, n_entries: int = 16, backend=None) -> None: + def __init__(self, path: str, n_entries: int = 16, backend=None, n_tasks: int = 0) -> None: # this still takes ram buffer, which may lead to OOM # self.f = open(path, "wb", buffering=0) self.fd = os.open(path, os.O_WRONLY | os.O_CREAT, mode=0o664) if backend is not None: - self.io = AsyncFileWriterC(self.fd, n_entries, backend=backend) + self.io = AsyncFileWriterC(self.fd, n_entries, backend=backend, n_tasks=n_tasks) else: - self.io = AsyncFileWriterC(self.fd, n_entries) + self.io = AsyncFileWriterC(self.fd, n_entries, n_tasks=n_tasks) self.offset = 0 # must ensure the data is not garbage collected self.buffers = []