Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve how device switch is handled between the metric device and the input tensors device #3043

Merged
merged 24 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
750a6ca
refactor: remove outdated code and issue a warning if two tensors are…
MarcBresson Aug 23, 2023
78a4c78
feat: prioritize computation on GPU devices over CPUs
MarcBresson Aug 24, 2023
85eebd5
fix: use a temp var that will be moved with y_pred
MarcBresson Aug 24, 2023
9125e60
test: add metric and y_pred with different devices test
MarcBresson Aug 24, 2023
a4c2f7c
feat: move self._kernel directly and issue a warning only when not al…
MarcBresson Aug 24, 2023
1908fff
feat: adapt test to new behaviour
MarcBresson Aug 24, 2023
2547e70
feat: keep the accumulation on the same device as self._kernel
MarcBresson Aug 24, 2023
3269955
feat: move accumulation along side self._kernel
MarcBresson Aug 24, 2023
04af090
feat: allow different channel number
MarcBresson Aug 24, 2023
7922ec9
style: format using the run_code_style script
MarcBresson Aug 25, 2023
b0625e4
style: add line brak to conform to E501
MarcBresson Aug 25, 2023
6817316
fix: use torch.empty to avoid type incompatibility between None and T…
MarcBresson Aug 25, 2023
d2aa8c8
feat: only operate on self._kernel, keep the accumulation on user's s…
MarcBresson Aug 25, 2023
c6bf8f8
test: add variable channel test and factorize the code
MarcBresson Aug 25, 2023
f6f82fe
Merge branch 'master' into refactor-_update
MarcBresson Aug 25, 2023
99c3469
refactor: remove redundant line between init and reset
MarcBresson Aug 25, 2023
eba6f68
refactor: elif comparison and replace RuntimeWarning by UserWarning
MarcBresson Aug 25, 2023
91ae235
refactor: set _kernel in __init__ and manually format to pass E501
MarcBresson Aug 25, 2023
7284b01
test: adapt test to new UserWarning
MarcBresson Aug 25, 2023
d96255c
test: remove skips
MarcBresson Aug 25, 2023
2807f28
refactor: use None instead of torch.empty
MarcBresson Aug 25, 2023
526234c
style: reorder imports
MarcBresson Aug 25, 2023
b6f1a21
refactor: rename channel to nb_channel
MarcBresson Aug 25, 2023
0a38aa5
Fixed failing test_distrib_accumulator_device
vfdev-5 Aug 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Sequence, Union
import warnings
from typing import Callable, Optional, Sequence, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -102,7 +103,8 @@
self.c2 = (k2 * data_range) ** 2
self.pad_h = (self.kernel_size[0] - 1) // 2
self.pad_w = (self.kernel_size[1] - 1) // 2
self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma)
self._kernel_2d = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma)
self._kernel: Optional[torch.Tensor] = None

@reinit__is_reduced
def reset(self) -> None:
Expand Down Expand Up @@ -155,9 +157,22 @@
f"Expected y_pred and y to have BxCxHxW shape. Got y_pred: {y_pred.shape} and y: {y.shape}."
)

channel = y_pred.size(1)
if len(self._kernel.shape) < 4:
self._kernel = self._kernel.expand(channel, 1, -1, -1).to(device=y_pred.device)
nb_channel = y_pred.size(1)
if self._kernel is None or self._kernel.shape[0] != nb_channel:
self._kernel = self._kernel_2d.expand(nb_channel, 1, -1, -1)

if y_pred.device != self._kernel.device:
MarcBresson marked this conversation as resolved.
Show resolved Hide resolved
if self._kernel.device == torch.device("cpu"):
self._kernel = self._kernel.to(device=y_pred.device)

Check warning on line 166 in ignite/metrics/ssim.py

View check run for this annotation

Codecov / codecov/patch

ignite/metrics/ssim.py#L165-L166

Added lines #L165 - L166 were not covered by tests

elif y_pred.device == torch.device("cpu"):
warnings.warn(

Check warning on line 169 in ignite/metrics/ssim.py

View check run for this annotation

Codecov / codecov/patch

ignite/metrics/ssim.py#L168-L169

Added lines #L168 - L169 were not covered by tests
"y_pred tensor is on cpu device but previous computation was on another device: "
f"{self._kernel.device}. To avoid having a performance hit, please ensure that all "
"y and y_pred tensors are on the same device.",
)
y_pred = y_pred.to(device=self._kernel.device)
y = y.to(device=self._kernel.device)

Check warning on line 175 in ignite/metrics/ssim.py

View check run for this annotation

Codecov / codecov/patch

ignite/metrics/ssim.py#L174-L175

Added lines #L174 - L175 were not covered by tests

y_pred = F.pad(y_pred, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect")
y = F.pad(y, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect")
Expand All @@ -166,7 +181,7 @@
self._kernel = self._kernel.to(dtype=y_pred.dtype)

input_list = [y_pred, y, y_pred * y_pred, y * y, y_pred * y]
outputs = F.conv2d(torch.cat(input_list), self._kernel, groups=channel)
outputs = F.conv2d(torch.cat(input_list), self._kernel, groups=nb_channel)
batch_size = y_pred.size(0)
output_list = [outputs[x * batch_size : (x + 1) * batch_size] for x in range(len(input_list))]

Expand All @@ -184,7 +199,7 @@
b2 = sigma_pred_sq + sigma_target_sq + self.c2

ssim_idx = (a1 * a2) / (b1 * b2)
self._sum_of_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=torch.float64).sum().to(self._device)
self._sum_of_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=torch.float64).sum().to(device=self._device)

