-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug]: debugging guide for device >= 0 && device < num_gpus INTERNAL ASSERT FAILED at "../aten/src/ATen/cuda/CUDAContext.cpp" #6056
Comments
Hmm, I'm getting this error for #5276: The stack trace suggests that When I try to log the traceback in
But I think this is supposed to happen, right?
|
Your current environment
🐛 Describe the bug
This is a compond and annoying bug, coupled with pytorch bug pytorch/pytorch#122815 .
Basically, pytorch
torch.cuda.device_count
function will cache the device count when first called. Users might not call it directly, but if you useimport torch._dynamo
, it will be called. The call chain is:In our case, some image processing code will import
torchvision
, which implicitly importtorch._dynamo
:Since
torch._dynamo
remembers the device count, it registers a hook to initialize all devices after cuda is initialized. If we shrinkCUDA_VISIBLE_DEVICES
later, before we initialize cuda, thentorch._dynamo
will hit this error.PyTorch fixes this bug in pytorch/pytorch#122795 .
However, before we upgrade to pytorch 2.4 , we cannot do anything.
Inside vLLM, we already use
vllm.utils.cuda_device_count_stateless
as much as possible. (If you seetorch.cuda.device_count()
, it is a bug, and we should fix it by callingvllm.utils.cuda_device_count_stateless()
).If some other library (e.g.
transformers
in this case) accidentally calledtorch.cuda.device_count()
, we cannot do anything but defer theimport
, as is done in #6055 .How to find the code to blame? My current approach is to manually insert
import traceback; traceback.print_stack()
insidetorch.cuda.device_count
. Yes, modify pytorch's code, that's it. If it prints a stack trace before we initialize the engine, then we need to find the line to blame.After deferring all possible lines to blame, we should fix this bug.
The text was updated successfully, but these errors were encountered: