Skip to content

Commit

Permalink
Make GpuDriver::AsynchronousMemcpyH2D return an absl::Status instead …
Browse files Browse the repository at this point in the history
…of a bool.

PiperOrigin-RevId: 671794765
  • Loading branch information
klucke authored and copybara-github committed Sep 6, 2024
1 parent 6ab86d6 commit 8e9efb1
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 38 deletions.
22 changes: 8 additions & 14 deletions xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1638,24 +1638,18 @@ bool GpuDriver::AsynchronousMemcpyD2H(Context* context, void* host_dst,
return true;
}

bool GpuDriver::AsynchronousMemcpyH2D(Context* context, CUdeviceptr gpu_dst,
const void* host_src, uint64_t size,
CUstream stream) {
absl::Status GpuDriver::AsynchronousMemcpyH2D(Context* context,
CUdeviceptr gpu_dst,
const void* host_src,
uint64_t size, CUstream stream) {
ScopedActivateContext activation(context);
auto status =
cuda::ToStatus(cuMemcpyHtoDAsync(gpu_dst, host_src, size, stream));
if (!status.ok()) {
LOG(ERROR) << absl::StrFormat(
"failed to enqueue async memcpy from host to device: %s; GPU dst: %p; "
"host src: %p; size: %u=0x%x",
status.ToString(), absl::bit_cast<void*>(gpu_dst), host_src, size,
size);
return false;
}
TF_RETURN_IF_ERROR(
cuda::ToStatus(cuMemcpyHtoDAsync(gpu_dst, host_src, size, stream)));

VLOG(2) << "successfully enqueued async memcpy h2d of " << size << " bytes"
<< " from " << host_src << " to " << absl::bit_cast<void*>(gpu_dst)
<< " on stream " << stream;
return true;
return absl::OkStatus();
}

absl::Status GpuDriver::AsynchronousMemcpyD2D(Context* context,
Expand Down
7 changes: 4 additions & 3 deletions xla/stream_executor/gpu/gpu_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,10 @@ class GpuDriver {
static bool AsynchronousMemcpyD2H(Context* context, void* host_dst,
GpuDevicePtr gpu_src, uint64_t size,
GpuStreamHandle stream);
static bool AsynchronousMemcpyH2D(Context* context, GpuDevicePtr gpu_dst,
const void* host_src, uint64_t size,
GpuStreamHandle stream);
static absl::Status AsynchronousMemcpyH2D(Context* context,
GpuDevicePtr gpu_dst,
const void* host_src, uint64_t size,
GpuStreamHandle stream);
static absl::Status AsynchronousMemcpyD2D(Context* context,
GpuDevicePtr gpu_dst,
GpuDevicePtr gpu_src, uint64_t size,
Expand Down
6 changes: 1 addition & 5 deletions xla/stream_executor/gpu/gpu_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,9 @@ absl::Status GpuStream::Memcpy(DeviceMemoryBase* gpu_dst,

absl::Status GpuStream::Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src,
uint64_t size) {
bool ok = GpuDriver::AsynchronousMemcpyH2D(
return GpuDriver::AsynchronousMemcpyH2D(
parent_->gpu_context(), reinterpret_cast<GpuDevicePtr>(gpu_dst->opaque()),
host_src, size, gpu_stream());
if (!ok) {
return absl::InternalError("Failed to memcpy from device to host.");
}
return absl::OkStatus();
}

absl::Status GpuStream::Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src,
Expand Down
30 changes: 14 additions & 16 deletions xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1548,27 +1548,25 @@ void GpuDriver::DestroyStream(Context* context, GpuStreamHandle stream) {
return true;
}

/* static */ bool GpuDriver::AsynchronousMemcpyH2D(Context* context,
hipDeviceptr_t gpu_dst,
const void* host_src,
uint64_t size,
GpuStreamHandle stream) {
absl::Status GpuDriver::AsynchronousMemcpyH2D(Context* context,
hipDeviceptr_t gpu_dst,
const void* host_src,
uint64_t size,
GpuStreamHandle stream) {
ScopedActivateContext activation{context};
hipError_t res = wrap::hipMemcpyHtoDAsync(
gpu_dst, const_cast<void*>(host_src), size, stream);
if (res != hipSuccess) {
LOG(ERROR) << absl::StrFormat(
"failed to enqueue async memcpy from host to device: %s; Gpu dst: %p; "
"host src: %p; size: %llu=0x%llx",
ToString(res).c_str(), absl::bit_cast<void*>(gpu_dst), host_src, size,
size);
return false;
}
RETURN_IF_ROCM_ERROR(
wrap::hipMemcpyHtoDAsync(gpu_dst, const_cast<void*>(host_src), size,
stream),
absl::StrFormat(
"failed to enqueue async memcpy from host to device: Gpu dst: %p; "
"host src: %p; size: %llu=0x%llx",
absl::bit_cast<void*>(gpu_dst), host_src, size, size));

VLOG(2) << "successfully enqueued async memcpy h2d of " << size
<< " bytes from " << host_src << " to "
<< absl::bit_cast<void*>(gpu_dst) << " on stream " << stream
<< " device: " << context->device_ordinal();
return true;
return absl::OkStatus();
}

absl::Status GpuDriver::AsynchronousMemcpyD2D(Context* context,
Expand Down

0 comments on commit 8e9efb1

Please sign in to comment.