diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index 031a4cb01d1..e5c8cb3b0a5 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -84,6 +84,7 @@ def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESS This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. This test only considers tensors of the same shape across different ranks. Note that this test only works for torch>=2.0. + """ tensor = torch.ones(50, requires_grad=True) result = gather_all_tensors(tensor) @@ -108,6 +109,7 @@ def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PR This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. This test considers tensors of different shapes across different ranks. Note that this test only works for torch>=2.0. + """ tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True) result = gather_all_tensors(tensor)