Skip to content

Commit

Permalink
[backend] add backend option to async file writer
Browse files Browse the repository at this point in the history
  • Loading branch information
botbw committed Oct 9, 2024
1 parent 8f63b21 commit 2b313a5
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 26 deletions.
14 changes: 1 addition & 13 deletions csrc/async_file_io.cpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
#include "async_file_io.h"
#include "backend.h"
#include <stdexcept>
#include <string>

AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries) : fd(fd)
{
for (const std::string &backend : get_backends())
{
if (probe_backend(backend))
{
this->aio = create_asyncio(n_entries, backend);
return;
}
}
throw std::runtime_error("No asyncio backend is installed");
}
AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {}

void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset)
{
Expand Down
14 changes: 9 additions & 5 deletions csrc/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,19 +130,23 @@ std::string get_default_backend() {
AsyncIO *create_asyncio(unsigned int n_entries, std::string backend)
{
std::unordered_set<std::string> backends = get_backends();
std::string default_backend = get_default_backend();

if (backends.empty())
throw std::runtime_error("No asyncio backend is installed");

std::string default_backend = get_default_backend();
if (default_backend.size() > 0) {
std::cout << "[backend] backend is overwritten by environ TENSORNVME_BACKEND from " << backend << std::endl;
if (default_backend.size() > 0) { // priority 1: environ is set
std::cout << "[backend] backend is overwritten by environ TENSORNVME_BACKEND from " << backend << " to " << default_backend << std::endl;
backend = default_backend;
} else if (backend.size() > 0) { // priority 2: backend is set
if (backends.find(backend) == backends.end())
throw std::runtime_error("Unsupported backend: " + backend);
}
std::cout << "[backend] using backend: " << backend << std::endl;
if (backends.find(backend) == backends.end())
throw std::runtime_error("Unsupported backend: " + backend);

if (!probe_backend(backend))
throw std::runtime_error("Backend \"" + backend + "\" is not install correctly");

#ifndef DISABLE_URING
if (backend == "uring")
return new UringAsyncIO(n_entries);
Expand Down
4 changes: 2 additions & 2 deletions csrc/py_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
py::class_<Offloader>(m, "Offloader")
.def(py::init<const std::string &, unsigned int, const std::string &>(), py::arg("filename"), py::arg("n_entries"), py::arg("backend") = "uring")
.def(py::init<const std::string &, unsigned int, const std::string &>(), py::arg("filename"), py::arg("n_entries"), py::arg("backend") = "aio")
.def("async_write", &Offloader::async_write, py::arg("tensor"), py::arg("key"), py::arg("callback") = py::none())
.def("async_read", &Offloader::async_read, py::arg("tensor"), py::arg("key"), py::arg("callback") = py::none())
.def("sync_write", &Offloader::sync_write, py::arg("tensor"), py::arg("key"))
Expand All @@ -27,7 +27,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("get_backends", get_backends);
m.def("probe_backend", probe_backend, py::arg("backend"));
py::class_<AsyncFileWriter>(m, "AsyncFileWriter")
.def(py::init<int, unsigned int>(), py::arg("fd"), py::arg("n_entries"))
.def(py::init<int, unsigned int, const std::string &>(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio")
.def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"))
.def("synchronize", &AsyncFileWriter::synchronize);
}
3 changes: 2 additions & 1 deletion include/async_file_io.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#include <string>
#include "asyncio.h"
#ifndef DISABLE_URING
#include "uring.h"
Expand All @@ -10,7 +11,7 @@
class AsyncFileWriter
{
public:
AsyncFileWriter(int fd, unsigned int n_entries);
AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend);
void write(size_t buffer, size_t n_bytes, unsigned long long offset);
void synchronize();
~AsyncFileWriter();
Expand Down
2 changes: 1 addition & 1 deletion include/offload.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class Offloader
{
public:
Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend = "uring");
Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend);
SpaceInfo prepare_write(const at::Tensor &tensor, const std::string &key);
SpaceInfo prepare_read(const at::Tensor &tensor, const std::string &key);
void async_write(const at::Tensor &tensor, const std::string &key, callback_t callback = nullptr);
Expand Down
4 changes: 2 additions & 2 deletions tensornvme/_C/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Set
from torch import Tensor

class Offloader:
def __init__(self, filename: str, n_entries: int, backend: str = "uring") -> None: ...
def __init__(self, filename: str, n_entries: int, backend: str = "aio") -> None: ...
def async_write(self, tensor: Tensor, key: str, callback: Optional[Callable[[], None]] = None) -> None: ...
def async_read(self, tensor: Tensor, key: str, callback: Optional[Callable[[], None]] = None) -> None: ...
def sync_write(self, tensor: Tensor, key: str) -> None: ...
Expand All @@ -20,6 +20,6 @@ def get_backends() -> Set[str]: ...
def probe_backend(backend: str) -> bool: ...

class AsyncFileWriter:
def __init__(self, fd: int, n_entries: int) -> None: ...
def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ...
def write(self, buffer, n_bytes: int, offset: int) -> None: ...
def synchronize(self) -> None: ...
7 changes: 5 additions & 2 deletions tensornvme/async_file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@


class AsyncFileWriter:
def __init__(self, fp: IOBase, n_entries: int = 16) -> None:
def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None:
fd = fp.fileno()
self.io = AsyncFileWriterC(fd, n_entries)
if backend is not None:
self.io = AsyncFileWriterC(fd, n_entries, backend=backend)
else:
self.io = AsyncFileWriterC(fd, n_entries)
self.fp = fp
self.offset = 0
# must ensure the data is not garbage collected
Expand Down

0 comments on commit 2b313a5

Please sign in to comment.