-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add support for CUDA allocation flags #1079
base: main
Are you sure you want to change the base?
Conversation
4114f01
to
9d300d4
Compare
6d18197
to
0bbce23
Compare
0bbce23
to
f2dd498
Compare
src/provider/provider_cuda.c
Outdated
if (cu_params->memory_type == UMF_MEMORY_TYPE_SHARED) { | ||
if (cu_params->alloc_flags == 0) { | ||
// if flags are not set, the default setting is CU_MEM_ATTACH_GLOBAL | ||
cu_params->alloc_flags = CU_MEM_ATTACH_GLOBAL; | ||
} else if (cu_params->alloc_flags != CU_MEM_ATTACH_GLOBAL && | ||
cu_params->alloc_flags != CU_MEM_ATTACH_HOST) { | ||
LOG_ERR("Invalid shared allocation flags"); | ||
return UMF_RESULT_ERROR_INVALID_ARGUMENT; | ||
} | ||
} else if (cu_params->memory_type == UMF_MEMORY_TYPE_HOST) { | ||
if (cu_params->alloc_flags & | ||
~(CU_MEMHOSTALLOC_PORTABLE | CU_MEMHOSTALLOC_DEVICEMAP | | ||
CU_MEMHOSTALLOC_WRITECOMBINED)) { | ||
LOG_ERR("Invalid host allocation flags"); | ||
return UMF_RESULT_ERROR_INVALID_ARGUMENT; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the CUDA Driver API spec, these flags are defined as follows:
#define CU_MEMHOSTALLOC_PORTABLE 0x01
#define CU_MEMHOSTALLOC_DEVICEMAP 0x02
#define CU_MEMHOSTALLOC_WRITECOMBINED 0x04
and
enum CUmemAttach_flags {
CU_MEM_ATTACH_GLOBAL = 0x1,
CU_MEM_ATTACH_HOST = 0x2,
CU_MEM_ATTACH_SINGLE = 0x4
};
We cannot catch the situation when cu_params->memory_type = UMF_MEMORY_TYPE_HOST
and cu_params->alloc_flags = CU_MEM_ATTACH_GLOBAL
because CU_MEM_ATTACH_GLOBAL == CU_MEMHOSTALLOC_PORTABLE
I think, we can do nothing, it is a user's responsibility to provide a correct combination of memory type and allocation flags.
So I leave this comment just to note.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. We shouldn't try and catch invalid combinations of flags since CUDA will always be more up to date than us and will do the error checking for us. As long as errors from the CUDA runtime are correctly propagated I think we should just pass through
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, I have removed this check and added a comment
f2dd498
to
979f212
Compare
979f212
to
0f12c57
Compare
if (cu_params->memory_type == UMF_MEMORY_TYPE_SHARED && | ||
cu_params->alloc_flags == 0) { | ||
// the default setting is CU_MEM_ATTACH_GLOBAL | ||
cu_params->alloc_flags = CU_MEM_ATTACH_GLOBAL; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I think it would be better not to modify the params. In theory, following scenario is valid, right?
auto params = umfCUDAMemoryProviderParamsCreate();
params.setMemType(UMF_MEMORY_TYPE_SHARED);
provider1 = umfMemoryProviderCreate(..., params);
params.setMemType(UMF_MEMORY_TYPE_DEVICE);
provider2 = umfMemoryProviderCreate(..., params);
In this case, the alloc_flags passed for the second provider will be different than what user would expect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. We should do this for the cu_provider->alloc_flags
.
BTW can we test such scenario in our tests?
Add support for CUDA allocation flags.
This PR adds one new API call for CUDA:
umfCUDAMemoryProviderParamsSetAllocFlags()
Checklist