diff --git a/csrc/pthread_backend.cpp b/csrc/pthread_backend.cpp index ff7da42..04f0221 100644 --- a/csrc/pthread_backend.cpp +++ b/csrc/pthread_backend.cpp @@ -1,7 +1,5 @@ #include "pthread_backend.h" -#include - void PthreadAsyncIO::write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) { auto fut = this->pool.submit_task( [fd, buffer, n_bytes, offset] { @@ -81,21 +79,23 @@ void PthreadAsyncIO::synchronize() { void PthreadAsyncIO::register_file(int fd) {} void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional pinned) { + auto stream = c10::cuda::getCurrentCUDAStream(); + at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html + auto event_ptr = std::make_shared(torch::kCUDA); // make a shared ptr here since event is not copyable + if (t.is_cuda()) { + if (pinned.has_value()) { + pinned.value().copy_(t, /*non_blocking*/ true); + t = pinned.value(); + } else { + t = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ true, /*copy*/ false); // modified from torch::Tensor::cpu() + } + } + event_ptr->record(stream); auto fut = this->pool.submit_task( - [fd, t, offset, pinned] { - torch::Tensor cpu_tensor; - if (t.is_cuda()) { - if (pinned.has_value()) { - pinned.value().copy_(t); - cpu_tensor = pinned.value(); - } else { - cpu_tensor = t.to(torch::kCPU); - } - } else { - cpu_tensor = t; - } - void *buf = cpu_tensor.data_ptr(); - size_t n_bytes = cpu_tensor.numel() * cpu_tensor.element_size(); + [fd, t, offset, pinned, event_ptr] { + event_ptr->synchronize(); // sync with comm stream + void *buf = t.data_ptr(); + size_t n_bytes = t.numel() * t.element_size(); return pwrite(fd, buf, n_bytes, offset); } ); diff --git a/include/pthread_backend.h b/include/pthread_backend.h index 97ce561..75c83b9 100644 --- a/include/pthread_backend.h +++ b/include/pthread_backend.h @@ -9,6 +9,9 @@ #include #include #include +#include +#include +#include #include "asyncio.h" #include "threadpool.hpp" diff --git a/tensornvme/async_file_io.py b/tensornvme/async_file_io.py index 48fd66a..b223a01 100644 --- a/tensornvme/async_file_io.py +++ b/tensornvme/async_file_io.py @@ -1,4 +1,5 @@ import ctypes +import torch from functools import partial from torch import Tensor from typing import List, Optional @@ -16,6 +17,7 @@ def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None: self.offset = 0 # must ensure the data is not garbage collected self.buffers = [] + self.comm_stream = torch.cuda.Stream() def write(self, data: bytes) -> int: ptr = ctypes.cast(data, ctypes.POINTER(ctypes.c_char)) @@ -36,6 +38,14 @@ def write_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None: self.io.write_tensor(tensor, self.offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1), pinned) self.offset += tensor.numel() * tensor.element_size() + def write_gpu_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None: + assert tensor.device.type == 'cuda', f"tensor must be on cuda device, got {tensor.device}" + with torch.cuda.stream(self.comm_stream): + self.write_tensor(tensor, pinned) + + def sync_before_step(self): + self.comm_stream.synchronize() + @staticmethod def gc_callback(listt: List, idx: int) -> None: listt[idx] = None