From e4cef7833da5800aea510e7243071eefa639af4c Mon Sep 17 00:00:00 2001 From: Ken Raffenetti Date: Tue, 1 Oct 2024 16:46:40 -0500 Subject: [PATCH] ch4/ofi: Convert CUDA device id to handle for fi_mr_regattr Libfabric docs say that the value of the cuda field in the regattr struct is the device handle gotten from cuDeviceGet, not the ordinal. Fixes pmodels/mpich#7148. --- src/mpid/ch4/netmod/ofi/ofi_impl.h | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/mpid/ch4/netmod/ofi/ofi_impl.h b/src/mpid/ch4/netmod/ofi/ofi_impl.h index 9917c2e9ea8..af9dd9e2d05 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_impl.h +++ b/src/mpid/ch4/netmod/ofi/ofi_impl.h @@ -11,6 +11,9 @@ #include "ofi_types.h" #include "mpidch4r.h" #include "ch4_impl.h" +#ifdef MPL_HAVE_CUDA +#include /* for cuDeviceGet */ +#endif extern unsigned long long PVAR_COUNTER_nic_sent_bytes_count[MPIDI_OFI_MAX_NICS] ATTRIBUTE((unused)); extern unsigned long long PVAR_COUNTER_nic_recvd_bytes_count[MPIDI_OFI_MAX_NICS] @@ -707,8 +710,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_register_memory(char *send_buf, size_t da mr_attr.context = NULL; if (MPL_gpu_attr_is_strict_dev(attr)) { #ifdef MPL_HAVE_CUDA + CUdevice device; + int dev_id; + + /* libfabric says to get the device handle from cuDeviceGet */ + dev_id = MPL_gpu_get_dev_id_from_attr(attr); + cuDeviceGet(&device, dev_id); + mr_attr.iface = FI_HMEM_CUDA; - mr_attr.device.cuda = MPL_gpu_get_dev_id_from_attr(attr); + mr_attr.device.cuda = device; #elif defined MPL_HAVE_ZE /* OFI does not support tiles yet, need to pass the root device. */ mr_attr.iface = FI_HMEM_ZE;