Skip to content

Commit

Permalink
Avoid poisoning process with CUDA calls as soon as importing (#6810)
Browse files Browse the repository at this point in the history
Call `torch.cuda.device_count() > 0` before `torch.cuda.is_available()`,
to give priority to nvml based availability, so that we can try not to
poison process with CUDA calls as soon as we execute `import deepspeed`.


https://github.com/pytorch/pytorch/blob/v2.5.1/torch/cuda/__init__.py#L120-L124

There are 2 reasons to make this change:

Firstly, if we accidentally import deepspeed, since the CUDA runtime
initializes when the first CUDA API call is made and caches the device
list, changing the CUDA_VISIBLE_DEVICES within the same process after
initialization won't have any effect on the visible devices. The
specific case:
OpenRLHF/OpenRLHF#524 (comment)

A demo for reproduction before the fix is applied:

```python
import torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import deepspeed
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
torch.cuda.set_device('cuda:0')
```

Secondly, https://pytorch.org/docs/stable/notes/cuda.html

When assessing the availability of CUDA in a given environment
(is_available()), PyTorch’s default behavior is to call the CUDA Runtime
API method cudaGetDeviceCount. Because this call in turn initializes the
CUDA Driver API (via cuInit) if it is not already initialized,
subsequent forks of a process that has run is_available() will fail with
a CUDA initialization error.

Signed-off-by: Hollow Man <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
HollowMan6 and loadams authored Dec 12, 2024
1 parent bd6fd50 commit b166a9b
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,12 @@ def get_accelerator():
import torch

# Determine if we are on a GPU or x86 CPU with torch.
if torch.cuda.is_available(): #ignore-cuda
# "torch.cuda.is_available()" provides a stronger guarantee, #ignore-cuda
# ensuring that we are free from CUDA initialization errors.
# While "torch.cuda.device_count() > 0" check ensures that #ignore-cuda
# we won't try to do any CUDA calls when no device is available
# For reference: https://github.com/microsoft/DeepSpeed/pull/6810
if torch.cuda.device_count() > 0 and torch.cuda.is_available(): #ignore-cuda
accelerator_name = "cuda"
else:
if accel_logger is not None:
Expand Down

0 comments on commit b166a9b

Please sign in to comment.