diff --git a/start.py b/start.py index 8fe34fa8..1b5b085c 100644 --- a/start.py +++ b/start.py @@ -15,6 +15,11 @@ def get_requirements_file(): # TODO: Check if the user has an AMD gpu on windows if ROCM_PATH: requirements_name = "requirements-amd" + + # Also override env vars for ROCm support on non-supported GPUs + os.environ["ROCM_PATH"] = '/opt/rocm' + os.environ["HSA_OVERRIDE_GFX_VERSION"] = '10.3.0' + os.environ["HCC_AMDGPU_TARGET"] = 'gfx1030' elif CUDA_PATH: cuda_version = pathlib.Path(CUDA_PATH).name if "12" in cuda_version: