Skip to content

Commit

Permalink
add support for CUDA allocation flags
Browse files Browse the repository at this point in the history
  • Loading branch information
bratpiorka committed Feb 6, 2025
1 parent 1fa3f8a commit 71236fd
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 16 deletions.
9 changes: 8 additions & 1 deletion include/umf/providers/provider_cuda.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (C) 2024 Intel Corporation
* Copyright (C) 2024-2025 Intel Corporation
*
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Expand Down Expand Up @@ -53,6 +53,13 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
umf_cuda_memory_provider_params_handle_t hParams,
umf_usm_memory_type_t memoryType);

/// @brief Set the allocation flags in the parameters struct.
/// @param hParams handle to the parameters of the CUDA Memory Provider.
/// @param flags valid combination of CUDA allocation flags.
/// @return UMF_RESULT_SUCCESS on success or appropriate error code on failure.
umf_result_t umfCUDAMemoryProviderParamsSetAllocFlags(
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags);

umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void);

#ifdef __cplusplus
Expand Down
1 change: 1 addition & 0 deletions src/libumf.def
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ EXPORTS
umfScalablePoolParamsSetGranularity
umfScalablePoolParamsSetKeepAllMemory
; Added in UMF_0.11
umfCUDAMemoryProviderParamsSetAllocFlags
umfFixedMemoryProviderOps
umfFixedMemoryProviderParamsCreate
umfFixedMemoryProviderParamsDestroy
Expand Down
1 change: 1 addition & 0 deletions src/libumf.map
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ UMF_0.10 {
};

UMF_0.11 {
umfCUDAMemoryProviderParamsSetAllocFlags;
umfFixedMemoryProviderOps;
umfFixedMemoryProviderParamsCreate;
umfFixedMemoryProviderParamsDestroy;
Expand Down
75 changes: 64 additions & 11 deletions src/provider/provider_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
return UMF_RESULT_ERROR_NOT_SUPPORTED;
}

umf_result_t umfCUDAMemoryProviderParamsSetAllocFlags(
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags) {
(void)hParams;
(void)flags;
LOG_ERR("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!");
return UMF_RESULT_ERROR_NOT_SUPPORTED;
}

umf_memory_provider_ops_t *umfCUDAMemoryProviderOps(void) {
// not supported
LOG_ERR("CUDA provider is disabled (UMF_BUILD_CUDA_PROVIDER is OFF)!");
Expand Down Expand Up @@ -89,13 +97,22 @@ typedef struct cu_memory_provider_t {
CUdevice device;
umf_usm_memory_type_t memory_type;
size_t min_alignment;
unsigned int alloc_flags;
} cu_memory_provider_t;

// CUDA Memory Provider settings struct
typedef struct umf_cuda_memory_provider_params_t {
void *cuda_context_handle; ///< Handle to the CUDA context
int cuda_device_handle; ///< Handle to the CUDA device
umf_usm_memory_type_t memory_type; ///< Allocation memory type
// Handle to the CUDA context
void *cuda_context_handle;

// Handle to the CUDA device
int cuda_device_handle;

// Allocation memory type
umf_usm_memory_type_t memory_type;

// Allocation flags for cuMemHostAlloc/cuMemAllocManaged
unsigned int alloc_flags;
} umf_cuda_memory_provider_params_t;

typedef struct cu_ops_t {
Expand All @@ -104,6 +121,7 @@ typedef struct cu_ops_t {
CUmemAllocationGranularity_flags option);
CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t bytesize);
CUresult (*cuMemAllocHost)(void **pp, size_t bytesize);
CUresult (*cuMemHostAlloc)(void **pp, size_t bytesize, unsigned int flags);
CUresult (*cuMemAllocManaged)(CUdeviceptr *dptr, size_t bytesize,
unsigned int flags);
CUresult (*cuMemFree)(CUdeviceptr dptr);
Expand Down Expand Up @@ -175,6 +193,8 @@ static void init_cu_global_state(void) {
utils_get_symbol_addr(0, "cuMemAlloc_v2", lib_name);
*(void **)&g_cu_ops.cuMemAllocHost =
utils_get_symbol_addr(0, "cuMemAllocHost_v2", lib_name);
*(void **)&g_cu_ops.cuMemHostAlloc =
utils_get_symbol_addr(0, "cuMemHostAlloc", lib_name);
*(void **)&g_cu_ops.cuMemAllocManaged =
utils_get_symbol_addr(0, "cuMemAllocManaged", lib_name);
*(void **)&g_cu_ops.cuMemFree =
Expand All @@ -197,12 +217,12 @@ static void init_cu_global_state(void) {
utils_get_symbol_addr(0, "cuIpcCloseMemHandle", lib_name);

if (!g_cu_ops.cuMemGetAllocationGranularity || !g_cu_ops.cuMemAlloc ||
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemAllocManaged ||
!g_cu_ops.cuMemFree || !g_cu_ops.cuMemFreeHost ||
!g_cu_ops.cuGetErrorName || !g_cu_ops.cuGetErrorString ||
!g_cu_ops.cuCtxGetCurrent || !g_cu_ops.cuCtxSetCurrent ||
!g_cu_ops.cuIpcGetMemHandle || !g_cu_ops.cuIpcOpenMemHandle ||
!g_cu_ops.cuIpcCloseMemHandle) {
!g_cu_ops.cuMemAllocHost || !g_cu_ops.cuMemHostAlloc ||
!g_cu_ops.cuMemAllocManaged || !g_cu_ops.cuMemFree ||
!g_cu_ops.cuMemFreeHost || !g_cu_ops.cuGetErrorName ||
!g_cu_ops.cuGetErrorString || !g_cu_ops.cuCtxGetCurrent ||
!g_cu_ops.cuCtxSetCurrent || !g_cu_ops.cuIpcGetMemHandle ||
!g_cu_ops.cuIpcOpenMemHandle || !g_cu_ops.cuIpcCloseMemHandle) {
LOG_ERR("Required CUDA symbols not found.");
Init_cu_global_state_failed = true;
}
Expand All @@ -226,6 +246,7 @@ umf_result_t umfCUDAMemoryProviderParamsCreate(
params_data->cuda_context_handle = NULL;
params_data->cuda_device_handle = -1;
params_data->memory_type = UMF_MEMORY_TYPE_UNKNOWN;
params_data->alloc_flags = 0;

*hParams = params_data;

Expand Down Expand Up @@ -276,6 +297,18 @@ umf_result_t umfCUDAMemoryProviderParamsSetMemoryType(
return UMF_RESULT_SUCCESS;
}

umf_result_t umfCUDAMemoryProviderParamsSetAllocFlags(
umf_cuda_memory_provider_params_handle_t hParams, unsigned int flags) {
if (!hParams) {
LOG_ERR("CUDA Memory Provider params handle is NULL");
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

hParams->alloc_flags = flags;

return UMF_RESULT_SUCCESS;
}

static umf_result_t cu_memory_provider_initialize(void *params,
void **provider) {
if (params == NULL) {
Expand All @@ -295,6 +328,24 @@ static umf_result_t cu_memory_provider_initialize(void *params,
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

if (cu_params->memory_type == UMF_MEMORY_TYPE_SHARED) {
if (cu_params->alloc_flags == 0) {
// if flags are not set, the default setting is CU_MEM_ATTACH_GLOBAL
cu_params->alloc_flags = CU_MEM_ATTACH_GLOBAL;
} else if (cu_params->alloc_flags != CU_MEM_ATTACH_GLOBAL &&
cu_params->alloc_flags != CU_MEM_ATTACH_HOST) {
LOG_ERR("Invalid shared allocation flags");
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}
} else if (cu_params->memory_type == UMF_MEMORY_TYPE_HOST) {
if (cu_params->alloc_flags &
~(CU_MEMHOSTALLOC_PORTABLE | CU_MEMHOSTALLOC_DEVICEMAP |
CU_MEMHOSTALLOC_WRITECOMBINED)) {
LOG_ERR("Invalid host allocation flags");
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}
}

utils_init_once(&cu_is_initialized, init_cu_global_state);
if (Init_cu_global_state_failed) {
LOG_ERR("Loading CUDA symbols failed");
Expand Down Expand Up @@ -325,6 +376,7 @@ static umf_result_t cu_memory_provider_initialize(void *params,
cu_provider->device = cu_params->cuda_device_handle;
cu_provider->memory_type = cu_params->memory_type;
cu_provider->min_alignment = min_alignment;
cu_provider->alloc_flags = cu_params->alloc_flags;

*provider = cu_provider;

Expand Down Expand Up @@ -382,7 +434,8 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
CUresult cu_result = CUDA_SUCCESS;
switch (cu_provider->memory_type) {
case UMF_MEMORY_TYPE_HOST: {
cu_result = g_cu_ops.cuMemAllocHost(resultPtr, size);
cu_result =
g_cu_ops.cuMemHostAlloc(resultPtr, size, cu_provider->alloc_flags);
break;
}
case UMF_MEMORY_TYPE_DEVICE: {
Expand All @@ -391,7 +444,7 @@ static umf_result_t cu_memory_provider_alloc(void *provider, size_t size,
}
case UMF_MEMORY_TYPE_SHARED: {
cu_result = g_cu_ops.cuMemAllocManaged((CUdeviceptr *)resultPtr, size,
CU_MEM_ATTACH_GLOBAL);
cu_provider->alloc_flags);
break;
}
default:
Expand Down
33 changes: 33 additions & 0 deletions test/providers/cuda_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct libcu_ops {
CUresult (*cuMemAlloc)(CUdeviceptr *dptr, size_t size);
CUresult (*cuMemFree)(CUdeviceptr dptr);
CUresult (*cuMemAllocHost)(void **pp, size_t size);
CUresult (*cuMemHostAlloc)(void **pp, size_t size, unsigned int flags);
CUresult (*cuMemAllocManaged)(CUdeviceptr *dptr, size_t bytesize,
unsigned int flags);
CUresult (*cuMemFreeHost)(void *p);
Expand All @@ -34,6 +35,7 @@ struct libcu_ops {
CUresult (*cuPointerGetAttributes)(unsigned int numAttributes,
CUpointer_attribute *attributes,
void **data, CUdeviceptr ptr);
CUresult (*cuMemHostGetFlags)(unsigned int *pFlags, void *p);
CUresult (*cuStreamSynchronize)(CUstream hStream);
CUresult (*cuCtxSynchronize)(void);
} libcu_ops;
Expand Down Expand Up @@ -72,6 +74,9 @@ struct DlHandleCloser {
libcu_ops.cuMemAllocHost = [](auto... args) {
return noop_stub(args...);
};
libcu_ops.cuMemHostAlloc = [](auto... args) {
return noop_stub(args...);
};
libcu_ops.cuMemAllocManaged = [](auto... args) {
return noop_stub(args...);
};
Expand All @@ -90,6 +95,9 @@ struct DlHandleCloser {
libcu_ops.cuPointerGetAttributes = [](auto... args) {
return noop_stub(args...);
};
libcu_ops.cuMemHostGetFlags = [](auto... args) {
return noop_stub(args...);
};
libcu_ops.cuStreamSynchronize = [](auto... args) {
return noop_stub(args...);
};
Expand Down Expand Up @@ -170,6 +178,12 @@ int InitCUDAOps() {
fprintf(stderr, "cuMemAllocHost_v2 symbol not found in %s\n", lib_name);
return -1;
}
*(void **)&libcu_ops.cuMemHostAlloc =
utils_get_symbol_addr(cuDlHandle.get(), "cuMemHostAlloc", lib_name);
if (libcu_ops.cuMemHostAlloc == nullptr) {
fprintf(stderr, "cuMemHostAlloc symbol not found in %s\n", lib_name);
return -1;
}
*(void **)&libcu_ops.cuMemAllocManaged =
utils_get_symbol_addr(cuDlHandle.get(), "cuMemAllocManaged", lib_name);
if (libcu_ops.cuMemAllocManaged == nullptr) {
Expand Down Expand Up @@ -208,6 +222,12 @@ int InitCUDAOps() {
lib_name);
return -1;
}
*(void **)&libcu_ops.cuMemHostGetFlags =
utils_get_symbol_addr(cuDlHandle.get(), "cuMemHostGetFlags", lib_name);
if (libcu_ops.cuMemHostGetFlags == nullptr) {
fprintf(stderr, "cuMemHostGetFlags symbol not found in %s\n", lib_name);
return -1;
}
*(void **)&libcu_ops.cuStreamSynchronize = utils_get_symbol_addr(
cuDlHandle.get(), "cuStreamSynchronize", lib_name);
if (libcu_ops.cuStreamSynchronize == nullptr) {
Expand Down Expand Up @@ -237,13 +257,15 @@ int InitCUDAOps() {
libcu_ops.cuDeviceGet = cuDeviceGet;
libcu_ops.cuMemAlloc = cuMemAlloc;
libcu_ops.cuMemAllocHost = cuMemAllocHost;
libcu_ops.cuMemHostAlloc = cuMemHostAlloc;
libcu_ops.cuMemAllocManaged = cuMemAllocManaged;
libcu_ops.cuMemFree = cuMemFree;
libcu_ops.cuMemFreeHost = cuMemFreeHost;
libcu_ops.cuMemsetD32 = cuMemsetD32;
libcu_ops.cuMemcpy = cuMemcpy;
libcu_ops.cuPointerGetAttribute = cuPointerGetAttribute;
libcu_ops.cuPointerGetAttributes = cuPointerGetAttributes;
libcu_ops.cuMemHostGetFlags = cuMemHostGetFlags;
libcu_ops.cuStreamSynchronize = cuStreamSynchronize;
libcu_ops.cuCtxSynchronize = cuCtxSynchronize;

Expand Down Expand Up @@ -373,6 +395,17 @@ umf_usm_memory_type_t get_mem_type(CUcontext context, void *ptr) {
return UMF_MEMORY_TYPE_UNKNOWN;
}

unsigned int get_mem_host_alloc_flags(void *ptr) {
unsigned int flags;
CUresult res = libcu_ops.cuMemHostGetFlags(&flags, ptr);
if (res != CUDA_SUCCESS) {
fprintf(stderr, "cuPointerGetAttribute() failed!\n");
return 0;
}

return flags;
}

CUcontext get_mem_context(void *ptr) {
CUcontext context;
CUresult res = libcu_ops.cuPointerGetAttribute(
Expand Down
2 changes: 2 additions & 0 deletions test/providers/cuda_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ int cuda_copy(CUcontext context, CUdevice device, void *dst_ptr,

umf_usm_memory_type_t get_mem_type(CUcontext context, void *ptr);

unsigned int get_mem_host_alloc_flags(void *ptr);

CUcontext get_mem_context(void *ptr);

CUcontext get_current_context();
Expand Down
Loading

0 comments on commit 71236fd

Please sign in to comment.