Skip to content

Commit

Permalink
Defaulted CUDA_HOME/CUDA_PATH to /usr/local/cuda when attempting to d…
Browse files Browse the repository at this point in the history
…ynamically load cuDNN and NVRTC

Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Sep 16, 2024
1 parent af5daa0 commit 6388c31
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions transformer_engine/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,34 +47,35 @@ def _get_sys_extension():

def _load_cudnn():
"""Load CUDNN shared library."""

# Attempt to locate cuDNN in Python dist-packages
lib_path = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]",
)
)

if lib_path:
assert (
len(lib_path) == 1
), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX."
return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL)

# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
if cudnn_home:
libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)

cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home:
libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)

# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)


Expand All @@ -91,14 +92,15 @@ def _load_library():

def _load_nvrtc():
"""Load NVRTC shared library."""
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home:
libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)

# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)

# Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True)
libs = libs.decode("utf-8").split("\n")
sos = []
Expand All @@ -109,6 +111,8 @@ def _load_nvrtc():
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)

# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)


Expand Down

0 comments on commit 6388c31

Please sign in to comment.