diff --git a/installer.py b/installer.py index be500e5a4..6a08d1b82 100644 --- a/installer.py +++ b/installer.py @@ -633,7 +633,7 @@ def install_rocm_zluda(): zluda_installer.set_blaslt_enabled(device.blaslt_supported) zluda_installer.make_copy() zluda_installer.load() - torch_command = os.environ.get('TORCH_COMMAND', f'torch=={zluda_installer.get_default_torch_version(device)} torchvision --index-url https://download.pytorch.org/whl/cu118') + torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.6.0 torchvision --index-url https://download.pytorch.org/whl/cu118') log.info(f'Using ZLUDA in {zluda_installer.path}') except Exception as e: error = e diff --git a/modules/zluda_installer.py b/modules/zluda_installer.py index 8888a56f4..df3002b79 100644 --- a/modules/zluda_installer.py +++ b/modules/zluda_installer.py @@ -5,7 +5,7 @@ import shutil import zipfile import urllib.request -from typing import Optional, Union +from typing import Union from modules import rocm @@ -94,8 +94,11 @@ def load() -> None: ctypes.windll.LoadLibrary(os.path.join(path, v)) if hipBLASLt_enabled: + os.environ.setdefault("DISABLE_ADDMM_CUDA_LT", "0") ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', 'hipblaslt.dll')) ctypes.windll.LoadLibrary(os.path.join(path, 'cublasLt64_11.dll')) + else: + os.environ["DISABLE_ADDMM_CUDA_LT"] = "1" def conceal(): import torch # pylint: disable=unused-import @@ -110,12 +113,3 @@ def _join_rocm_home(*paths) -> str: return os.path.join(cpp_extension.ROCM_HOME, *paths) cpp_extension._join_rocm_home = _join_rocm_home # pylint: disable=protected-access rocm.conceal = conceal - - -def get_default_torch_version(agent: Optional[rocm.Agent]) -> str: - if agent is not None: - if agent.arch in (rocm.MicroArchitecture.RDNA, rocm.MicroArchitecture.CDNA,): - return "2.4.1" if hipBLASLt_enabled else "2.3.1" - elif agent.arch == rocm.MicroArchitecture.GCN: - return "2.2.1" - return "2.4.1" if hipBLASLt_enabled else "2.3.1"