Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve how device switch is handled between the metric device and the input tensors device #3043

Merged
merged 24 commits into from
Aug 25, 2023
Merged
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
750a6ca
refactor: remove outdated code and issue a warning if two tensors are…
MarcBresson Aug 23, 2023
78a4c78
feat: prioritize computation on GPU devices over CPUs
MarcBresson Aug 24, 2023
85eebd5
fix: use a temp var that will be moved with y_pred
MarcBresson Aug 24, 2023
9125e60
test: add metric and y_pred with different devices test
MarcBresson Aug 24, 2023
a4c2f7c
feat: move self._kernel directly and issue a warning only when not al…
MarcBresson Aug 24, 2023
1908fff
feat: adapt test to new behaviour
MarcBresson Aug 24, 2023
2547e70
feat: keep the accumulation on the same device as self._kernel
MarcBresson Aug 24, 2023
3269955
feat: move accumulation along side self._kernel
MarcBresson Aug 24, 2023
04af090
feat: allow different channel number
MarcBresson Aug 24, 2023
7922ec9
style: format using the run_code_style script
MarcBresson Aug 25, 2023
b0625e4
style: add line brak to conform to E501
MarcBresson Aug 25, 2023
6817316
fix: use torch.empty to avoid type incompatibility between None and T…
MarcBresson Aug 25, 2023
d2aa8c8
feat: only operate on self._kernel, keep the accumulation on user's s…
MarcBresson Aug 25, 2023
c6bf8f8
test: add variable channel test and factorize the code
MarcBresson Aug 25, 2023
f6f82fe
Merge branch 'master' into refactor-_update
MarcBresson Aug 25, 2023
99c3469
refactor: remove redundant line between init and reset
MarcBresson Aug 25, 2023
eba6f68
refactor: elif comparison and replace RuntimeWarning by UserWarning
MarcBresson Aug 25, 2023
91ae235
refactor: set _kernel in __init__ and manually format to pass E501
MarcBresson Aug 25, 2023
7284b01
test: adapt test to new UserWarning
MarcBresson Aug 25, 2023
d96255c
test: remove skips
MarcBresson Aug 25, 2023
2807f28
refactor: use None instead of torch.empty
MarcBresson Aug 25, 2023
526234c
style: reorder imports
MarcBresson Aug 25, 2023
b6f1a21
refactor: rename channel to nb_channel
MarcBresson Aug 25, 2023
0a38aa5
Fixed failing test_distrib_accumulator_device
vfdev-5 Aug 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: move accumulation along side self._kernel
MarcBresson committed Aug 24, 2023
commit 326995590532666bcf2a10d2ddceff2a2ef2cb79
1 change: 1 addition & 0 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
@@ -165,6 +165,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
if y_pred.device != self._kernel.device:
if self._kernel.device == torch.device("cpu"):
self._kernel = self._kernel.to(device=y_pred.device)
self._sum_of_ssim = self._sum_of_ssim.to(device=y_pred.device)

if y_pred.device == torch.device("cpu"):
warnings.warn(