Skip to content

Commit

Permalink
[fio] implement async io with safetensors format (#48)
Browse files Browse the repository at this point in the history
* [fio] implement async io with safetensors format

* [fio] use raw tensor ptr instead of numpy

* [chore] refactor

* [fio] add callback

* [chore] refactor
  • Loading branch information
botbw authored Oct 14, 2024
1 parent ebc660e commit b2f9944
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 11 deletions.
5 changes: 3 additions & 2 deletions csrc/async_file_io.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#include "asyncio.h"
#include "async_file_io.h"
#include "backend.h"
#include <stdexcept>

AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {}

void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset)
void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback)
{
void *ptr = reinterpret_cast<void *>(buffer);
this->aio->write(this->fd, ptr, n_bytes, offset, nullptr);
this->aio->write(this->fd, ptr, n_bytes, offset, callback);
}

void AsyncFileWriter::synchronize()
Expand Down
2 changes: 1 addition & 1 deletion csrc/py_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("probe_backend", probe_backend, py::arg("backend"));
py::class_<AsyncFileWriter>(m, "AsyncFileWriter")
.def(py::init<int, unsigned int, const std::string &>(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio")
.def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"))
.def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"), py::arg("callback") = py::none())
.def("synchronize", &AsyncFileWriter::synchronize);
}
2 changes: 1 addition & 1 deletion include/async_file_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class AsyncFileWriter
{
public:
AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend);
void write(size_t buffer, size_t n_bytes, unsigned long long offset);
void write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
void synchronize();
~AsyncFileWriter();

Expand Down
2 changes: 1 addition & 1 deletion tensornvme/_C/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ def probe_backend(backend: str) -> bool: ...

class AsyncFileWriter:
def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ...
def write(self, buffer, n_bytes: int, offset: int) -> None: ...
def write(self, buffer: int, n_bytes: int, offset: int, callback: Optional[Callable[[], None]] = None) -> None: ...
def synchronize(self) -> None: ...
21 changes: 15 additions & 6 deletions tensornvme/async_file_io.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import ctypes
from io import IOBase
from functools import partial

from typing import List
from io import IOBase
from tensornvme._C import AsyncFileWriter as AsyncFileWriterC


class AsyncFileWriter:
def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None:
fd = fp.fileno()
Expand All @@ -17,15 +18,23 @@ def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None:
self.buffers = []

def write(self, data: bytes) -> int:
if isinstance(data, memoryview):
data = data.tobytes()
ptr = ctypes.cast(data, ctypes.POINTER(ctypes.c_char))
addr = ctypes.addressof(ptr.contents)
self.io.write(addr, len(data), self.offset)
self.buffers.append(data) # append before callback is called
self.io.write(addr, len(data), self.offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1))
self.offset += len(data)
self.buffers.append(data)

return len(data)

def write_raw(self, py_ref: object, buffer: int, n_bytes: int, offset: int) -> None:
self.buffers.append(py_ref) # append before callback is called
self.io.write(buffer, n_bytes, offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1))
self.offset += n_bytes

@staticmethod
def gc_callback(listt: List, idx: int) -> None:
listt[idx] = None

def flush(self) -> None:
pass

Expand Down

0 comments on commit b2f9944

Please sign in to comment.