From 975665039d85cbb1cf74bebb70010a75af54c396 Mon Sep 17 00:00:00 2001 From: vfdev Date: Wed, 27 Nov 2024 19:21:10 +0100 Subject: [PATCH] Fix warning raised by torch.meshgrid on missing indexing (#8689) Co-authored-by: Nicolas Hug --- references/depth/stereo/utils/losses.py | 2 +- torchvision/models/maxvit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/references/depth/stereo/utils/losses.py b/references/depth/stereo/utils/losses.py index c809cc74d0f..1c21353a056 100644 --- a/references/depth/stereo/utils/losses.py +++ b/references/depth/stereo/utils/losses.py @@ -13,7 +13,7 @@ def make_gaussian_kernel(kernel_size: int, sigma: float) -> torch.Tensor: y = torch.arange(kernel_size, dtype=torch.float32) x = x - (kernel_size - 1) / 2 y = y - (kernel_size - 1) / 2 - x, y = torch.meshgrid(x, y) + x, y = torch.meshgrid(x, y, indexing="ij") grid = (x**2 + y**2) / (2 * sigma**2) kernel = torch.exp(-grid) kernel = kernel / kernel.sum() diff --git a/torchvision/models/maxvit.py b/torchvision/models/maxvit.py index 2a3888b2af3..66f49772218 100644 --- a/torchvision/models/maxvit.py +++ b/torchvision/models/maxvit.py @@ -40,7 +40,7 @@ def _make_block_input_shapes(input_size: Tuple[int, int], n_blocks: int) -> List def _get_relative_position_index(height: int, width: int) -> torch.Tensor: - coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)])) + coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)], indexing="ij")) coords_flat = torch.flatten(coords, 1) relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous()