Skip to content

Commit

Permalink
Make GpuDriver::SynchronizeContext return an absl::Status.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671920622
  • Loading branch information
klucke authored and copybara-github committed Sep 6, 2024
1 parent b3a48ed commit 18f61b0
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 21 deletions.
11 changes: 2 additions & 9 deletions xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1543,16 +1543,9 @@ absl::Status GpuDriver::WaitStreamOnEvent(Context* context, CUstream stream,
return cuda::ToStatus(cuStreamWaitEvent(stream, event, 0 /* = flags */));
}

bool GpuDriver::SynchronizeContext(Context* context) {
absl::Status GpuDriver::SynchronizeContext(Context* context) {
ScopedActivateContext activation(context);
auto status = cuda::ToStatus(cuCtxSynchronize());
if (!status.ok()) {
LOG(ERROR) << "could not synchronize on CUDA context: " << status
<< " :: " << tsl::CurrentStackTrace();
return false;
}

return true;
return cuda::ToStatus(cuCtxSynchronize());
}

absl::Status GpuDriver::SynchronizeStream(Context* context, CUstream stream) {
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ void GpuExecutor::Deallocate(DeviceMemoryBase* mem) {
}

bool GpuExecutor::SynchronizeAllActivity() {
return GpuDriver::SynchronizeContext(context_);
return GpuDriver::SynchronizeContext(context_).ok();
}

absl::Status GpuExecutor::SynchronousMemZero(DeviceMemoryBase* location,
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/gpu/gpu_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ class GpuDriver {
// have been completed, via cuCtxSynchronize.
//
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g7a54725f28d34b8c6299f0c6ca579616
static bool SynchronizeContext(Context* context);
static absl::Status SynchronizeContext(Context* context);

// Returns true if all stream tasks have completed at time of the call. Note
// the potential for races around this call (if another thread adds work to
Expand Down
13 changes: 4 additions & 9 deletions xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1436,16 +1436,11 @@ absl::Status GpuDriver::WaitStreamOnEvent(Context* context,
return absl::OkStatus();
}

bool GpuDriver::SynchronizeContext(Context* context) {
absl::Status GpuDriver::SynchronizeContext(Context* context) {
ScopedActivateContext activation{context};
hipError_t res = wrap::hipDeviceSynchronize();
if (res != hipSuccess) {
LOG(ERROR) << "could not synchronize on ROCM device: " << ToString(res)
<< " :: " << tsl::CurrentStackTrace();
return false;
}

return true;
RETURN_IF_ROCM_ERROR(wrap::hipDeviceSynchronize(),
"could not synchronize on ROCM device");
return absl::OkStatus();
}

absl::Status GpuDriver::SynchronizeStream(Context* context,
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ void GpuExecutor::Deallocate(DeviceMemoryBase* mem) {
}

bool GpuExecutor::SynchronizeAllActivity() {
return GpuDriver::SynchronizeContext(context_);
return GpuDriver::SynchronizeContext(context_).ok();
}

absl::Status GpuExecutor::SynchronousMemZero(DeviceMemoryBase* location,
Expand Down

0 comments on commit 18f61b0

Please sign in to comment.