Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check Pytorch cuda context is valid across GPUs #284

Open
VibhuJawa opened this issue Oct 8, 2024 · 2 comments
Open

Check Pytorch cuda context is valid across GPUs #284

VibhuJawa opened this issue Oct 8, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@VibhuJawa
Copy link
Collaborator

Describe the bug

We have had multiple breakages of CUDA context being only used for GPU 0 in a dask+pytorch environment. Sometimes this can occur due to a library creating a cuda context with pytorch before starting the cluster.

What ends up happening is Pytorch models being deployed on GPU-0 and that issue is hard to debug .

@VibhuJawa VibhuJawa added the bug Something isn't working label Oct 8, 2024
@VibhuJawa VibhuJawa self-assigned this Oct 8, 2024
@VibhuJawa VibhuJawa changed the title [WIP] Check Pytorch cuda context is valid across GPUs Check Pytorch cuda context is valid across GPUs Oct 9, 2024
@VibhuJawa
Copy link
Collaborator Author

I think a better fix is ensuring we dont fork context if its all ready present for local cuda cluster.

import cupy as cp
cp.cuda.runtime.getDeviceCount()

# import torch
# t = totch.as_tensor([1,2,3])

from dask_cuda import LocalCUDACluster
from distributed import Client 
from distributed.diagnostics.nvml import has_cuda_context
import time
def check_cuda_context():
    _warning_suffix = (
        "This is often the result of a CUDA-enabled library calling a CUDA runtime function before "
        "Dask-CUDA can spawn worker processes. Please make sure any such function calls don't happen "
        "at import time or in the global scope of a program."
    )
    if has_cuda_context().has_context:     
        # If no error was raised, the CUDA context is initialized
        raise RuntimeError(
            f"CUDA context is initialized before the dask-cuda cluster was spun up. {_warning_suffix}"
        )

if __name__ == "__main__":
    check_cuda_context()    
    cluster = LocalCUDACluster(rmm_async=True, rmm_pool_size="2GiB")
    client = Client(cluster)

CC: @ayushdg

@VibhuJawa
Copy link
Collaborator Author

Moving to next release as not a high priority

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant