Skip to content

Commit

Permalink
[pthread] init async gpu -> cpu (#49)
Browse files Browse the repository at this point in the history
* [pthread] init async gpu -> cpu

* [chore] add callback

* [chore] add pinned mem buffer

* [tmp] non-blocking somehow not working

* [h2d] add individual sync for h2d

* [chore] enable notify when submitting tensor write task

* [chore] remove api

* [chore] remove api
  • Loading branch information
botbw authored Nov 12, 2024
1 parent b2f9944 commit a1bf816
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 12 deletions.
21 changes: 18 additions & 3 deletions csrc/aio.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#include <stdexcept>
#include <memory>
#include "aio.h"

AIOAsyncIO::AIOAsyncIO(unsigned int n_entries)
Expand Down Expand Up @@ -126,4 +124,21 @@ void AIOAsyncIO::readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned l
io_submit(this->io_ctx, 1, &iocbs); /* 提交这个I/O不会堵塞 */

this->n_read_events++;
}
}

void AIOAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
if (t.is_cuda()) {
if (pinned.has_value()) {
pinned.value().copy_(t);
t = pinned.value();
} else {
t = t.to(torch::kCPU);
}
}
void *buffer = t.data_ptr();
size_t n_bytes = t.numel() * t.element_size();
this->write(fd, buffer, n_bytes, offset, callback);
}

void AIOAsyncIO::register_h2d(unsigned int num_tensors) {}
void AIOAsyncIO::sync_h2d() {}
15 changes: 12 additions & 3 deletions csrc/async_file_io.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
#include "asyncio.h"
#include "async_file_io.h"
#include "backend.h"
#include <stdexcept>

AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {}

Expand All @@ -11,6 +8,18 @@ void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long of
this->aio->write(this->fd, ptr, n_bytes, offset, callback);
}

void AsyncFileWriter::write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
this->aio->write_tensor(this->fd, tensor, offset, callback, pinned);
}

void AsyncFileWriter::register_h2d(unsigned int num_tensors) {
this->aio->register_h2d(num_tensors);
}

void AsyncFileWriter::sync_h2d() {
this->aio->sync_h2d();
}

void AsyncFileWriter::synchronize()
{
this->aio->synchronize();
Expand Down
47 changes: 46 additions & 1 deletion csrc/pthread_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,49 @@ void PthreadAsyncIO::synchronize() {
this->get_event(WAIT);
}

void PthreadAsyncIO::register_file(int fd) {}
void PthreadAsyncIO::register_file(int fd) {}

void PthreadAsyncIO::register_h2d(unsigned int num_tensors) {
this->h2d_in_progress.store(num_tensors); // register tensors to write for this run
}

void PthreadAsyncIO::sync_h2d() {
std::unique_lock<std::mutex> lock(this->mtx);
this->cv.wait(lock, [this] { return this->h2d_in_progress == 0; }); // block until all in-progress h2d are completed
}

void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
auto stream = c10::cuda::getCurrentCUDAStream();
if (!t.is_cuda()) {
this->h2d_in_progress.fetch_sub(1); // already moved to cpu
if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step()
std::lock_guard<std::mutex> lock(this->mtx);
cv.notify_one();
}
}
auto fut = this->pool.submit_task(
[this, fd, t, offset, pinned, stream] {
torch::Tensor cpu_tensor;
if (t.is_cuda()) {
at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html
if (pinned.has_value()) {
pinned.value().copy_(t, /*non_blocking*/ false);
cpu_tensor = pinned.value();
} else {
cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu()
}
this->h2d_in_progress.fetch_sub(1);
if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step()
std::lock_guard<std::mutex> lock(this->mtx);
cv.notify_one();
}
} else {
cpu_tensor = t;
}
void *buf = cpu_tensor.data_ptr();
size_t n_bytes = cpu_tensor.numel() * cpu_tensor.element_size();
return pwrite(fd, buf, n_bytes, offset);
}
);
this->write_fut.push_back(std::make_tuple(std::move(fut), callback));
}
5 changes: 4 additions & 1 deletion csrc/py_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
py::class_<AsyncFileWriter>(m, "AsyncFileWriter")
.def(py::init<int, unsigned int, const std::string &>(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio")
.def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"), py::arg("callback") = py::none())
.def("synchronize", &AsyncFileWriter::synchronize);
.def("write_tensor", &AsyncFileWriter::write_tensor, py::arg("tensor"), py::arg("offset"), py::arg("callback") = py::none(), py::arg("pinned") = py::none())
.def("synchronize", &AsyncFileWriter::synchronize)
.def("sync_h2d", &AsyncFileWriter::sync_h2d)
.def("register_h2d", &AsyncFileWriter::register_h2d, py::arg("num_tensors"));
}
19 changes: 18 additions & 1 deletion csrc/uring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,21 @@ void UringAsyncIO::readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned
io_uring_sqe_set_data(sqe, data);
io_uring_submit(&this->ring);
this->n_read_events++;
}
}

void UringAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
if (t.is_cuda()) {
if (pinned.has_value()) {
pinned.value().copy_(t);
t = pinned.value();
} else {
t = t.to(torch::kCPU);
}
}
void *buffer = t.data_ptr<float>();
size_t n_bytes = t.numel() * t.element_size();
this->write(fd, buffer, n_bytes, offset, callback);
}

