From 2b313a53fb3e8adbe803b6a32889429cde394659 Mon Sep 17 00:00:00 2001 From: botbw Date: Wed, 9 Oct 2024 11:44:36 +0800 Subject: [PATCH] [backend] add backend option to async file writer --- csrc/async_file_io.cpp | 14 +------------- csrc/backend.cpp | 14 +++++++++----- csrc/py_api.cpp | 4 ++-- include/async_file_io.h | 3 ++- include/offload.h | 2 +- tensornvme/_C/__init__.pyi | 4 ++-- tensornvme/async_file_io.py | 7 +++++-- 7 files changed, 22 insertions(+), 26 deletions(-) diff --git a/csrc/async_file_io.cpp b/csrc/async_file_io.cpp index c510bc7..b474e56 100644 --- a/csrc/async_file_io.cpp +++ b/csrc/async_file_io.cpp @@ -1,20 +1,8 @@ #include "async_file_io.h" #include "backend.h" #include -#include -AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries) : fd(fd) -{ - for (const std::string &backend : get_backends()) - { - if (probe_backend(backend)) - { - this->aio = create_asyncio(n_entries, backend); - return; - } - } - throw std::runtime_error("No asyncio backend is installed"); -} +AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {} void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset) { diff --git a/csrc/backend.cpp b/csrc/backend.cpp index ad4f972..286c94c 100644 --- a/csrc/backend.cpp +++ b/csrc/backend.cpp @@ -130,19 +130,23 @@ std::string get_default_backend() { AsyncIO *create_asyncio(unsigned int n_entries, std::string backend) { std::unordered_set backends = get_backends(); + std::string default_backend = get_default_backend(); + if (backends.empty()) throw std::runtime_error("No asyncio backend is installed"); - std::string default_backend = get_default_backend(); - if (default_backend.size() > 0) { - std::cout << "[backend] backend is overwritten by environ TENSORNVME_BACKEND from " << backend << std::endl; + if (default_backend.size() > 0) { // priority 1: environ is set + std::cout << "[backend] backend is overwritten by environ TENSORNVME_BACKEND from " << backend << " to " << default_backend << std::endl; backend = default_backend; + } else if (backend.size() > 0) { // priority 2: backend is set + if (backends.find(backend) == backends.end()) + throw std::runtime_error("Unsupported backend: " + backend); } std::cout << "[backend] using backend: " << backend << std::endl; - if (backends.find(backend) == backends.end()) - throw std::runtime_error("Unsupported backend: " + backend); + if (!probe_backend(backend)) throw std::runtime_error("Backend \"" + backend + "\" is not install correctly"); + #ifndef DISABLE_URING if (backend == "uring") return new UringAsyncIO(n_entries); diff --git a/csrc/py_api.cpp b/csrc/py_api.cpp index 4d95a9f..17152d9 100644 --- a/csrc/py_api.cpp +++ b/csrc/py_api.cpp @@ -12,7 +12,7 @@ namespace py = pybind11; PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::class_(m, "Offloader") - .def(py::init(), py::arg("filename"), py::arg("n_entries"), py::arg("backend") = "uring") + .def(py::init(), py::arg("filename"), py::arg("n_entries"), py::arg("backend") = "aio") .def("async_write", &Offloader::async_write, py::arg("tensor"), py::arg("key"), py::arg("callback") = py::none()) .def("async_read", &Offloader::async_read, py::arg("tensor"), py::arg("key"), py::arg("callback") = py::none()) .def("sync_write", &Offloader::sync_write, py::arg("tensor"), py::arg("key")) @@ -27,7 +27,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("get_backends", get_backends); m.def("probe_backend", probe_backend, py::arg("backend")); py::class_(m, "AsyncFileWriter") - .def(py::init(), py::arg("fd"), py::arg("n_entries")) + .def(py::init(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio") .def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset")) .def("synchronize", &AsyncFileWriter::synchronize); } \ No newline at end of file diff --git a/include/async_file_io.h b/include/async_file_io.h index 82c4843..210e30e 100644 --- a/include/async_file_io.h +++ b/include/async_file_io.h @@ -1,4 +1,5 @@ #pragma once +#include #include "asyncio.h" #ifndef DISABLE_URING #include "uring.h" @@ -10,7 +11,7 @@ class AsyncFileWriter { public: - AsyncFileWriter(int fd, unsigned int n_entries); + AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend); void write(size_t buffer, size_t n_bytes, unsigned long long offset); void synchronize(); ~AsyncFileWriter(); diff --git a/include/offload.h b/include/offload.h index 7fc72df..6ea42f6 100644 --- a/include/offload.h +++ b/include/offload.h @@ -14,7 +14,7 @@ class Offloader { public: - Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend = "uring"); + Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend); SpaceInfo prepare_write(const at::Tensor &tensor, const std::string &key); SpaceInfo prepare_read(const at::Tensor &tensor, const std::string &key); void async_write(const at::Tensor &tensor, const std::string &key, callback_t callback = nullptr); diff --git a/tensornvme/_C/__init__.pyi b/tensornvme/_C/__init__.pyi index 84b41d1..3825d96 100644 --- a/tensornvme/_C/__init__.pyi +++ b/tensornvme/_C/__init__.pyi @@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Set from torch import Tensor class Offloader: - def __init__(self, filename: str, n_entries: int, backend: str = "uring") -> None: ... + def __init__(self, filename: str, n_entries: int, backend: str = "aio") -> None: ... def async_write(self, tensor: Tensor, key: str, callback: Optional[Callable[[], None]] = None) -> None: ... def async_read(self, tensor: Tensor, key: str, callback: Optional[Callable[[], None]] = None) -> None: ... def sync_write(self, tensor: Tensor, key: str) -> None: ... @@ -20,6 +20,6 @@ def get_backends() -> Set[str]: ... def probe_backend(backend: str) -> bool: ... class AsyncFileWriter: - def __init__(self, fd: int, n_entries: int) -> None: ... + def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ... def write(self, buffer, n_bytes: int, offset: int) -> None: ... def synchronize(self) -> None: ... diff --git a/tensornvme/async_file_io.py b/tensornvme/async_file_io.py index 4129a93..027807c 100644 --- a/tensornvme/async_file_io.py +++ b/tensornvme/async_file_io.py @@ -5,9 +5,12 @@ class AsyncFileWriter: - def __init__(self, fp: IOBase, n_entries: int = 16) -> None: + def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None: fd = fp.fileno() - self.io = AsyncFileWriterC(fd, n_entries) + if backend is not None: + self.io = AsyncFileWriterC(fd, n_entries, backend=backend) + else: + self.io = AsyncFileWriterC(fd, n_entries) self.fp = fp self.offset = 0 # must ensure the data is not garbage collected