self._num_examples += y.shape[0]

Expand Down
103 changes: 93 additions & 10 deletions tests/ignite/metrics/test_ssim.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Sequence, Union

import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -70,25 +72,49 @@ def test_invalid_ssim():
"shape, kernel_size, gaussian, use_sample_covariance",
[[(8, 3, 224, 224), 7, False, True], [(12, 3, 28, 28), 11, True, False]],
)
def test_ssim(
available_device, shape, kernel_size, gaussian, use_sample_covariance, dtype=torch.float32, precision=7e-5
):
y_pred = torch.rand(shape, device=available_device, dtype=dtype)
def test_ssim(available_device, shape, kernel_size, gaussian, use_sample_covariance):
y_pred = torch.rand(shape, device=available_device)
y = y_pred * 0.8

compare_ssim_ignite_skiimg(
y_pred,
y,
available_device,
kernel_size=kernel_size,
gaussian=gaussian,
use_sample_covariance=use_sample_covariance,
)


def compare_ssim_ignite_skiimg(
y_pred: torch.Tensor,
y: torch.Tensor,
device: torch.device,
precision: float = 2e-5, # default to float32 expected precision
*,
skimg_y_pred: Union[np.ndarray, None] = None,
skimg_y: Union[np.ndarray, None] = None,
data_range: float = 1.0,
kernel_size: Union[int, Sequence[int]] = 11,
gaussian: bool = True,
use_sample_covariance: bool = False,
):
sigma = 1.5
data_range = 1.0
ssim = SSIM(data_range=data_range, sigma=sigma, device=available_device)

ssim = SSIM(data_range=data_range, sigma=sigma, device=device)
ssim.update((y_pred, y))
ignite_ssim = ssim.compute()

if y_pred.dtype == torch.bfloat16:
y_pred = y_pred.to(dtype=torch.float16)

skimg_pred = y_pred.cpu().numpy()
skimg_y = skimg_pred * 0.8
if skimg_y_pred is None:
skimg_y_pred = y_pred.cpu().numpy()
if skimg_y is None:
skimg_y = skimg_y_pred * 0.8

skimg_ssim = ski_ssim(
skimg_pred,
skimg_y_pred,
skimg_y,
win_size=kernel_size,
sigma=sigma,
Expand All @@ -102,6 +128,43 @@ def test_ssim(
assert np.allclose(ignite_ssim, skimg_ssim, atol=precision)


@pytest.mark.parametrize(
"metric_device, y_pred_device",
[
[torch.device("cpu"), torch.device("cpu")],
[torch.device("cpu"), torch.device("cuda")],
[torch.device("cuda"), torch.device("cpu")],
[torch.device("cuda"), torch.device("cuda")],
],
)
def test_ssim_device(available_device, metric_device, y_pred_device):
if available_device == "cpu":
pytest.skip("This test requires a cuda device.")

data_range = 1.0
sigma = 1.5
shape = (12, 5, 256, 256)

ssim = SSIM(data_range=data_range, sigma=sigma, device=metric_device)

y_pred = torch.rand(shape, device=y_pred_device)
y = y_pred * 0.8

if metric_device == torch.device("cuda") and y_pred_device == torch.device("cpu"):
with pytest.warns(UserWarning):
ssim.update((y_pred, y))
else:
ssim.update((y_pred, y))

if metric_device == torch.device("cuda") or y_pred_device == torch.device("cuda"):
# A tensor will always have the device index set
excepted_device = torch.device("cuda:0")
else:
excepted_device = torch.device("cpu")

assert ssim._kernel.device == excepted_device


def test_ssim_variable_batchsize(available_device):
# Checks https://github.com/pytorch/ignite/issues/2532
sigma = 1.5
Expand All @@ -128,6 +191,21 @@ def test_ssim_variable_batchsize(available_device):
assert np.allclose(out, expected)


def test_ssim_variable_channel(available_device):
y_preds = [
torch.rand(12, 5, 28, 28, device=available_device),
torch.rand(12, 4, 28, 28, device=available_device),
torch.rand(12, 7, 28, 28, device=available_device),
torch.rand(12, 3, 28, 28, device=available_device),
torch.rand(12, 11, 28, 28, device=available_device),
torch.rand(12, 6, 28, 28, device=available_device),
]
y_true = [v * 0.8 for v in y_preds]

for y_pred, y in zip(y_preds, y_true):
compare_ssim_ignite_skiimg(y_pred, y, available_device)


@pytest.mark.parametrize(
"dtype, precision", [(torch.bfloat16, 2e-3), (torch.float16, 4e-4), (torch.float32, 2e-5), (torch.float64, 2e-5)]
)
Expand All @@ -136,7 +214,12 @@ def test_cuda_ssim_dtypes(available_device, dtype, precision):
if available_device == "cpu" and dtype in [torch.float16, torch.bfloat16]:
pytest.skip(reason=f"Unsupported dtype {dtype} on CPU device")

test_ssim(available_device, (12, 3, 28, 28), 11, True, False, dtype=dtype, precision=precision)
shape = (12, 3, 28, 28)

y_pred = torch.rand(shape, device=available_device, dtype=dtype)
y = y_pred * 0.8

compare_ssim_ignite_skiimg(y_pred, y, available_device, precision)


@pytest.mark.parametrize("metric_device", ["cpu", "process_device"])
Expand Down
Loading