void UringAsyncIO::register_h2d(unsigned int num_tensors) {}
void UringAsyncIO::sync_h2d() {}
6 changes: 6 additions & 0 deletions include/aio.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#pragma once

#include <libaio.h>
#include <torch/torch.h>
#include <stdexcept>
#include <memory>
#include "asyncio.h"

class AIOAsyncIO : public AsyncIO
Expand All @@ -24,9 +27,12 @@ class AIOAsyncIO : public AsyncIO
void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);

void register_h2d(unsigned int num_tensors);
void sync_h2d();
void sync_write_events();
void sync_read_events();
void synchronize();

void register_file(int fd);
void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
};
9 changes: 9 additions & 0 deletions include/async_file_io.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
#pragma once
#include <string>
#include <torch/torch.h>
#include <optional>

#include "asyncio.h"
#include "backend.h"

#ifndef DISABLE_URING
#include "uring.h"
#endif

#ifndef DISABLE_AIO
#include "aio.h"
#endif
Expand All @@ -13,7 +19,10 @@ class AsyncFileWriter
public:
AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend);
void write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
void write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
void synchronize();
void register_h2d(unsigned int num_tensors);
void sync_h2d();
~AsyncFileWriter();

private:
Expand Down
4 changes: 4 additions & 0 deletions include/asyncio.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <fcntl.h>
#include <functional>
#include <torch/torch.h>

using callback_t = std::function<void()>;

Expand Down Expand Up @@ -44,7 +45,10 @@ class AsyncIO
virtual void get_event(WaitType wt) = 0;
virtual void sync_write_events() = 0;
virtual void sync_read_events() = 0;
virtual void register_h2d(unsigned int num_tensors) = 0;
virtual void sync_h2d() = 0;
virtual void synchronize() = 0;

virtual void register_file(int fd) = 0;
virtual void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) = 0;
};
15 changes: 14 additions & 1 deletion include/pthread_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
#include <queue>
#include <tuple>
#include <functional>
#include <iostream>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
#include <atomic>
#include <condition_variable>
#include <mutex>

#include "asyncio.h"
#include "threadpool.hpp"
Expand All @@ -18,12 +24,15 @@ class PthreadAsyncIO : public AsyncIO
{
private:
BS::thread_pool pool;
std::atomic<unsigned int> h2d_in_progress;
std::condition_variable cv;
std::mutex mtx;
std::deque<std::tuple<std::future<ssize_t>, callback_t>> write_fut;
std::deque<std::tuple<std::future<ssize_t>, callback_t>> read_fut;

public:
PthreadAsyncIO(unsigned int n_entries)
: pool(n_entries) {}
: pool(n_entries), h2d_in_progress(0) {}

~PthreadAsyncIO() {}

Expand All @@ -35,7 +44,11 @@ class PthreadAsyncIO : public AsyncIO
void get_event(WaitType wt);
void sync_write_events();
void sync_read_events();
void register_h2d(unsigned int num_tensors);
void sync_h2d();
void synchronize();

void register_file(int fd);

void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
};
3 changes: 3 additions & 0 deletions include/uring.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ class UringAsyncIO : public AsyncIO
void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);

void register_h2d(unsigned int num_tensors);
void sync_h2d();
void sync_write_events();
void sync_read_events();
void synchronize();

void register_file(int fd);
void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
};
1 change: 1 addition & 0 deletions tensornvme/_C/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ def probe_backend(backend: str) -> bool: ...
class AsyncFileWriter:
def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ...
def write(self, buffer: int, n_bytes: int, offset: int, callback: Optional[Callable[[], None]] = None) -> None: ...
def write_tensor(self, tensor: Tensor, offset: int, callback: Optional[Callable[[], None]] = None, pinned: Optional[Tensor] = None) -> None: ...
def synchronize(self) -> None: ...
18 changes: 16 additions & 2 deletions tensornvme/async_file_io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import ctypes
import torch
from functools import partial

from typing import List
from torch import Tensor
from typing import List, Optional
from io import IOBase
from tensornvme._C import AsyncFileWriter as AsyncFileWriterC

Expand All @@ -16,6 +17,7 @@ def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None:
self.offset = 0
# must ensure the data is not garbage collected
self.buffers = []
self.comm_stream = torch.cuda.Stream()

def write(self, data: bytes) -> int:
ptr = ctypes.cast(data, ctypes.POINTER(ctypes.c_char))
Expand All @@ -31,6 +33,18 @@ def write_raw(self, py_ref: object, buffer: int, n_bytes: int, offset: int) -> N
self.io.write(buffer, n_bytes, offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1))
self.offset += n_bytes

def write_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None:
with torch.cuda.stream(self.comm_stream):
self.buffers.append(tensor) # append before callback is called
self.io.write_tensor(tensor, self.offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1), pinned)
self.offset += tensor.numel() * tensor.element_size()

def register_h2d(self, num_tensors: int) -> None:
self.io.register_h2d(num_tensors)

def sync_before_step(self):
self.io.sync_h2d()

@staticmethod
def gc_callback(listt: List, idx: int) -> None:
listt[idx] = None
Expand Down

0 comments on commit a1bf816

Please sign in to comment.