Skip to content

Commit

Permalink
zluda torch 2.6.0
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Feb 7, 2025
1 parent cd1a9c5 commit 4f09014
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
2 changes: 1 addition & 1 deletion installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def install_rocm_zluda():
zluda_installer.set_blaslt_enabled(device.blaslt_supported)
zluda_installer.make_copy()
zluda_installer.load()
torch_command = os.environ.get('TORCH_COMMAND', f'torch=={zluda_installer.get_default_torch_version(device)} torchvision --index-url https://download.pytorch.org/whl/cu118')
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.6.0 torchvision --index-url https://download.pytorch.org/whl/cu118')
log.info(f'Using ZLUDA in {zluda_installer.path}')
except Exception as e:
error = e
Expand Down
14 changes: 4 additions & 10 deletions modules/zluda_installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import shutil
import zipfile
import urllib.request
from typing import Optional, Union
from typing import Union
from modules import rocm


Expand Down Expand Up @@ -94,8 +94,11 @@ def load() -> None:
ctypes.windll.LoadLibrary(os.path.join(path, v))

if hipBLASLt_enabled:
os.environ.setdefault("DISABLE_ADDMM_CUDA_LT", "0")
ctypes.windll.LoadLibrary(os.path.join(rocm.path, 'bin', 'hipblaslt.dll'))
ctypes.windll.LoadLibrary(os.path.join(path, 'cublasLt64_11.dll'))
else:
os.environ["DISABLE_ADDMM_CUDA_LT"] = "1"

def conceal():
import torch # pylint: disable=unused-import
Expand All @@ -110,12 +113,3 @@ def _join_rocm_home(*paths) -> str:
return os.path.join(cpp_extension.ROCM_HOME, *paths)
cpp_extension._join_rocm_home = _join_rocm_home # pylint: disable=protected-access
rocm.conceal = conceal


def get_default_torch_version(agent: Optional[rocm.Agent]) -> str:
if agent is not None:
if agent.arch in (rocm.MicroArchitecture.RDNA, rocm.MicroArchitecture.CDNA,):
return "2.4.1" if hipBLASLt_enabled else "2.3.1"
elif agent.arch == rocm.MicroArchitecture.GCN:
return "2.2.1"
return "2.4.1" if hipBLASLt_enabled else "2.3.1"

0 comments on commit 4f09014

Please sign in to comment.