Skip to content

Commit

Permalink
Add exclude_gpus properties to ROCm and DirectML
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Mar 26, 2024
1 parent be2e8c5 commit 8c79afd
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
26 changes: 25 additions & 1 deletion lib/gpu_stats/directml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from comtypes import COMError, IUnknown, GUID, STDMETHOD, HRESULT # pylint:disable=import-error

from ._base import _GPUStats
from ._base import _GPUStats, _EXCLUDE_DEVICES

if T.TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -628,3 +628,27 @@ def _get_free_vram(self) -> list[int]:
vram = [int(device.local_mem.Budget / (1024 * 1024)) for device in self._devices]
self._log("debug", f"GPU VRAM free: {vram}")
return vram

def exclude_devices(self, devices: list[int]) -> None:
""" Exclude GPU devices from being used by Faceswap. Sets the DML_VISIBLE_DEVICES
environment variable. This must be called before Torch/Keras are imported
Parameters
----------
devices: list[int]
The GPU device IDS to be excluded
"""
if not devices:
return
self._logger.debug("Excluding GPU indicies: %s", devices)

_EXCLUDE_DEVICES.extend(devices)

active = self._get_active_devices()

os.environ["DML_VISIBLE_DEVICES"] = ",".join(str(d) for d in active
if d not in _EXCLUDE_DEVICES)

self._logger.debug("DML environmet variables: %s",
[f"{k}: {v}" for k, v in os.environ.items()
if k.lower().startswith("dml")])
26 changes: 25 additions & 1 deletion lib/gpu_stats/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import re
from subprocess import run

from ._base import _GPUStats
from ._base import _GPUStats, _EXCLUDE_DEVICES

_DEVICE_LOOKUP = { # ref: https://gist.github.com/roalercon/51f13a387f3754615cce
int("0x130F", 0): "AMD Radeon(TM) R7 Graphics",
Expand Down Expand Up @@ -448,3 +448,27 @@ def _get_free_vram(self) -> list[int]:
retval.append(vram - int(used / (1024 * 1024)))
self._log("debug", f"GPU VRAM free: {retval}")
return retval

def exclude_devices(self, devices: list[int]) -> None:
""" Exclude GPU devices from being used by Faceswap. Sets the HIP_VISIBLE_DEVICES
environment variable. This must be called before Torch/Keras are imported
Parameters
----------
devices: list[int]
The GPU device IDS to be excluded
"""
if not devices:
return
self._logger.debug("Excluding GPU indicies: %s", devices)

_EXCLUDE_DEVICES.extend(devices)

active = self._get_active_devices()

os.environ["HIP_VISIBLE_DEVICES"] = ",".join(str(d) for d in active
if d not in _EXCLUDE_DEVICES)

self._logger.debug("HIP environmet variables: %s",
[f"{k}: {v}" for k, v in os.environ.items()
if k.lower().startswith("hip")])

0 comments on commit 8c79afd

Please sign in to comment.