From 8122e9f3e29efa4e1a4ec623b943740051f430d1 Mon Sep 17 00:00:00 2001 From: cw-tan Date: Tue, 17 Sep 2024 13:58:40 -0400 Subject: [PATCH] propagate rank result to gathered result for autograd compatibility --- src/torchmetrics/utilities/distributed.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index 455d64c4ae0..4f6eacea866 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -91,6 +91,8 @@ def class_reduce( def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]: gathered_result = [torch.zeros_like(result) for _ in range(world_size)] torch.distributed.all_gather(gathered_result, result, group) + # to propagate autograd graph from local rank (achieves intended effect for torch> 2.0) + gathered_result[torch.distributed.get_rank(group)] = result return gathered_result @@ -144,4 +146,6 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens for idx, item_size in enumerate(local_sizes): slice_param = [slice(dim_size) for dim_size in item_size] gathered_result[idx] = gathered_result[idx][slice_param] + # to propagate autograd graph from local rank (achieves intended effect for torch> 2.0) + gathered_result[torch.distributed.get_rank(group)] = result return gathered_result