Skip to content

Commit

Permalink
[Hardware][AMD]: Replace HIPCC version with more precise ROCm version (
Browse files Browse the repository at this point in the history
…#11515)

Signed-off-by: hjwei <[email protected]>
  • Loading branch information
hj-wei authored Dec 28, 2024
1 parent b7dcc00 commit 59d6bb4
Showing 1 changed file with 29 additions and 23 deletions.
52 changes: 29 additions & 23 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ctypes
import importlib.util
import logging
import os
Expand All @@ -13,7 +14,7 @@
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
from setuptools_scm import get_version
from torch.utils.cpp_extension import CUDA_HOME
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME


def load_module_from_path(module_name, path):
Expand Down Expand Up @@ -379,25 +380,31 @@ def _build_custom_ops() -> bool:
return _is_cuda() or _is_hip() or _is_cpu()


def get_hipcc_rocm_version():
# Run the hipcc --version command
result = subprocess.run(['hipcc', '--version'],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True)
def get_rocm_version():
# Get the Rocm version from the ROCM_HOME/bin/librocm-core.so
# see https://github.com/ROCm/rocm-core/blob/d11f5c20d500f729c393680a01fa902ebf92094b/rocm_version.cpp#L21
try:
librocm_core_file = Path(ROCM_HOME) / "lib" / "librocm-core.so"
if not librocm_core_file.is_file():
return None
librocm_core = ctypes.CDLL(librocm_core_file)
VerErrors = ctypes.c_uint32
get_rocm_core_version = librocm_core.getROCmVersion
get_rocm_core_version.restype = VerErrors
get_rocm_core_version.argtypes = [
ctypes.POINTER(ctypes.c_uint32),
ctypes.POINTER(ctypes.c_uint32),
ctypes.POINTER(ctypes.c_uint32),
]
major = ctypes.c_uint32()
minor = ctypes.c_uint32()
patch = ctypes.c_uint32()

# Check if the command was executed successfully
if result.returncode != 0:
print("Error running 'hipcc --version'")
if (get_rocm_core_version(ctypes.byref(major), ctypes.byref(minor),
ctypes.byref(patch)) == 0):
return "%d.%d.%d" % (major.value, minor.value, patch.value)
return None

# Extract the version using a regular expression
match = re.search(r'HIP version: (\S+)', result.stdout)
if match:
# Return the version string
return match.group(1)
else:
print("Could not find HIP version in the output")
except Exception:
return None


Expand Down Expand Up @@ -479,11 +486,10 @@ def get_vllm_version() -> str:
if "sdist" not in sys.argv:
version += f"{sep}cu{cuda_version_str}"
elif _is_hip():
# Get the HIP version
hipcc_version = get_hipcc_rocm_version()
if hipcc_version != MAIN_CUDA_VERSION:
rocm_version_str = hipcc_version.replace(".", "")[:3]
version += f"{sep}rocm{rocm_version_str}"
# Get the Rocm Version
rocm_version = get_rocm_version() or torch.version.hip
if rocm_version and rocm_version != MAIN_CUDA_VERSION:
version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}"
elif _is_neuron():
# Get the Neuron version
neuron_version = str(get_neuronxcc_version())
Expand Down

0 comments on commit 59d6bb4

Please sign in to comment.