diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD index 9904f00f4f89a..0a6fe742dbafd 100644 --- a/xla/stream_executor/cuda/BUILD +++ b/xla/stream_executor/cuda/BUILD @@ -826,6 +826,7 @@ cuda_only_cc_library( ":cuda_kernel", # buildcleaner: keep ":cuda_platform_id", ":cuda_runtime", # buildcleaner: keep + ":cuda_status", ":cuda_version_parser", "//xla/stream_executor", "//xla/stream_executor:blas", diff --git a/xla/stream_executor/cuda/cuda_driver.cc b/xla/stream_executor/cuda/cuda_driver.cc index 3045c908ca12a..1308272ff09f8 100644 --- a/xla/stream_executor/cuda/cuda_driver.cc +++ b/xla/stream_executor/cuda/cuda_driver.cc @@ -295,14 +295,6 @@ void GpuDriver::DestroyContext(Context* context) { GetContextMap()->Remove(cuda_context->context()); } -absl::Status GpuDriver::FuncGetAttribute(CUfunction_attribute attribute, - CUfunction func, - int* attribute_value) { - return cuda::ToStatus( - cuFuncGetAttribute(attribute_value, attribute, func), - absl::StrCat("Failed to query kernel attribute: ", attribute)); -} - absl::Status GpuDriver::CreateGraph(CUgraph* graph) { VLOG(2) << "Create new CUDA graph"; TF_RETURN_IF_ERROR(cuda::ToStatus(cuGraphCreate(graph, /*flags=*/0), diff --git a/xla/stream_executor/cuda/cuda_executor.cc b/xla/stream_executor/cuda/cuda_executor.cc index e396a6cc7fff9..28e15f1546b6c 100644 --- a/xla/stream_executor/cuda/cuda_executor.cc +++ b/xla/stream_executor/cuda/cuda_executor.cc @@ -39,9 +39,9 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" -#include "xla/stream_executor/cuda/cuda_diagnostics.h" #include "xla/stream_executor/cuda/cuda_event.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/cuda/cuda_version_parser.h" #include "xla/stream_executor/cuda/delay_kernel.h" #include "xla/stream_executor/device_description.h" @@ -53,7 +53,6 @@ limitations under the License. #include "xla/stream_executor/gpu/context.h" #include "xla/stream_executor/gpu/gpu_collectives.h" #include "xla/stream_executor/gpu/gpu_command_buffer.h" -#include "xla/stream_executor/gpu/gpu_diagnostics.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_kernel.h" @@ -91,6 +90,13 @@ bool ShouldLaunchDelayKernel() { return value; } +absl::Status FuncGetAttribute(CUfunction_attribute attribute, CUfunction func, + int* attribute_value) { + return cuda::ToStatus( + cuFuncGetAttribute(attribute_value, attribute, func), + absl::StrCat("Failed to query kernel attribute: ", attribute)); +} + } // namespace // Given const GPU memory, returns a libcuda device pointer datatype, suitable @@ -423,13 +429,12 @@ CudaExecutor::CreateOrShareConstant(Stream* stream, absl::Status CudaExecutor::GetKernelMetadata(GpuKernel* cuda_kernel, KernelMetadata* kernel_metadata) { int value; - TF_RETURN_IF_ERROR(GpuDriver::FuncGetAttribute( - CU_FUNC_ATTRIBUTE_NUM_REGS, cuda_kernel->gpu_function(), &value)); + TF_RETURN_IF_ERROR(FuncGetAttribute(CU_FUNC_ATTRIBUTE_NUM_REGS, + cuda_kernel->gpu_function(), &value)); kernel_metadata->set_registers_per_thread(value); - TF_RETURN_IF_ERROR( - GpuDriver::FuncGetAttribute(CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, - cuda_kernel->gpu_function(), &value)); + TF_RETURN_IF_ERROR(FuncGetAttribute(CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, + cuda_kernel->gpu_function(), &value)); kernel_metadata->set_shared_memory_bytes(value); return absl::OkStatus(); } diff --git a/xla/stream_executor/gpu/gpu_driver.h b/xla/stream_executor/gpu/gpu_driver.h index 4befc7db65d0d..447201e63796e 100644 --- a/xla/stream_executor/gpu/gpu_driver.h +++ b/xla/stream_executor/gpu/gpu_driver.h @@ -167,15 +167,6 @@ class GpuDriver { // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g27a365aebb0eb548166309f58a1e8b8e static void DestroyContext(Context* context); - // Queries the runtime for the specified attribute of the specified function. - // cuFuncGetAttribute (the underlying CUDA driver API routine) only operates - // in terms of integer-sized values, so there's no potential for overrun (as - // of CUDA 5.5). - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g5e92a1b0d8d1b82cb00dcfb2de15961b - static absl::Status FuncGetAttribute(GpuFunctionAttribute attribute, - GpuFunctionHandle function, - int* attribute_value); - // Launches a CUDA/ROCm kernel via cuLaunchKernel/hipModuleLaunchKernel. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15 // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#execution-control diff --git a/xla/stream_executor/gpu/gpu_types.h b/xla/stream_executor/gpu/gpu_types.h index 011e8f251edbc..a47b5d81d66d2 100644 --- a/xla/stream_executor/gpu/gpu_types.h +++ b/xla/stream_executor/gpu/gpu_types.h @@ -47,14 +47,12 @@ struct UnsupportedGpuFeature {}; using GpuStreamHandle = ::sycl::queue*; using GpuEventHandle = ::sycl::event*; using GpuFunctionHandle = ::sycl::kernel*; -using GpuFunctionAttribute = UnsupportedGpuFeature; using GpuDeviceHandle = ::sycl::device*; using GpuDevicePtr = void*; using GpuDeviceAttribute = UnsupportedGpuFeature; using GpuDeviceProperty = UnsupportedGpuFeature; using GpuModuleHandle = ze_module_handle_t; using GpuFuncCachePreference = UnsupportedGpuFeature; -using GpuRngHandle = UnsupportedGpuFeature; using GpuGraphHandle = UnsupportedGpuFeature; using GpuGraphExecHandle = UnsupportedGpuFeature; using GpuGraphNodeHandle = UnsupportedGpuFeature; @@ -65,14 +63,12 @@ using GpuGraphConditionalHandle = UnsupportedGpuFeature; using GpuStreamHandle = hipStream_t; using GpuEventHandle = hipEvent_t; using GpuFunctionHandle = hipFunction_t; -using GpuFunctionAttribute = hipFunction_attribute; using GpuDeviceHandle = hipDevice_t; using GpuDevicePtr = hipDeviceptr_t; using GpuDeviceAttribute = hipDeviceAttribute_t; using GpuDeviceProperty = hipDeviceProp_t; using GpuModuleHandle = hipModule_t; using GpuFuncCachePreference = hipFuncCache_t; -using GpuRngHandle = hiprandGenerator_t; using GpuGraphHandle = hipGraph_t; using GpuGraphExecHandle = hipGraphExec_t; using GpuGraphNodeHandle = hipGraphNode_t; @@ -82,7 +78,6 @@ using GpuGraphConditionalHandle = UnsupportedGpuFeature; using GpuStreamHandle = CUstream; using GpuEventHandle = CUevent; using GpuFunctionHandle = CUfunction; -using GpuFunctionAttribute = CUfunction_attribute; using GpuDeviceHandle = CUdevice; using GpuDevicePtr = CUdeviceptr; using GpuDeviceAttribute = CUdevice_attribute; diff --git a/xla/stream_executor/rocm/rocm_driver.cc b/xla/stream_executor/rocm/rocm_driver.cc index 3381b97c5553e..81f7f3d76fbd7 100644 --- a/xla/stream_executor/rocm/rocm_driver.cc +++ b/xla/stream_executor/rocm/rocm_driver.cc @@ -366,15 +366,6 @@ void GpuDriver::DestroyContext(Context* context) { GetContextMap()->Remove(gpu_context->context()); } -absl::Status GpuDriver::FuncGetAttribute(hipFunction_attribute attribute, - hipFunction_t func, - int* attribute_value) { - RETURN_IF_ROCM_ERROR( - wrap::hipFuncGetAttribute(attribute_value, attribute, func), - "Failed to query kernel attribute: ", attribute); - return absl::OkStatus(); -} - absl::Status GpuDriver::CreateGraph(hipGraph_t* graph) { VLOG(2) << "Create new HIP graph"; RETURN_IF_ROCM_ERROR(wrap::hipGraphCreate(graph, /*flags=*/0), diff --git a/xla/stream_executor/rocm/rocm_executor.cc b/xla/stream_executor/rocm/rocm_executor.cc index 251f72c8fbcb1..e46408be1974f 100644 --- a/xla/stream_executor/rocm/rocm_executor.cc +++ b/xla/stream_executor/rocm/rocm_executor.cc @@ -71,6 +71,7 @@ limitations under the License. #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/rocm/rocm_diagnostics.h" #include "xla/stream_executor/rocm/rocm_driver.h" +#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_event.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/rocm/rocm_version_parser.h" @@ -83,6 +84,19 @@ limitations under the License. #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" +#define RETURN_IF_ROCM_ERROR(expr, ...) \ + do { \ + hipError_t _res = (expr); \ + if (TF_PREDICT_FALSE(_res != hipSuccess)) { \ + if (_res == hipErrorOutOfMemory) \ + return absl::ResourceExhaustedError(absl::StrCat( \ + __VA_ARGS__, ":", ::stream_executor::gpu::ToString(_res))); \ + else \ + return absl::InternalError(absl::StrCat( \ + __VA_ARGS__, ": ", ::stream_executor::gpu::ToString(_res))); \ + } \ + } while (0) + namespace stream_executor { namespace gpu { @@ -135,6 +149,14 @@ int fpus_per_core(std::string gcn_arch_name) { } return n; } + +absl::Status FuncGetAttribute(hipFunction_attribute attribute, + hipFunction_t func, int* attribute_value) { + RETURN_IF_ROCM_ERROR( + wrap::hipFuncGetAttribute(attribute_value, attribute, func), + "Failed to query kernel attribute: ", attribute); + return absl::OkStatus(); +} } // namespace absl::StatusOr> @@ -317,13 +339,12 @@ absl::StatusOr> RocmExecutor::LoadKernel( absl::Status RocmExecutor::GetKernelMetadata(GpuKernel* rocm_kernel, KernelMetadata* kernel_metadata) { int value = 0; - TF_RETURN_IF_ERROR(GpuDriver::FuncGetAttribute( - HIP_FUNC_ATTRIBUTE_NUM_REGS, rocm_kernel->gpu_function(), &value)); + TF_RETURN_IF_ERROR(FuncGetAttribute(HIP_FUNC_ATTRIBUTE_NUM_REGS, + rocm_kernel->gpu_function(), &value)); kernel_metadata->set_registers_per_thread(value); - TF_RETURN_IF_ERROR( - GpuDriver::FuncGetAttribute(HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, - rocm_kernel->gpu_function(), &value)); + TF_RETURN_IF_ERROR(FuncGetAttribute(HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, + rocm_kernel->gpu_function(), &value)); kernel_metadata->set_shared_memory_bytes(value); return absl::OkStatus(); }