Skip to content

Commit

Permalink
Bug fix for BRISQUE score (#88)
Browse files Browse the repository at this point in the history
* bugfix(brisque): Improved robustness of BRISQUE score

* bugfix(brisque): Added asserts according to theory from GGD and AGGD.

* minor(brisque): Changed link, groomed tests

Co-authored-by: Sergey Kastryulin <[email protected]>
  • Loading branch information
denproc and snk4tr authored Jun 18, 2020
1 parent 73a9071 commit 0734b30
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 48 deletions.
25 changes: 19 additions & 6 deletions photosynthesis_metrics/brisque.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def _ggd_parameters(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

sigma_sq = x.pow(2).mean(dim=(-1, -2))
sigma = sigma_sq.sqrt().squeeze(dim=-1)

assert not torch.isclose(sigma, torch.zeros_like(sigma)).all(), \
'Expected image with non zero variance of pixel values'

E = x.abs().mean(dim=(-1, -2))
rho = sigma_sq / E ** 2

Expand All @@ -32,16 +36,25 @@ def _ggd_parameters(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

def _aggd_parameters(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
gamma = torch.arange(start=0.2, end=10.001, step=0.001).to(x)
r_table = torch.exp(2 * torch.lgamma(2. / gamma) - torch.lgamma(1. / gamma) - torch.lgamma(3. / gamma)).repeat(
x.size(0), 1)
r_table = torch.exp(2 * torch.lgamma(2. / gamma) - torch.lgamma(1. / gamma) - torch.lgamma(3. / gamma))
r_table = r_table.repeat(x.size(0), 1)

mask_left = x < 0
mask_right = x > 0
count_left = mask_left.sum(dim=(-1, -2))
count_right = mask_right.sum(dim=(-1, -2))
count_left = mask_left.sum(dim=(-1, -2), dtype=torch.float32)
count_right = mask_right.sum(dim=(-1, -2), dtype=torch.float32)

assert (count_left > 0).all(), 'Expected input tensor (pairwise products of neighboring MSCN coefficients)' \
' with values below zero to compute parameters of AGGD'
assert (count_right > 0).all(), 'Expected input tensor (pairwise products of neighboring MSCN coefficients)' \
' with values above zero to compute parameters of AGGD'

left_sigma = ((x * mask_left).pow(2).sum(dim=(-1, -2)) / count_left).sqrt()
right_sigma = ((x * mask_right).pow(2).sum(dim=(-1, -2)) / count_right).sqrt()

assert (left_sigma > 0).all() and (right_sigma > 0).all(), f'Expected non-zero left and right variances, ' \
f'got {left_sigma} and {right_sigma}'

gamma_hat = left_sigma / right_sigma
ro_hat = x.abs().mean(dim=(-1, -2)).pow(2) / x.pow(2).mean(dim=(-1, -2))
ro_hat_norm = (ro_hat * (gamma_hat.pow(3) + 1) * (gamma_hat + 1)) / (gamma_hat.pow(2) + 1).pow(2)
Expand Down Expand Up @@ -125,8 +138,8 @@ def _RBF_kernel(features: torch.Tensor, sv: torch.Tensor, gamma: float = 0.05) -


def _score_svr(features: torch.Tensor) -> torch.Tensor:
url = 'https://github.com/photosynthesis-team/photosynthesis.metrics/releases/' \
'latest/download/brisque_svm_weights.pt'
url = 'https://github.com/photosynthesis-team/photosynthesis.metrics/' \
'releases/download/v0.4.0/brisque_svm_weights.pt'
sv_coef, sv = load_url(url, map_location=features.device)

# gamma and rho are SVM model parameters taken from official implementation of BRISQUE on MATLAB
Expand Down
1 change: 1 addition & 0 deletions photosynthesis_metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def _validate_input(
assert isinstance(tensor, torch.Tensor), f'Expected input to be torch.Tensor, got {type(tensor)}.'
assert min_n_dim <= tensor.dim() <= max_n_dim, \
f'Input images must be {min_n_dim}D - {max_n_dim}D tensors, got images of shape {tensor.size()}.'
assert torch.all(tensor >= 0), 'Expected input tensor greater or equal 0'
if tensor.dim() == 5:
assert tensor.size(-1) == 2, f'Expected Complex 5D tensor with (N,C,H,W,2) size, got {tensor.size()}'

Expand Down
69 changes: 27 additions & 42 deletions tests/test_brisque.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,29 @@ def prediction_RGB() -> torch.Tensor:

# ================== Test function: `brisque` ==================
def test_brisque_if_works_with_grey(prediction_grey: torch.Tensor) -> None:
try:
brisque(prediction_grey)
except Exception as e:
pytest.fail(f"Unexpected error occurred: {e}")
brisque(prediction_grey)


@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test if there is no GPU.')
def test_brisque_if_works_with_grey_on_gpu(prediction_grey: torch.Tensor) -> None:
try:
prediction_grey = prediction_grey.cuda()
brisque(prediction_grey)
except Exception as e:
pytest.fail(f"Unexpected error occurred: {e}")
prediction_grey = prediction_grey.cuda()
brisque(prediction_grey)


def test_brisque_if_works_with_RGB(prediction_RGB: torch.Tensor) -> None:
try:
brisque(prediction_RGB)
except Exception as e:
pytest.fail(f"Unexpected error occurred: {e}")
brisque(prediction_RGB)


@pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test if there is no GPU.')
def test_brisque_if_works_with_RGB_on_gpu(prediction_RGB: torch.Tensor) -> None:
try:
prediction_RGB = prediction_RGB.cuda()
brisque(prediction_RGB)
except Exception as e:
pytest.fail(f"Unexpected error occurred: {e}")
prediction_RGB = prediction_RGB.cuda()
brisque(prediction_RGB)


def test_brisque_raises_if_wrong_reduction(prediction_grey: torch.Tensor) -> None:
for mode in ['mean', 'sum', 'none']:
try:
brisque(prediction_grey, reduction=mode)
except Exception as e:
pytest.fail(f"Unexpected error occurred: {e}")
brisque(prediction_grey, reduction=mode)

for mode in [None, 'n', 2]:
with pytest.raises(KeyError):
brisque(prediction_grey, reduction=mode)
Expand All @@ -62,49 +48,48 @@ def test_brisque_raises_if_wrong_reduction(prediction_grey: torch.Tensor) -> Non
def test_brisque_values_grey(prediction_grey: torch.Tensor) -> None:
score = brisque(prediction_grey, reduction='none')
score_baseline = torch.tensor([BRISQUE().get_score(img.squeeze().numpy()) for img in prediction_grey])
assert torch.isclose(score, score_baseline, rtol=2e-4, atol=1e-6).all(), f'Expected values to be equal to ' \
f'baseline prediction.' \
f'got {score} and {score_baseline}'
assert torch.isclose(score, score_baseline, atol=1e-1).all(), f'Expected values to be equal to ' \
f'baseline prediction.' \
f'got {score} and {score_baseline}'


def test_brisque_values_RGB(prediction_RGB: torch.Tensor) -> None:
score = brisque(prediction_RGB, reduction='none')
score_baseline = torch.tensor([BRISQUE().get_score(img.squeeze().permute(1, 2, 0).numpy()[..., ::-1])
for img in prediction_RGB])
assert torch.isclose(score, score_baseline, rtol=2e-4).all(), f'Expected values to be equal to ' \
assert torch.isclose(score, score_baseline, atol=1e-1).all(), f'Expected values to be equal to ' \
f'baseline prediction.' \
f'got {score} and {score_baseline}'


def test_brisque_all_zeros_or_ones() -> None:
size = (1, 1, 256, 256)
for tensor in [torch.zeros(size), torch.ones(size)]:
with pytest.raises(AssertionError):
brisque(tensor, reduction='mean')


# ================== Test class: `BRISQUELoss` ==================
def test_brisque_loss_if_works_with_grey(prediction_grey: torch.Tensor) -> None:
prediction_grey_grad = prediction_grey.clone()
prediction_grey_grad.requires_grad_()
try:
loss_value = BRISQUELoss()(prediction_grey_grad)
loss_value.backward()
assert prediction_grey_grad.grad is not None, 'Expected non None gradient of leaf variable'
except Exception as e:
pytest.fail(f"Unexpected error occurred: {e}")
loss_value = BRISQUELoss()(prediction_grey_grad)
loss_value.backward()
assert prediction_grey_grad.grad is not None, 'Expected non None gradient of leaf variable'


def test_brisque_loss_if_works_with_RGB(prediction_RGB: torch.Tensor) -> None:
prediction_RGB_grad = prediction_RGB.clone()
prediction_RGB_grad.requires_grad_()
try:
loss_value = BRISQUELoss()(prediction_RGB_grad)
loss_value.backward()
assert prediction_RGB_grad.grad is not None, 'Expected non None gradient of leaf variable'
except Exception as e:
pytest.fail(f"Unexpected error occurred: {e}")
loss_value = BRISQUELoss()(prediction_RGB_grad)
loss_value.backward()
assert prediction_RGB_grad.grad is not None, 'Expected non None gradient of leaf variable'


def test_brisque_loss_raises_if_wrong_reduction(prediction_grey: torch.Tensor) -> None:
for mode in ['mean', 'sum', 'none']:
try:
BRISQUELoss(reduction=mode)(prediction_grey)
except Exception as e:
pytest.fail(f"Unexpected error occurred: {e}")
BRISQUELoss(reduction=mode)(prediction_grey)

for mode in [None, 'n', 2]:
with pytest.raises(KeyError):
BRISQUELoss(reduction=mode)(prediction_grey)

0 comments on commit 0734b30

Please sign in to comment.