Skip to content

Commit

Permalink
[FIX] Skip Distributed Sampler Tests if PyTorch with CUDA is not Avai…
Browse files Browse the repository at this point in the history
…lable (#4518)

Some nightly tests are failing because there is no CUDA-supporting version of PyTorch available (as expected, i.e. on CUDA 11.4).  Instead, the CPU version of PyTorch gets installed, and the test crashes when attempting to set the CUDA allocator.  This PR disables those tests when only a CPU version of PyTorch is available to prevent this from happening.

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Rick Ratzel (https://github.com/rlratzel)

URL: #4518
  • Loading branch information
alexbarghi-nv authored Jul 9, 2024
1 parent 407cdab commit e299a59
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
7 changes: 5 additions & 2 deletions python/cugraph/cugraph/tests/sampling/test_dist_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@

torch = import_optional("torch")
if not isinstance(torch, MissingModule):
from rmm.allocators.torch import rmm_torch_allocator
if torch.cuda.is_available():
from rmm.allocators.torch import rmm_torch_allocator

torch.cuda.change_current_allocator(rmm_torch_allocator)
torch.cuda.change_current_allocator(rmm_torch_allocator)
else:
pytest.skip("CUDA-enabled PyTorch is unavailable", allow_module_level=True)


@pytest.fixture
Expand Down
7 changes: 5 additions & 2 deletions python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@

torch = import_optional("torch")
if __name__ == "__main__" and not isinstance(torch, MissingModule):
from rmm.allocators.torch import rmm_torch_allocator
if torch.cuda.is_available():
from rmm.allocators.torch import rmm_torch_allocator

torch.cuda.change_current_allocator(rmm_torch_allocator)
torch.cuda.change_current_allocator(rmm_torch_allocator)
else:
pytest.skip("CUDA-enabled PyTorch is unavailable", allow_module_level=True)


def karate_mg_graph(rank, world_size):
Expand Down

0 comments on commit e299a59

Please sign in to comment.