Skip to content

Commit

Permalink
Bugfix: relax the requirements (#154)
Browse files Browse the repository at this point in the history
* bug(requirements): made requirements less strict, added warning.

* bug(requirements): Highlight it in the description and readme.

* bug(requirements): Minor.

* Release commit
  • Loading branch information
denproc authored Jul 28, 2020
1 parent c39b7e6 commit 3db33e3
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 8 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ prediction = torch.rand(3, 3, 256, 256)
brisque_index: torch.Tensor = brisque(prediction, data_range=1.)
```

In order to use BRISQUE as a loss function, use corresponding PyTorch module:
In order to use BRISQUE as a loss function, use corresponding PyTorch module.

Note: the back propagation is not available using `torch==1.5.0`. Update the environment with latest `torch` and `torchvision`.
```python
import torch
from piq import BRISQUELoss
Expand Down
8 changes: 4 additions & 4 deletions conda.recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ source:
requirements:
build:
- python
- pytorch>=1.2.0,!=1.5.0
- torchvision>=0.4.0,!=0.6.0
- pytorch>=1.2.0
- torchvision>=0.4.0
- scipy==1.3.3
- gudhi>=3.2
run:
- python
- pytorch>=1.2.0,!=1.5.0
- torchvision>=0.4.0,!=0.6.0
- pytorch>=1.2.0
- torchvision>=0.4.0
- scipy==1.3.3
- gudhi>=3.2

Expand Down
2 changes: 1 addition & 1 deletion piq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.5.0"
__version__ = "0.5.1"

from .ssim import ssim, multi_scale_ssim, SSIMLoss, MultiScaleSSIMLoss
from .msid import MSID
Expand Down
15 changes: 15 additions & 0 deletions piq/brisque.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
https://github.com/bukalapak/pybrisque
"""
from typing import Union, Tuple
import warnings
import torch
from torch.nn.modules.loss import _Loss
from torch.utils.model_zoo import load_url
Expand All @@ -33,10 +34,20 @@ def brisque(x: torch.Tensor,
Returns:
Value of BRISQUE index.
Note:
The back propagation is not available using torch=1.5.0 due to bug in argmin/argmax back propagation.
Update the torch and torchvision to the latest versions.
References:
.. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain",
https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf
"""
if '1.5.0' in torch.__version__:
warnings.warn(f'BRISQUE does not support back propagation due to bug in torch={torch.__version__}.'
f'Update torch to the latest version to access full functionality of the BRIQSUE.'
f'More info is available at https://github.com/photosynthesis-team/piq/pull/79 and'
f'https://github.com/pytorch/pytorch/issues/38869.')

_validate_input(input_tensors=x, allow_5d=False, kernel_size=kernel_size)
x = _adjust_dimensions(input_tensors=x)

Expand Down Expand Up @@ -91,6 +102,10 @@ class BRISQUELoss(_Loss):
>>> output = loss(prediction)
>>> output.backward()
Note:
The back propagation is not available using torch=1.5.0 due to bug in argmin/argmax back propagation.
Update the torch and torchvision to the latest versions.
References:
.. [1] Anish Mittal et al. "No-Reference Image Quality Assessment in the Spatial Domain",
https://live.ece.utexas.edu/publications/2012/TIP%20BRISQUE.pdf
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch>=1.2.0,!=1.5.0
torchvision>=0.4.0,!=0.6.0
torch>=1.2.0
torchvision>=0.4.0
scipy==1.3.3
gudhi>=3.2

0 comments on commit 3db33e3

Please sign in to comment.