From b2f9944fff5588fec83ee39d8820b8731d411c83 Mon Sep 17 00:00:00 2001 From: botbw Date: Mon, 14 Oct 2024 18:25:22 +0800 Subject: [PATCH] [fio] implement async io with safetensors format (#48) * [fio] implement async io with safetensors format * [fio] use raw tensor ptr instead of numpy * [chore] refactor * [fio] add callback * [chore] refactor --- csrc/async_file_io.cpp | 5 +++-- csrc/py_api.cpp | 2 +- include/async_file_io.h | 2 +- tensornvme/_C/__init__.pyi | 2 +- tensornvme/async_file_io.py | 21 +++++++++++++++------ 5 files changed, 21 insertions(+), 11 deletions(-) diff --git a/csrc/async_file_io.cpp b/csrc/async_file_io.cpp index b474e56..625402c 100644 --- a/csrc/async_file_io.cpp +++ b/csrc/async_file_io.cpp @@ -1,13 +1,14 @@ +#include "asyncio.h" #include "async_file_io.h" #include "backend.h" #include AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {} -void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset) +void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback) { void *ptr = reinterpret_cast(buffer); - this->aio->write(this->fd, ptr, n_bytes, offset, nullptr); + this->aio->write(this->fd, ptr, n_bytes, offset, callback); } void AsyncFileWriter::synchronize() diff --git a/csrc/py_api.cpp b/csrc/py_api.cpp index 17152d9..73a2fda 100644 --- a/csrc/py_api.cpp +++ b/csrc/py_api.cpp @@ -28,6 +28,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("probe_backend", probe_backend, py::arg("backend")); py::class_(m, "AsyncFileWriter") .def(py::init(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio") - .def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset")) + .def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"), py::arg("callback") = py::none()) .def("synchronize", &AsyncFileWriter::synchronize); } \ No newline at end of file diff --git a/include/async_file_io.h b/include/async_file_io.h index 210e30e..bf1f83f 100644 --- a/include/async_file_io.h +++ b/include/async_file_io.h @@ -12,7 +12,7 @@ class AsyncFileWriter { public: AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend); - void write(size_t buffer, size_t n_bytes, unsigned long long offset); + void write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback); void synchronize(); ~AsyncFileWriter(); diff --git a/tensornvme/_C/__init__.pyi b/tensornvme/_C/__init__.pyi index 3825d96..2f23d8f 100644 --- a/tensornvme/_C/__init__.pyi +++ b/tensornvme/_C/__init__.pyi @@ -21,5 +21,5 @@ def probe_backend(backend: str) -> bool: ... class AsyncFileWriter: def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ... - def write(self, buffer, n_bytes: int, offset: int) -> None: ... + def write(self, buffer: int, n_bytes: int, offset: int, callback: Optional[Callable[[], None]] = None) -> None: ... def synchronize(self) -> None: ... diff --git a/tensornvme/async_file_io.py b/tensornvme/async_file_io.py index 027807c..6a0ae61 100644 --- a/tensornvme/async_file_io.py +++ b/tensornvme/async_file_io.py @@ -1,9 +1,10 @@ import ctypes -from io import IOBase +from functools import partial +from typing import List +from io import IOBase from tensornvme._C import AsyncFileWriter as AsyncFileWriterC - class AsyncFileWriter: def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None: fd = fp.fileno() @@ -17,15 +18,23 @@ def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None: self.buffers = [] def write(self, data: bytes) -> int: - if isinstance(data, memoryview): - data = data.tobytes() ptr = ctypes.cast(data, ctypes.POINTER(ctypes.c_char)) addr = ctypes.addressof(ptr.contents) - self.io.write(addr, len(data), self.offset) + self.buffers.append(data) # append before callback is called + self.io.write(addr, len(data), self.offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1)) self.offset += len(data) - self.buffers.append(data) + return len(data) + def write_raw(self, py_ref: object, buffer: int, n_bytes: int, offset: int) -> None: + self.buffers.append(py_ref) # append before callback is called + self.io.write(buffer, n_bytes, offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1)) + self.offset += n_bytes + + @staticmethod + def gc_callback(listt: List, idx: int) -> None: + listt[idx] = None + def flush(self) -> None: pass