Skip to content

Commit

Permalink
[hotfix] fix task count (#56)
Browse files Browse the repository at this point in the history
* [hotfix] fix task count

* [hotfix] fix h2d count
  • Loading branch information
ver217 authored Nov 27, 2024
1 parent a4d34bf commit 6403388
Show file tree
Hide file tree
Showing 14 changed files with 34 additions and 32 deletions.
2 changes: 1 addition & 1 deletion csrc/aio.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "aio.h"

AIOAsyncIO::AIOAsyncIO(unsigned int n_entries)
AIOAsyncIO::AIOAsyncIO(unsigned int n_entries, unsigned int n_tasks)
{
// printf("Initializing the io Context\n");
this->max_nr = n_entries;
Expand Down
2 changes: 1 addition & 1 deletion csrc/async_file_io.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "async_file_io.h"

AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {}
AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend, unsigned int n_tasks) : fd(fd), aio(create_asyncio(n_entries, backend, n_tasks)) {}

void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback)
{
Expand Down
14 changes: 7 additions & 7 deletions csrc/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,23 @@ void probe_asyncio(const std::string &backend)
if (backend == "uring")
{
#ifndef DISABLE_URING
aio.reset(new UringAsyncIO(2));
aio.reset(new UringAsyncIO(2, 0));
#else
throw std::runtime_error("backend uring is not installed\n");
#endif
}
else if (backend == "aio")
{
#ifndef DISABLE_AIO
aio.reset(new AIOAsyncIO(2));
aio.reset(new AIOAsyncIO(2, 0));
#else
throw std::runtime_error("backend aio is not installed\n");
#endif
}
else if (backend == "pthread")
{
#ifndef DISABLE_PTHREAD
aio.reset(new PthreadAsyncIO(2));
aio.reset(new PthreadAsyncIO(2, 0));
#else
throw std::runtime_error("backend pthread is not installed\n");
#endif
Expand Down Expand Up @@ -160,7 +160,7 @@ std::string get_debug_log()
return std::string(env_);
}

AsyncIO *create_asyncio(unsigned int n_entries, std::string backend)
AsyncIO *create_asyncio(unsigned int n_entries, std::string backend, unsigned int n_tasks)
{
std::unordered_set<std::string> backends = get_backends();
std::string default_backend = get_default_backend();
Expand Down Expand Up @@ -188,15 +188,15 @@ AsyncIO *create_asyncio(unsigned int n_entries, std::string backend)

#ifndef DISABLE_URING
if (backend == "uring")
return new UringAsyncIO(n_entries);
return new UringAsyncIO(n_entries, n_tasks);
#endif
#ifndef DISABLE_AIO
if (backend == "aio")
return new AIOAsyncIO(n_entries);
return new AIOAsyncIO(n_entries, n_tasks);
#endif
#ifndef DISABLE_PTHREAD
if (backend == "pthread")
return new PthreadAsyncIO(n_entries);
return new PthreadAsyncIO(n_entries, n_tasks);
#endif
throw std::runtime_error("Unsupported backend: " + backend);
}
2 changes: 1 addition & 1 deletion csrc/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ iovec *tensors_to_iovec(const std::vector<at::Tensor> &tensors)

Offloader::Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend) : filename(filename), space_mgr(SpaceManager(0))
{
this->aio = create_asyncio(n_entries, backend);
this->aio = create_asyncio(n_entries, backend, 0);
this->fd = open(filename.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
this->aio->register_file(fd);
}
Expand Down
20 changes: 10 additions & 10 deletions csrc/pthread_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ void PthreadAsyncIO::write(int fd, void *buffer, size_t n_bytes, unsigned long l
auto val = pwrite(fd, buffer, n_bytes, offset);
if (this->is_debug)
{
auto cur_tasks = this->tasks_in_progress.fetch_sub(1);
if (cur_tasks == 1)
auto cur_tasks = this->tasks_in_progress.fetch_add(1);
if (cur_tasks + 1 == this->total_tasks)
{
if (this->debug_log.empty())
{
Expand Down Expand Up @@ -117,23 +117,23 @@ void PthreadAsyncIO::register_file(int fd) {}

void PthreadAsyncIO::register_h2d(unsigned int num_tensors)
{
this->h2d_in_progress.store(num_tensors); // register tensors to write for this run
this->total_h2d = num_tensors;
}

void PthreadAsyncIO::sync_h2d()
{
std::unique_lock<std::mutex> lock(this->mtx);
this->cv.wait(lock, [this]
{ return this->h2d_in_progress == 0; }); // block until all in-progress h2d are completed
{ return this->h2d_in_progress == this->total_h2d; }); // block until all in-progress h2d are completed
}

void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned)
{
auto stream = c10::cuda::getCurrentCUDAStream();
if (!t.is_cuda())
{
this->h2d_in_progress.fetch_sub(1); // already moved to cpu
if (this->h2d_in_progress.load() == 0)
auto cur_h2d = this->h2d_in_progress.fetch_add(1); // already moved to cpu
if (cur_h2d + 1 == this->total_h2d)
{ // notify when all h2d are completed and safe to optimizer.step()
std::lock_guard<std::mutex> lock(this->mtx);
cv.notify_one();
Expand All @@ -155,8 +155,8 @@ void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long of
{
cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu()
}
this->h2d_in_progress.fetch_sub(1);
if (this->h2d_in_progress.load() == 0)
auto cur_h2d = this->h2d_in_progress.fetch_add(1);
if (cur_h2d + 1 == this->total_h2d)
{ // notify when all h2d are completed and safe to optimizer.step()
std::lock_guard<std::mutex> lock(this->mtx);
cv.notify_one();
Expand All @@ -171,8 +171,8 @@ void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long of
auto val = pwrite(fd, buf, n_bytes, offset);
if (this->is_debug)
{
auto cur_tasks = this->tasks_in_progress.fetch_sub(1);
if (cur_tasks == 1)
auto cur_tasks = this->tasks_in_progress.fetch_add(1);
if (cur_tasks + 1 == this->total_tasks)
{
if (this->debug_log.empty())
{
Expand Down
2 changes: 1 addition & 1 deletion csrc/py_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("get_backends", get_backends);
m.def("probe_backend", probe_backend, py::arg("backend"));
py::class_<AsyncFileWriter>(m, "AsyncFileWriter")
.def(py::init<int, unsigned int, const std::string &>(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio")
.def(py::init<int, unsigned int, const std::string &, unsigned int>(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio", py::arg("n_tasks") = 0)
.def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"), py::arg("callback") = py::none())
.def("write_tensor", &AsyncFileWriter::write_tensor, py::arg("tensor"), py::arg("offset"), py::arg("callback") = py::none(), py::arg("pinned") = py::none())
.def("synchronize", &AsyncFileWriter::synchronize)
Expand Down
2 changes: 1 addition & 1 deletion csrc/uring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include <memory>
#include "uring.h"

UringAsyncIO::UringAsyncIO(unsigned int n_entries) : n_write_events(0), n_read_events(0), n_entries(n_entries)
UringAsyncIO::UringAsyncIO(unsigned int n_entries, unsigned int n_tasks) : n_write_events(0), n_read_events(0), n_entries(n_entries)
{
io_uring_queue_init(n_entries, &this->ring, 0);
}
Expand Down
2 changes: 1 addition & 1 deletion include/aio.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class AIOAsyncIO : public AsyncIO
void get_event(WaitType wt);

public:
AIOAsyncIO(unsigned int n_entries);
AIOAsyncIO(unsigned int n_entries, unsigned int n_tasks);
~AIOAsyncIO();

void write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
Expand Down
2 changes: 1 addition & 1 deletion include/async_file_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class AsyncFileWriter
{
public:
AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend);
AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend, unsigned int n_tasks);
void write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
void write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
void synchronize();
Expand Down
2 changes: 1 addition & 1 deletion include/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ std::string get_default_backend();

bool get_debug_flag();

AsyncIO *create_asyncio(unsigned int n_entries, std::string backend);
AsyncIO *create_asyncio(unsigned int n_entries, std::string backend, unsigned int n_tasks);

std::string get_debug_log();
6 changes: 4 additions & 2 deletions include/pthread_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class PthreadAsyncIO : public AsyncIO
private:
BS::thread_pool pool;
std::atomic<unsigned int> h2d_in_progress;
unsigned int total_h2d;
std::condition_variable cv;
std::mutex mtx;
std::deque<std::tuple<std::future<ssize_t>, callback_t>> write_fut;
Expand All @@ -34,10 +35,11 @@ class PthreadAsyncIO : public AsyncIO
const std::string debug_log = get_debug_log();

std::atomic<unsigned int> tasks_in_progress;
unsigned int total_tasks;

public:
PthreadAsyncIO(unsigned int n_entries)
: pool(n_entries), h2d_in_progress(0) {}
PthreadAsyncIO(unsigned int n_entries, unsigned int n_tasks)
: pool(n_entries), h2d_in_progress(0), tasks_in_progress(0), total_tasks(n_tasks), total_h2d(0) {}

~PthreadAsyncIO() {}

Expand Down
2 changes: 1 addition & 1 deletion include/uring.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class UringAsyncIO : public AsyncIO
void get_event(WaitType wt);

public:
UringAsyncIO(unsigned int n_entries);
UringAsyncIO(unsigned int n_entries, unsigned int n_tasks);
~UringAsyncIO();

void write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
Expand Down
2 changes: 1 addition & 1 deletion tensornvme/_C/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_backends() -> Set[str]: ...
def probe_backend(backend: str) -> bool: ...

class AsyncFileWriter:
def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ...
def __init__(self, fd: int, n_entries: int, backend: str = "aio", n_tasks: int = 0) -> None: ...
def write(self, buffer: int, n_bytes: int, offset: int, callback: Optional[Callable[[], None]] = None) -> None: ...
def write_tensor(
self,
Expand Down
6 changes: 3 additions & 3 deletions tensornvme/async_file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@


class AsyncFileWriter:
def __init__(self, path: str, n_entries: int = 16, backend=None) -> None:
def __init__(self, path: str, n_entries: int = 16, backend=None, n_tasks: int = 0) -> None:
# this still takes ram buffer, which may lead to OOM
# self.f = open(path, "wb", buffering=0)
self.fd = os.open(path, os.O_WRONLY | os.O_CREAT, mode=0o664)
if backend is not None:
self.io = AsyncFileWriterC(self.fd, n_entries, backend=backend)
self.io = AsyncFileWriterC(self.fd, n_entries, backend=backend, n_tasks=n_tasks)
else:
self.io = AsyncFileWriterC(self.fd, n_entries)
self.io = AsyncFileWriterC(self.fd, n_entries, n_tasks=n_tasks)
self.offset = 0
# must ensure the data is not garbage collected
self.buffers = []
Expand Down

0 comments on commit 6403388

Please sign in to comment.