diff --git a/csrc/async_file_io.cpp b/csrc/async_file_io.cpp index c510bc7..b474e56 100644 --- a/csrc/async_file_io.cpp +++ b/csrc/async_file_io.cpp @@ -1,20 +1,8 @@ #include "async_file_io.h" #include "backend.h" #include -#include -AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries) : fd(fd) -{ - for (const std::string &backend : get_backends()) - { - if (probe_backend(backend)) - { - this->aio = create_asyncio(n_entries, backend); - return; - } - } - throw std::runtime_error("No asyncio backend is installed"); -} +AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {} void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset) { diff --git a/csrc/backend.cpp b/csrc/backend.cpp index 98e40ff..69950ce 100644 --- a/csrc/backend.cpp +++ b/csrc/backend.cpp @@ -119,15 +119,50 @@ bool probe_backend(const std::string &backend) } } -AsyncIO *create_asyncio(unsigned int n_entries, const std::string &backend) +std::string get_default_backend() { + const char* env = getenv("TENSORNVME_BACKEND"); + if (env == nullptr) { + return std::string(""); + } + return std::string(env); +} + +bool get_debug_flag() { + const char* env_ = getenv("TENSORNVME_DEBUG"); + if (env_ == nullptr) { + return false; + } + std::string env(env_); + std::transform(env.begin(), env.end(), env.begin(), + [](unsigned char c) { return std::tolower(c); }); + return env == "1" || env == "true"; +} + +AsyncIO *create_asyncio(unsigned int n_entries, std::string backend) { std::unordered_set backends = get_backends(); + std::string default_backend = get_default_backend(); + bool is_debugging = get_debug_flag(); + if (backends.empty()) throw std::runtime_error("No asyncio backend is installed"); - if (backends.find(backend) == backends.end()) - throw std::runtime_error("Unsupported backend: " + backend); + + if (default_backend.size() > 0) { // priority 1: environ is set + if (is_debugging) { + std::cout << "[backend] backend is overwritten by environ TENSORNVME_BACKEND from " << backend << " to " << default_backend << std::endl; + } + backend = default_backend; + } else if (backend.size() > 0) { // priority 2: backend is set + if (backends.find(backend) == backends.end()) + throw std::runtime_error("Unsupported backend: " + backend); + } + if (is_debugging) { + std::cout << "[backend] using backend: " << backend << std::endl; + } + if (!probe_backend(backend)) throw std::runtime_error("Backend \"" + backend + "\" is not install correctly"); + #ifndef DISABLE_URING if (backend == "uring") return new UringAsyncIO(n_entries); diff --git a/csrc/py_api.cpp b/csrc/py_api.cpp index 4d95a9f..17152d9 100644 --- a/csrc/py_api.cpp +++ b/csrc/py_api.cpp @@ -12,7 +12,7 @@ namespace py = pybind11; PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::class_(m, "Offloader") - .def(py::init(), py::arg("filename"), py::arg("n_entries"), py::arg("backend") = "uring") + .def(py::init(), py::arg("filename"), py::arg("n_entries"), py::arg("backend") = "aio") .def("async_write", &Offloader::async_write, py::arg("tensor"), py::arg("key"), py::arg("callback") = py::none()) .def("async_read", &Offloader::async_read, py::arg("tensor"), py::arg("key"), py::arg("callback") = py::none()) .def("sync_write", &Offloader::sync_write, py::arg("tensor"), py::arg("key")) @@ -27,7 +27,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("get_backends", get_backends); m.def("probe_backend", probe_backend, py::arg("backend")); py::class_(m, "AsyncFileWriter") - .def(py::init(), py::arg("fd"), py::arg("n_entries")) + .def(py::init(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio") .def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset")) .def("synchronize", &AsyncFileWriter::synchronize); } \ No newline at end of file diff --git a/include/async_file_io.h b/include/async_file_io.h index 82c4843..210e30e 100644 --- a/include/async_file_io.h +++ b/include/async_file_io.h @@ -1,4 +1,5 @@ #pragma once +#include #include "asyncio.h" #ifndef DISABLE_URING #include "uring.h" @@ -10,7 +11,7 @@ class AsyncFileWriter { public: - AsyncFileWriter(int fd, unsigned int n_entries); + AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend); void write(size_t buffer, size_t n_bytes, unsigned long long offset); void synchronize(); ~AsyncFileWriter(); diff --git a/include/backend.h b/include/backend.h index 1d51c70..33c63a7 100644 --- a/include/backend.h +++ b/include/backend.h @@ -1,9 +1,17 @@ #include "asyncio.h" #include +#include +#include #include +#include +#include std::unordered_set get_backends(); bool probe_backend(const std::string &backend); -AsyncIO *create_asyncio(unsigned int n_entries, const std::string &backend); \ No newline at end of file +std::string get_default_backend(); + +bool get_debug_flag(); + +AsyncIO *create_asyncio(unsigned int n_entries, std::string backend); diff --git a/include/offload.h b/include/offload.h index 7fc72df..6ea42f6 100644 --- a/include/offload.h +++ b/include/offload.h @@ -14,7 +14,7 @@ class Offloader { public: - Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend = "uring"); + Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend); SpaceInfo prepare_write(const at::Tensor &tensor, const std::string &key); SpaceInfo prepare_read(const at::Tensor &tensor, const std::string &key); void async_write(const at::Tensor &tensor, const std::string &key, callback_t callback = nullptr); diff --git a/tensornvme/_C/__init__.pyi b/tensornvme/_C/__init__.pyi index 84b41d1..3825d96 100644 --- a/tensornvme/_C/__init__.pyi +++ b/tensornvme/_C/__init__.pyi @@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Set from torch import Tensor class Offloader: - def __init__(self, filename: str, n_entries: int, backend: str = "uring") -> None: ... + def __init__(self, filename: str, n_entries: int, backend: str = "aio") -> None: ... def async_write(self, tensor: Tensor, key: str, callback: Optional[Callable[[], None]] = None) -> None: ... def async_read(self, tensor: Tensor, key: str, callback: Optional[Callable[[], None]] = None) -> None: ... def sync_write(self, tensor: Tensor, key: str) -> None: ... @@ -20,6 +20,6 @@ def get_backends() -> Set[str]: ... def probe_backend(backend: str) -> bool: ... class AsyncFileWriter: - def __init__(self, fd: int, n_entries: int) -> None: ... + def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ... def write(self, buffer, n_bytes: int, offset: int) -> None: ... def synchronize(self) -> None: ... diff --git a/tensornvme/async_file_io.py b/tensornvme/async_file_io.py index 4129a93..027807c 100644 --- a/tensornvme/async_file_io.py +++ b/tensornvme/async_file_io.py @@ -5,9 +5,12 @@ class AsyncFileWriter: - def __init__(self, fp: IOBase, n_entries: int = 16) -> None: + def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None: fd = fp.fileno() - self.io = AsyncFileWriterC(fd, n_entries) + if backend is not None: + self.io = AsyncFileWriterC(fd, n_entries, backend=backend) + else: + self.io = AsyncFileWriterC(fd, n_entries) self.fp = fp self.offset = 0 # must ensure the data is not garbage collected diff --git a/tensornvme/offload.py b/tensornvme/offload.py index f6930b1..3d18fe1 100644 --- a/tensornvme/offload.py +++ b/tensornvme/offload.py @@ -7,8 +7,6 @@ class DiskOffloader(Offloader): def __init__(self, dir_name: str, n_entries: int = 16, backend: str = 'uring') -> None: - assert backend in get_backends( - ), f'Unsupported backend: {backend}, please install tensornvme with this backend' if not os.path.exists(dir_name): os.mkdir(dir_name) assert os.path.isdir(dir_name)