Skip to content

Commit

Permalink
[tmp] non-blocking somehow not working
Browse files Browse the repository at this point in the history
  • Loading branch information
botbw committed Nov 5, 2024
1 parent 2e88470 commit 39ea874
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 16 deletions.
32 changes: 16 additions & 16 deletions csrc/pthread_backend.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "pthread_backend.h"

#include <iostream>

void PthreadAsyncIO::write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) {
auto fut = this->pool.submit_task(
[fd, buffer, n_bytes, offset] {
Expand Down Expand Up @@ -81,21 +79,23 @@ void PthreadAsyncIO::synchronize() {
void PthreadAsyncIO::register_file(int fd) {}

void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
auto stream = c10::cuda::getCurrentCUDAStream();
at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html
auto event_ptr = std::make_shared<c10::Event>(torch::kCUDA); // make a shared ptr here since event is not copyable
if (t.is_cuda()) {
if (pinned.has_value()) {
pinned.value().copy_(t, /*non_blocking*/ true);
t = pinned.value();
} else {
t = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ true, /*copy*/ false); // modified from torch::Tensor::cpu()
}
}
event_ptr->record(stream);
auto fut = this->pool.submit_task(
[fd, t, offset, pinned] {
torch::Tensor cpu_tensor;
if (t.is_cuda()) {
if (pinned.has_value()) {
pinned.value().copy_(t);
cpu_tensor = pinned.value();
} else {
cpu_tensor = t.to(torch::kCPU);
}
} else {
cpu_tensor = t;
}
void *buf = cpu_tensor.data_ptr();
size_t n_bytes = cpu_tensor.numel() * cpu_tensor.element_size();
[fd, t, offset, pinned, event_ptr] {
event_ptr->synchronize(); // sync with comm stream
void *buf = t.data_ptr();
size_t n_bytes = t.numel() * t.element_size();
return pwrite(fd, buf, n_bytes, offset);
}
);
Expand Down
3 changes: 3 additions & 0 deletions include/pthread_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
#include <queue>
#include <tuple>
#include <functional>
#include <iostream>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>

#include "asyncio.h"
#include "threadpool.hpp"
Expand Down
10 changes: 10 additions & 0 deletions tensornvme/async_file_io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ctypes
import torch
from functools import partial
from torch import Tensor
from typing import List, Optional
Expand All @@ -16,6 +17,7 @@ def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None:
self.offset = 0
# must ensure the data is not garbage collected
self.buffers = []
self.comm_stream = torch.cuda.Stream()

def write(self, data: bytes) -> int:
ptr = ctypes.cast(data, ctypes.POINTER(ctypes.c_char))
Expand All @@ -36,6 +38,14 @@ def write_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None:
self.io.write_tensor(tensor, self.offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1), pinned)
self.offset += tensor.numel() * tensor.element_size()

def write_gpu_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None:
assert tensor.device.type == 'cuda', f"tensor must be on cuda device, got {tensor.device}"
with torch.cuda.stream(self.comm_stream):
self.write_tensor(tensor, pinned)

def sync_before_step(self):
self.comm_stream.synchronize()

@staticmethod
def gc_callback(listt: List, idx: int) -> None:
listt[idx] = None
Expand Down

0 comments on commit 39ea874

Please sign in to comment.