diff --git a/csrc/aio.cpp b/csrc/aio.cpp index 2d51b58..7a1d10b 100644 --- a/csrc/aio.cpp +++ b/csrc/aio.cpp @@ -1,5 +1,3 @@ -#include -#include #include "aio.h" AIOAsyncIO::AIOAsyncIO(unsigned int n_entries) @@ -126,4 +124,21 @@ void AIOAsyncIO::readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned l io_submit(this->io_ctx, 1, &iocbs); /* 提交这个I/O不会堵塞 */ this->n_read_events++; -} \ No newline at end of file +} + +void AIOAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned) { + if (t.is_cuda()) { + if (pinned.has_value()) { + pinned.value().copy_(t); + t = pinned.value(); + } else { + t = t.to(torch::kCPU); + } + } + void *buffer = t.data_ptr(); + size_t n_bytes = t.numel() * t.element_size(); + this->write(fd, buffer, n_bytes, offset, callback); +} + +void AIOAsyncIO::register_h2d(unsigned int num_tensors) {} +void AIOAsyncIO::sync_h2d() {} \ No newline at end of file diff --git a/csrc/async_file_io.cpp b/csrc/async_file_io.cpp index 625402c..6676c60 100644 --- a/csrc/async_file_io.cpp +++ b/csrc/async_file_io.cpp @@ -1,7 +1,4 @@ -#include "asyncio.h" #include "async_file_io.h" -#include "backend.h" -#include AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {} @@ -11,6 +8,18 @@ void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long of this->aio->write(this->fd, ptr, n_bytes, offset, callback); } +void AsyncFileWriter::write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional pinned) { + this->aio->write_tensor(this->fd, tensor, offset, callback, pinned); +} + +void AsyncFileWriter::register_h2d(unsigned int num_tensors) { + this->aio->register_h2d(num_tensors); +} + +void AsyncFileWriter::sync_h2d() { + this->aio->sync_h2d(); +} + void AsyncFileWriter::synchronize() { this->aio->synchronize(); diff --git a/csrc/pthread_backend.cpp b/csrc/pthread_backend.cpp index 601fc33..fc43425 100644 --- a/csrc/pthread_backend.cpp +++ b/csrc/pthread_backend.cpp @@ -76,4 +76,49 @@ void PthreadAsyncIO::synchronize() { this->get_event(WAIT); } -void PthreadAsyncIO::register_file(int fd) {} \ No newline at end of file +void PthreadAsyncIO::register_file(int fd) {} + +void PthreadAsyncIO::register_h2d(unsigned int num_tensors) { + this->h2d_in_progress.store(num_tensors); // register tensors to write for this run +} + +void PthreadAsyncIO::sync_h2d() { + std::unique_lock lock(this->mtx); + this->cv.wait(lock, [this] { return this->h2d_in_progress == 0; }); // block until all in-progress h2d are completed +} + +void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned) { + auto stream = c10::cuda::getCurrentCUDAStream(); + if (!t.is_cuda()) { + this->h2d_in_progress.fetch_sub(1); // already moved to cpu + if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step() + std::lock_guard lock(this->mtx); + cv.notify_one(); + } + } + auto fut = this->pool.submit_task( + [this, fd, t, offset, pinned, stream] { + torch::Tensor cpu_tensor; + if (t.is_cuda()) { + at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html + if (pinned.has_value()) { + pinned.value().copy_(t, /*non_blocking*/ false); + cpu_tensor = pinned.value(); + } else { + cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu() + } + this->h2d_in_progress.fetch_sub(1); + if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step() + std::lock_guard lock(this->mtx); + cv.notify_one(); + } + } else { + cpu_tensor = t; + } + void *buf = cpu_tensor.data_ptr(); + size_t n_bytes = cpu_tensor.numel() * cpu_tensor.element_size(); + return pwrite(fd, buf, n_bytes, offset); + } + ); + this->write_fut.push_back(std::make_tuple(std::move(fut), callback)); +} \ No newline at end of file diff --git a/csrc/py_api.cpp b/csrc/py_api.cpp index 73a2fda..085a1bb 100644 --- a/csrc/py_api.cpp +++ b/csrc/py_api.cpp @@ -29,5 +29,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) py::class_(m, "AsyncFileWriter") .def(py::init(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio") .def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"), py::arg("callback") = py::none()) - .def("synchronize", &AsyncFileWriter::synchronize); + .def("write_tensor", &AsyncFileWriter::write_tensor, py::arg("tensor"), py::arg("offset"), py::arg("callback") = py::none(), py::arg("pinned") = py::none()) + .def("synchronize", &AsyncFileWriter::synchronize) + .def("sync_h2d", &AsyncFileWriter::sync_h2d) + .def("register_h2d", &AsyncFileWriter::register_h2d, py::arg("num_tensors")); } \ No newline at end of file diff --git a/csrc/uring.cpp b/csrc/uring.cpp index 1382865..255e74b 100644 --- a/csrc/uring.cpp +++ b/csrc/uring.cpp @@ -97,4 +97,21 @@ void UringAsyncIO::readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned io_uring_sqe_set_data(sqe, data); io_uring_submit(&this->ring); this->n_read_events++; -} \ No newline at end of file +} + +void UringAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned) { + if (t.is_cuda()) { + if (pinned.has_value()) { + pinned.value().copy_(t); + t = pinned.value(); + } else { + t = t.to(torch::kCPU); + } + } + void *buffer = t.data_ptr(); + size_t n_bytes = t.numel() * t.element_size(); + this->write(fd, buffer, n_bytes, offset, callback); +} + +void UringAsyncIO::register_h2d(unsigned int num_tensors) {} +void UringAsyncIO::sync_h2d() {} \ No newline at end of file diff --git a/include/aio.h b/include/aio.h index f8d6a0f..fec1d24 100644 --- a/include/aio.h +++ b/include/aio.h @@ -1,6 +1,9 @@ #pragma once #include +#include +#include +#include #include "asyncio.h" class AIOAsyncIO : public AsyncIO @@ -24,9 +27,12 @@ class AIOAsyncIO : public AsyncIO void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback); void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback); + void register_h2d(unsigned int num_tensors); + void sync_h2d(); void sync_write_events(); void sync_read_events(); void synchronize(); void register_file(int fd); + void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned); }; \ No newline at end of file diff --git a/include/async_file_io.h b/include/async_file_io.h index bf1f83f..e9bb26b 100644 --- a/include/async_file_io.h +++ b/include/async_file_io.h @@ -1,9 +1,15 @@ #pragma once #include +#include +#include + #include "asyncio.h" +#include "backend.h" + #ifndef DISABLE_URING #include "uring.h" #endif + #ifndef DISABLE_AIO #include "aio.h" #endif @@ -13,7 +19,10 @@ class AsyncFileWriter public: AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend); void write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback); + void write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional pinned); void synchronize(); + void register_h2d(unsigned int num_tensors); + void sync_h2d(); ~AsyncFileWriter(); private: diff --git a/include/asyncio.h b/include/asyncio.h index d479123..4981390 100644 --- a/include/asyncio.h +++ b/include/asyncio.h @@ -2,6 +2,7 @@ #include #include +#include using callback_t = std::function; @@ -44,7 +45,10 @@ class AsyncIO virtual void get_event(WaitType wt) = 0; virtual void sync_write_events() = 0; virtual void sync_read_events() = 0; + virtual void register_h2d(unsigned int num_tensors) = 0; + virtual void sync_h2d() = 0; virtual void synchronize() = 0; virtual void register_file(int fd) = 0; + virtual void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned) = 0; }; \ No newline at end of file diff --git a/include/pthread_backend.h b/include/pthread_backend.h index b41d443..41f95d5 100644 --- a/include/pthread_backend.h +++ b/include/pthread_backend.h @@ -9,6 +9,12 @@ #include #include #include +#include +#include +#include +#include +#include +#include #include "asyncio.h" #include "threadpool.hpp" @@ -18,12 +24,15 @@ class PthreadAsyncIO : public AsyncIO { private: BS::thread_pool pool; + std::atomic h2d_in_progress; + std::condition_variable cv; + std::mutex mtx; std::deque, callback_t>> write_fut; std::deque, callback_t>> read_fut; public: PthreadAsyncIO(unsigned int n_entries) - : pool(n_entries) {} + : pool(n_entries), h2d_in_progress(0) {} ~PthreadAsyncIO() {} @@ -35,7 +44,11 @@ class PthreadAsyncIO : public AsyncIO void get_event(WaitType wt); void sync_write_events(); void sync_read_events(); + void register_h2d(unsigned int num_tensors); + void sync_h2d(); void synchronize(); void register_file(int fd); + + void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned); }; \ No newline at end of file diff --git a/include/uring.h b/include/uring.h index fee255e..24ce52c 100644 --- a/include/uring.h +++ b/include/uring.h @@ -21,9 +21,12 @@ class UringAsyncIO : public AsyncIO void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback); void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback); + void register_h2d(unsigned int num_tensors); + void sync_h2d(); void sync_write_events(); void sync_read_events(); void synchronize(); void register_file(int fd); + void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned); }; \ No newline at end of file diff --git a/tensornvme/_C/__init__.pyi b/tensornvme/_C/__init__.pyi index 2f23d8f..5bd330e 100644 --- a/tensornvme/_C/__init__.pyi +++ b/tensornvme/_C/__init__.pyi @@ -22,4 +22,5 @@ def probe_backend(backend: str) -> bool: ... class AsyncFileWriter: def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ... def write(self, buffer: int, n_bytes: int, offset: int, callback: Optional[Callable[[], None]] = None) -> None: ... + def write_tensor(self, tensor: Tensor, offset: int, callback: Optional[Callable[[], None]] = None, pinned: Optional[Tensor] = None) -> None: ... def synchronize(self) -> None: ... diff --git a/tensornvme/async_file_io.py b/tensornvme/async_file_io.py index 6a0ae61..64865da 100644 --- a/tensornvme/async_file_io.py +++ b/tensornvme/async_file_io.py @@ -1,7 +1,8 @@ import ctypes +import torch from functools import partial - -from typing import List +from torch import Tensor +from typing import List, Optional from io import IOBase from tensornvme._C import AsyncFileWriter as AsyncFileWriterC @@ -16,6 +17,7 @@ def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None: self.offset = 0 # must ensure the data is not garbage collected self.buffers = [] + self.comm_stream = torch.cuda.Stream() def write(self, data: bytes) -> int: ptr = ctypes.cast(data, ctypes.POINTER(ctypes.c_char)) @@ -31,6 +33,18 @@ def write_raw(self, py_ref: object, buffer: int, n_bytes: int, offset: int) -> N self.io.write(buffer, n_bytes, offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1)) self.offset += n_bytes + def write_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None: + with torch.cuda.stream(self.comm_stream): + self.buffers.append(tensor) # append before callback is called + self.io.write_tensor(tensor, self.offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1), pinned) + self.offset += tensor.numel() * tensor.element_size() + + def register_h2d(self, num_tensors: int) -> None: + self.io.register_h2d(num_tensors) + + def sync_before_step(self): + self.io.sync_h2d() + @staticmethod def gc_callback(listt: List, idx: int) -> None: listt[idx] = None