Skip to content

Commit

Permalink
Remove FuncGetAttribute from gpu_driver.h.
Browse files Browse the repository at this point in the history
This enabled some deletions from gpu_types.h, and the simplification of the gpu_driver interface.

PiperOrigin-RevId: 676119615
  • Loading branch information
klucke authored and Google-ML-Automation committed Sep 18, 2024
1 parent d2434f2 commit 0b5b884
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 43 deletions.
1 change: 1 addition & 0 deletions xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,7 @@ cuda_only_cc_library(
":cuda_kernel", # buildcleaner: keep
":cuda_platform_id",
":cuda_runtime", # buildcleaner: keep
":cuda_status",
":cuda_version_parser",
"//xla/stream_executor",
"//xla/stream_executor:blas",
Expand Down
8 changes: 0 additions & 8 deletions xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,6 @@ void GpuDriver::DestroyContext(Context* context) {
GetContextMap()->Remove(cuda_context->context());
}

absl::Status GpuDriver::FuncGetAttribute(CUfunction_attribute attribute,
CUfunction func,
int* attribute_value) {
return cuda::ToStatus(
cuFuncGetAttribute(attribute_value, attribute, func),
absl::StrCat("Failed to query kernel attribute: ", attribute));
}

absl::Status GpuDriver::CreateGraph(CUgraph* graph) {
VLOG(2) << "Create new CUDA graph";
TF_RETURN_IF_ERROR(cuda::ToStatus(cuGraphCreate(graph, /*flags=*/0),
Expand Down
19 changes: 12 additions & 7 deletions xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/stream_executor/blas.h"
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/cuda/cuda_diagnostics.h"
#include "xla/stream_executor/cuda/cuda_event.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/cuda/cuda_status.h"
#include "xla/stream_executor/cuda/cuda_version_parser.h"
#include "xla/stream_executor/cuda/delay_kernel.h"
#include "xla/stream_executor/device_description.h"
Expand All @@ -53,7 +53,6 @@ limitations under the License.
#include "xla/stream_executor/gpu/context.h"
#include "xla/stream_executor/gpu/gpu_collectives.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
#include "xla/stream_executor/gpu/gpu_diagnostics.h"
#include "xla/stream_executor/gpu/gpu_driver.h"
#include "xla/stream_executor/gpu/gpu_event.h"
#include "xla/stream_executor/gpu/gpu_kernel.h"
Expand Down Expand Up @@ -91,6 +90,13 @@ bool ShouldLaunchDelayKernel() {
return value;
}

absl::Status FuncGetAttribute(CUfunction_attribute attribute, CUfunction func,
int* attribute_value) {
return cuda::ToStatus(
cuFuncGetAttribute(attribute_value, attribute, func),
absl::StrCat("Failed to query kernel attribute: ", attribute));
}

} // namespace

// Given const GPU memory, returns a libcuda device pointer datatype, suitable
Expand Down Expand Up @@ -423,13 +429,12 @@ CudaExecutor::CreateOrShareConstant(Stream* stream,
absl::Status CudaExecutor::GetKernelMetadata(GpuKernel* cuda_kernel,
KernelMetadata* kernel_metadata) {
int value;
TF_RETURN_IF_ERROR(GpuDriver::FuncGetAttribute(
CU_FUNC_ATTRIBUTE_NUM_REGS, cuda_kernel->gpu_function(), &value));
TF_RETURN_IF_ERROR(FuncGetAttribute(CU_FUNC_ATTRIBUTE_NUM_REGS,
cuda_kernel->gpu_function(), &value));
kernel_metadata->set_registers_per_thread(value);

TF_RETURN_IF_ERROR(
GpuDriver::FuncGetAttribute(CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES,
cuda_kernel->gpu_function(), &value));
TF_RETURN_IF_ERROR(FuncGetAttribute(CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES,
cuda_kernel->gpu_function(), &value));
kernel_metadata->set_shared_memory_bytes(value);
return absl::OkStatus();
}
Expand Down
9 changes: 0 additions & 9 deletions xla/stream_executor/gpu/gpu_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,6 @@ class GpuDriver {
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g27a365aebb0eb548166309f58a1e8b8e
static void DestroyContext(Context* context);

// Queries the runtime for the specified attribute of the specified function.
// cuFuncGetAttribute (the underlying CUDA driver API routine) only operates
// in terms of integer-sized values, so there's no potential for overrun (as
// of CUDA 5.5).
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g5e92a1b0d8d1b82cb00dcfb2de15961b
static absl::Status FuncGetAttribute(GpuFunctionAttribute attribute,
GpuFunctionHandle function,
int* attribute_value);

// Launches a CUDA/ROCm kernel via cuLaunchKernel/hipModuleLaunchKernel.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#execution-control
Expand Down
5 changes: 0 additions & 5 deletions xla/stream_executor/gpu/gpu_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,12 @@ struct UnsupportedGpuFeature {};
using GpuStreamHandle = ::sycl::queue*;
using GpuEventHandle = ::sycl::event*;
using GpuFunctionHandle = ::sycl::kernel*;
using GpuFunctionAttribute = UnsupportedGpuFeature;
using GpuDeviceHandle = ::sycl::device*;
using GpuDevicePtr = void*;
using GpuDeviceAttribute = UnsupportedGpuFeature;
using GpuDeviceProperty = UnsupportedGpuFeature;
using GpuModuleHandle = ze_module_handle_t;
using GpuFuncCachePreference = UnsupportedGpuFeature;
using GpuRngHandle = UnsupportedGpuFeature;
using GpuGraphHandle = UnsupportedGpuFeature;
using GpuGraphExecHandle = UnsupportedGpuFeature;
using GpuGraphNodeHandle = UnsupportedGpuFeature;
Expand All @@ -65,14 +63,12 @@ using GpuGraphConditionalHandle = UnsupportedGpuFeature;
using GpuStreamHandle = hipStream_t;
using GpuEventHandle = hipEvent_t;
using GpuFunctionHandle = hipFunction_t;
using GpuFunctionAttribute = hipFunction_attribute;
using GpuDeviceHandle = hipDevice_t;
using GpuDevicePtr = hipDeviceptr_t;
using GpuDeviceAttribute = hipDeviceAttribute_t;
using GpuDeviceProperty = hipDeviceProp_t;
using GpuModuleHandle = hipModule_t;
using GpuFuncCachePreference = hipFuncCache_t;
using GpuRngHandle = hiprandGenerator_t;
using GpuGraphHandle = hipGraph_t;
using GpuGraphExecHandle = hipGraphExec_t;
using GpuGraphNodeHandle = hipGraphNode_t;
Expand All @@ -82,7 +78,6 @@ using GpuGraphConditionalHandle = UnsupportedGpuFeature;
using GpuStreamHandle = CUstream;
using GpuEventHandle = CUevent;
using GpuFunctionHandle = CUfunction;
using GpuFunctionAttribute = CUfunction_attribute;
using GpuDeviceHandle = CUdevice;
using GpuDevicePtr = CUdeviceptr;
using GpuDeviceAttribute = CUdevice_attribute;
Expand Down
9 changes: 0 additions & 9 deletions xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,15 +366,6 @@ void GpuDriver::DestroyContext(Context* context) {
GetContextMap()->Remove(gpu_context->context());
}

absl::Status GpuDriver::FuncGetAttribute(hipFunction_attribute attribute,
hipFunction_t func,
int* attribute_value) {
RETURN_IF_ROCM_ERROR(
wrap::hipFuncGetAttribute(attribute_value, attribute, func),
"Failed to query kernel attribute: ", attribute);
return absl::OkStatus();
}

absl::Status GpuDriver::CreateGraph(hipGraph_t* graph) {
VLOG(2) << "Create new HIP graph";
RETURN_IF_ROCM_ERROR(wrap::hipGraphCreate(graph, /*flags=*/0),
Expand Down
31 changes: 26 additions & 5 deletions xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ limitations under the License.
#include "xla/stream_executor/plugin_registry.h"
#include "xla/stream_executor/rocm/rocm_diagnostics.h"
#include "xla/stream_executor/rocm/rocm_driver.h"
#include "xla/stream_executor/rocm/rocm_driver_wrapper.h"
#include "xla/stream_executor/rocm/rocm_event.h"
#include "xla/stream_executor/rocm/rocm_platform_id.h"
#include "xla/stream_executor/rocm/rocm_version_parser.h"
Expand All @@ -83,6 +84,19 @@ limitations under the License.
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"

#define RETURN_IF_ROCM_ERROR(expr, ...) \
do { \
hipError_t _res = (expr); \
if (TF_PREDICT_FALSE(_res != hipSuccess)) { \
if (_res == hipErrorOutOfMemory) \
return absl::ResourceExhaustedError(absl::StrCat( \
__VA_ARGS__, ":", ::stream_executor::gpu::ToString(_res))); \
else \
return absl::InternalError(absl::StrCat( \
__VA_ARGS__, ": ", ::stream_executor::gpu::ToString(_res))); \
} \
} while (0)

namespace stream_executor {
namespace gpu {

Expand Down Expand Up @@ -135,6 +149,14 @@ int fpus_per_core(std::string gcn_arch_name) {
}
return n;
}

absl::Status FuncGetAttribute(hipFunction_attribute attribute,
hipFunction_t func, int* attribute_value) {
RETURN_IF_ROCM_ERROR(
wrap::hipFuncGetAttribute(attribute_value, attribute, func),
"Failed to query kernel attribute: ", attribute);
return absl::OkStatus();
}
} // namespace

absl::StatusOr<std::shared_ptr<DeviceMemoryBase>>
Expand Down Expand Up @@ -317,13 +339,12 @@ absl::StatusOr<std::unique_ptr<Kernel>> RocmExecutor::LoadKernel(
absl::Status RocmExecutor::GetKernelMetadata(GpuKernel* rocm_kernel,
KernelMetadata* kernel_metadata) {
int value = 0;
TF_RETURN_IF_ERROR(GpuDriver::FuncGetAttribute(
HIP_FUNC_ATTRIBUTE_NUM_REGS, rocm_kernel->gpu_function(), &value));
TF_RETURN_IF_ERROR(FuncGetAttribute(HIP_FUNC_ATTRIBUTE_NUM_REGS,
rocm_kernel->gpu_function(), &value));
kernel_metadata->set_registers_per_thread(value);

TF_RETURN_IF_ERROR(
GpuDriver::FuncGetAttribute(HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES,
rocm_kernel->gpu_function(), &value));
TF_RETURN_IF_ERROR(FuncGetAttribute(HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES,
rocm_kernel->gpu_function(), &value));
kernel_metadata->set_shared_memory_bytes(value);
return absl::OkStatus();
}
Expand Down

0 comments on commit 0b5b884

Please sign in to comment.