Skip to content

Commit

Permalink
Update HAS_GPU variable to account for CUDA_VISIBLE_DEVICES (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverholworthy authored Feb 15, 2023
1 parent 5dbafa6 commit e3d892e
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions merlin/core/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os

try:
from numba import cuda # pylint: disable=unused-import
except ImportError:
cuda = None

from dask.distributed.diagnostics import nvml

# Using the `dask.distributed.diagnostics.nvml.device_get_count`
# helper function from dask to check device counts with NVML
# since this handles some complexity of checking NVML state for us.

# Note: We can't use `numba.cuda.gpus`, since this has some side effects
# that are incompatible with Dask-CUDA. If CUDA runtime functions are
# called before Dask-CUDA can spawn worker processes
# then Dask-CUDA it will not work correctly (raises an exception)
def _get_gpu_count():
"""Get Number of GPU devices accounting for CUDA_VISIBLE_DEVICES environment variable"""
# Using the `dask.distributed.diagnostics.nvml.device_get_count`
# helper function from dask to check device counts with NVML
# since this handles some complexity of checking NVML state for us.

# Note: We can't use `numba.cuda.gpus`, since this has some side effects
# that are incompatible with Dask-CUDA. If CUDA runtime functions are
# called before Dask-CUDA can spawn worker processes
# then Dask-CUDA it will not work correctly (raises an exception)
nvml_device_count = nvml.device_get_count()
if nvml_device_count == 0:
return 0
try:
cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
if cuda_visible_devices:
return len(cuda_visible_devices.split(","))
else:
return 0
except KeyError:
return nvml_device_count


HAS_GPU = nvml.device_get_count() > 0
HAS_GPU = _get_gpu_count() > 0

0 comments on commit e3d892e

Please sign in to comment.