Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve how device switch is handled between the metric device and the input tensors device #3043

Merged
merged 24 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
750a6ca
refactor: remove outdated code and issue a warning if two tensors are…
MarcBresson Aug 23, 2023
78a4c78
feat: prioritize computation on GPU devices over CPUs
MarcBresson Aug 24, 2023
85eebd5
fix: use a temp var that will be moved with y_pred
MarcBresson Aug 24, 2023
9125e60
test: add metric and y_pred with different devices test
MarcBresson Aug 24, 2023
a4c2f7c
feat: move self._kernel directly and issue a warning only when not al…
MarcBresson Aug 24, 2023
1908fff
feat: adapt test to new behaviour
MarcBresson Aug 24, 2023
2547e70
feat: keep the accumulation on the same device as self._kernel
MarcBresson Aug 24, 2023
3269955
feat: move accumulation along side self._kernel
MarcBresson Aug 24, 2023
04af090
feat: allow different channel number
MarcBresson Aug 24, 2023
7922ec9
style: format using the run_code_style script
MarcBresson Aug 25, 2023
b0625e4
style: add line brak to conform to E501
MarcBresson Aug 25, 2023
6817316
fix: use torch.empty to avoid type incompatibility between None and T…
MarcBresson Aug 25, 2023
d2aa8c8
feat: only operate on self._kernel, keep the accumulation on user's s…
MarcBresson Aug 25, 2023
c6bf8f8
test: add variable channel test and factorize the code
MarcBresson Aug 25, 2023
f6f82fe
Merge branch 'master' into refactor-_update
MarcBresson Aug 25, 2023
99c3469
refactor: remove redundant line between init and reset
MarcBresson Aug 25, 2023
eba6f68
refactor: elif comparison and replace RuntimeWarning by UserWarning
MarcBresson Aug 25, 2023
91ae235
refactor: set _kernel in __init__ and manually format to pass E501
MarcBresson Aug 25, 2023
7284b01
test: adapt test to new UserWarning
MarcBresson Aug 25, 2023
d96255c
test: remove skips
MarcBresson Aug 25, 2023
2807f28
refactor: use None instead of torch.empty
MarcBresson Aug 25, 2023
526234c
style: reorder imports
MarcBresson Aug 25, 2023
b6f1a21
refactor: rename channel to nb_channel
MarcBresson Aug 25, 2023
0a38aa5
Fixed failing test_distrib_accumulator_device
vfdev-5 Aug 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Callable, Sequence, Union

import torch
Expand Down Expand Up @@ -158,8 +159,21 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
)

channel = y_pred.size(1)
if len(self._kernel.shape) < 4:
self._kernel = self._kernel.expand(channel, 1, -1, -1).to(device=y_pred.device)

self._kernel = self._kernel.expand(channel, 1, -1, -1)
MarcBresson marked this conversation as resolved.
Show resolved Hide resolved

if y_pred.device != self._kernel.device:
MarcBresson marked this conversation as resolved.
Show resolved Hide resolved
if self._kernel.device == torch.device("cpu"):
self._kernel = self._kernel.to(device=y_pred.device)
self._sum_of_ssim = self._sum_of_ssim.to(device=y_pred.device)
MarcBresson marked this conversation as resolved.
Show resolved Hide resolved

if y_pred.device == torch.device("cpu"):
warnings.warn(
"The metric device or one of the previous update tensor was set on another device than this update tensor, which is on CPU. To avoid having a performance hit, ensure that your metric device and all of your update tensors are on the same device.",
RuntimeWarning,
)
y_pred = y_pred.to(device=self._kernel.device)
y = y.to(device=self._kernel.device)

y_pred = F.pad(y_pred, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect")
y = F.pad(y, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect")
Expand All @@ -186,7 +200,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
b2 = sigma_pred_sq + sigma_target_sq + self.c2

ssim_idx = (a1 * a2) / (b1 * b2)
self._sum_of_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=torch.float64).sum().to(self._device)
self._sum_of_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=torch.float64).sum()

self._num_examples += y.shape[0]

Expand Down
39 changes: 39 additions & 0 deletions tests/ignite/metrics/test_ssim.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -102,6 +104,43 @@ def test_ssim(
assert np.allclose(ignite_ssim, skimg_ssim, atol=precision)


@pytest.mark.parametrize(
"metric_device, y_pred_device",
[
[torch.device("cpu"), torch.device("cpu")],
[torch.device("cpu"), torch.device("cuda")],
[torch.device("cuda"), torch.device("cpu")],
[torch.device("cuda"), torch.device("cuda")],
],
)
def test_ssim_device(available_device, metric_device, y_pred_device):
if available_device == "cpu":
pytest.skip("This test requires a cuda device.")

data_range = 1.0
sigma = 1.5
shape = (12, 5, 256, 256)

ssim = SSIM(data_range=data_range, sigma=sigma, device=metric_device)

y_pred = torch.rand(shape, device=y_pred_device)
y = y_pred * 0.8

if metric_device == torch.device("cuda") and y_pred_device == torch.device("cpu"):
with pytest.warns(RuntimeWarning):
ssim.update((y_pred, y))
else:
ssim.update((y_pred, y))

if metric_device == torch.device("cuda") or y_pred_device == torch.device("cuda"):
# A tensor will always have the device index set
excepted_device = torch.device("cuda:0")
else:
excepted_device = torch.device("cpu")

assert ssim._kernel.device == excepted_device


def test_ssim_variable_batchsize(available_device):
# Checks https://github.com/pytorch/ignite/issues/2532
sigma = 1.5
Expand Down
Loading