diff --git a/accelerator/real_accelerator.py b/accelerator/real_accelerator.py index 69e96d285bb8..a6173ac70abd 100644 --- a/accelerator/real_accelerator.py +++ b/accelerator/real_accelerator.py @@ -167,7 +167,12 @@ def get_accelerator(): import torch # Determine if we are on a GPU or x86 CPU with torch. - if torch.cuda.is_available(): #ignore-cuda + # "torch.cuda.is_available()" provides a stronger guarantee, #ignore-cuda + # ensuring that we are free from CUDA initialization errors. + # While "torch.cuda.device_count() > 0" check ensures that #ignore-cuda + # we won't try to do any CUDA calls when no device is available + # For reference: https://github.com/microsoft/DeepSpeed/pull/6810 + if torch.cuda.device_count() > 0 and torch.cuda.is_available(): #ignore-cuda accelerator_name = "cuda" else: if accel_logger is not None: