diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 46cfa9176a..4bcd1f8e27 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -47,20 +47,20 @@ def _get_sys_extension(): def _load_cudnn(): """Load CUDNN shared library.""" - + # Attempt to locate cuDNN in Python dist-packages lib_path = glob.glob( os.path.join( sysconfig.get_path("purelib"), f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]", ) ) - if lib_path: assert ( len(lib_path) == 1 ), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX." return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL) + # Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH") if cudnn_home: libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) @@ -68,13 +68,14 @@ def _load_cudnn(): if libs: return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) - cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") - if cuda_home: - libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) - libs.sort(reverse=True, key=os.path.basename) - if libs: - return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + # Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" + libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True) + libs.sort(reverse=True, key=os.path.basename) + if libs: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) @@ -91,14 +92,15 @@ def _load_library(): def _load_nvrtc(): """Load NVRTC shared library.""" - cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") - if cuda_home: - libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True) - libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs)) - libs.sort(reverse=True, key=os.path.basename) - if libs: - return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) - + # Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" + libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True) + libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs)) + libs.sort(reverse=True, key=os.path.basename) + if libs: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + + # Attempt to locate NVRTC via ldconfig libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True) libs = libs.decode("utf-8").split("\n") sos = [] @@ -109,6 +111,8 @@ def _load_nvrtc(): sos.append(lib.split(">")[1].strip()) if sos: return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) + + # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)