Skip to content

Commit

Permalink
Fix warning raised by torch.meshgrid on missing indexing (#8689)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <[email protected]>
  • Loading branch information
vfdev-5 and NicolasHug authored Nov 27, 2024
1 parent acbfd8d commit 9756650
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion references/depth/stereo/utils/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def make_gaussian_kernel(kernel_size: int, sigma: float) -> torch.Tensor:
y = torch.arange(kernel_size, dtype=torch.float32)
x = x - (kernel_size - 1) / 2
y = y - (kernel_size - 1) / 2
x, y = torch.meshgrid(x, y)
x, y = torch.meshgrid(x, y, indexing="ij")
grid = (x**2 + y**2) / (2 * sigma**2)
kernel = torch.exp(-grid)
kernel = kernel / kernel.sum()
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/maxvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _make_block_input_shapes(input_size: Tuple[int, int], n_blocks: int) -> List


def _get_relative_position_index(height: int, width: int) -> torch.Tensor:
coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)]))
coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)], indexing="ij"))
coords_flat = torch.flatten(coords, 1)
relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
Expand Down

0 comments on commit 9756650

Please sign in to comment.