diff --git a/photosynthesis_metrics/brisque.py b/photosynthesis_metrics/brisque.py index 8e4f08a7..03db15ca 100644 --- a/photosynthesis_metrics/brisque.py +++ b/photosynthesis_metrics/brisque.py @@ -22,6 +22,10 @@ def _ggd_parameters(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: sigma_sq = x.pow(2).mean(dim=(-1, -2)) sigma = sigma_sq.sqrt().squeeze(dim=-1) + + assert not torch.isclose(sigma, torch.zeros_like(sigma)).all(), \ + 'Expected image with non zero variance of pixel values' + E = x.abs().mean(dim=(-1, -2)) rho = sigma_sq / E ** 2 @@ -32,16 +36,25 @@ def _ggd_parameters(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def _aggd_parameters(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: gamma = torch.arange(start=0.2, end=10.001, step=0.001).to(x) - r_table = torch.exp(2 * torch.lgamma(2. / gamma) - torch.lgamma(1. / gamma) - torch.lgamma(3. / gamma)).repeat( - x.size(0), 1) + r_table = torch.exp(2 * torch.lgamma(2. / gamma) - torch.lgamma(1. / gamma) - torch.lgamma(3. / gamma)) + r_table = r_table.repeat(x.size(0), 1) mask_left = x < 0 mask_right = x > 0 - count_left = mask_left.sum(dim=(-1, -2)) - count_right = mask_right.sum(dim=(-1, -2)) + count_left = mask_left.sum(dim=(-1, -2), dtype=torch.float32) + count_right = mask_right.sum(dim=(-1, -2), dtype=torch.float32) + + assert (count_left > 0).all(), 'Expected input tensor (pairwise products of neighboring MSCN coefficients)' \ + ' with values below zero to compute parameters of AGGD' + assert (count_right > 0).all(), 'Expected input tensor (pairwise products of neighboring MSCN coefficients)' \ + ' with values above zero to compute parameters of AGGD' left_sigma = ((x * mask_left).pow(2).sum(dim=(-1, -2)) / count_left).sqrt() right_sigma = ((x * mask_right).pow(2).sum(dim=(-1, -2)) / count_right).sqrt() + + assert (left_sigma > 0).all() and (right_sigma > 0).all(), f'Expected non-zero left and right variances, ' \ + f'got {left_sigma} and {right_sigma}' + gamma_hat = left_sigma / right_sigma ro_hat = x.abs().mean(dim=(-1, -2)).pow(2) / x.pow(2).mean(dim=(-1, -2)) ro_hat_norm = (ro_hat * (gamma_hat.pow(3) + 1) * (gamma_hat + 1)) / (gamma_hat.pow(2) + 1).pow(2) @@ -125,8 +138,8 @@ def _RBF_kernel(features: torch.Tensor, sv: torch.Tensor, gamma: float = 0.05) - def _score_svr(features: torch.Tensor) -> torch.Tensor: - url = 'https://github.com/photosynthesis-team/photosynthesis.metrics/releases/' \ - 'latest/download/brisque_svm_weights.pt' + url = 'https://github.com/photosynthesis-team/photosynthesis.metrics/' \ + 'releases/download/v0.4.0/brisque_svm_weights.pt' sv_coef, sv = load_url(url, map_location=features.device) # gamma and rho are SVM model parameters taken from official implementation of BRISQUE on MATLAB diff --git a/photosynthesis_metrics/utils.py b/photosynthesis_metrics/utils.py index 2ffdf0a3..e5782b80 100644 --- a/photosynthesis_metrics/utils.py +++ b/photosynthesis_metrics/utils.py @@ -43,6 +43,7 @@ def _validate_input( assert isinstance(tensor, torch.Tensor), f'Expected input to be torch.Tensor, got {type(tensor)}.' assert min_n_dim <= tensor.dim() <= max_n_dim, \ f'Input images must be {min_n_dim}D - {max_n_dim}D tensors, got images of shape {tensor.size()}.' + assert torch.all(tensor >= 0), 'Expected input tensor greater or equal 0' if tensor.dim() == 5: assert tensor.size(-1) == 2, f'Expected Complex 5D tensor with (N,C,H,W,2) size, got {tensor.size()}' diff --git a/tests/test_brisque.py b/tests/test_brisque.py index 41af27af..45f82207 100644 --- a/tests/test_brisque.py +++ b/tests/test_brisque.py @@ -17,43 +17,29 @@ def prediction_RGB() -> torch.Tensor: # ================== Test function: `brisque` ================== def test_brisque_if_works_with_grey(prediction_grey: torch.Tensor) -> None: - try: - brisque(prediction_grey) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") + brisque(prediction_grey) @pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test if there is no GPU.') def test_brisque_if_works_with_grey_on_gpu(prediction_grey: torch.Tensor) -> None: - try: - prediction_grey = prediction_grey.cuda() - brisque(prediction_grey) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") + prediction_grey = prediction_grey.cuda() + brisque(prediction_grey) def test_brisque_if_works_with_RGB(prediction_RGB: torch.Tensor) -> None: - try: - brisque(prediction_RGB) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") + brisque(prediction_RGB) @pytest.mark.skipif(not torch.cuda.is_available(), reason='No need to run test if there is no GPU.') def test_brisque_if_works_with_RGB_on_gpu(prediction_RGB: torch.Tensor) -> None: - try: - prediction_RGB = prediction_RGB.cuda() - brisque(prediction_RGB) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") + prediction_RGB = prediction_RGB.cuda() + brisque(prediction_RGB) def test_brisque_raises_if_wrong_reduction(prediction_grey: torch.Tensor) -> None: for mode in ['mean', 'sum', 'none']: - try: - brisque(prediction_grey, reduction=mode) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") + brisque(prediction_grey, reduction=mode) + for mode in [None, 'n', 2]: with pytest.raises(KeyError): brisque(prediction_grey, reduction=mode) @@ -62,49 +48,48 @@ def test_brisque_raises_if_wrong_reduction(prediction_grey: torch.Tensor) -> Non def test_brisque_values_grey(prediction_grey: torch.Tensor) -> None: score = brisque(prediction_grey, reduction='none') score_baseline = torch.tensor([BRISQUE().get_score(img.squeeze().numpy()) for img in prediction_grey]) - assert torch.isclose(score, score_baseline, rtol=2e-4, atol=1e-6).all(), f'Expected values to be equal to ' \ - f'baseline prediction.' \ - f'got {score} and {score_baseline}' + assert torch.isclose(score, score_baseline, atol=1e-1).all(), f'Expected values to be equal to ' \ + f'baseline prediction.' \ + f'got {score} and {score_baseline}' def test_brisque_values_RGB(prediction_RGB: torch.Tensor) -> None: score = brisque(prediction_RGB, reduction='none') score_baseline = torch.tensor([BRISQUE().get_score(img.squeeze().permute(1, 2, 0).numpy()[..., ::-1]) for img in prediction_RGB]) - assert torch.isclose(score, score_baseline, rtol=2e-4).all(), f'Expected values to be equal to ' \ + assert torch.isclose(score, score_baseline, atol=1e-1).all(), f'Expected values to be equal to ' \ f'baseline prediction.' \ f'got {score} and {score_baseline}' +def test_brisque_all_zeros_or_ones() -> None: + size = (1, 1, 256, 256) + for tensor in [torch.zeros(size), torch.ones(size)]: + with pytest.raises(AssertionError): + brisque(tensor, reduction='mean') + + # ================== Test class: `BRISQUELoss` ================== def test_brisque_loss_if_works_with_grey(prediction_grey: torch.Tensor) -> None: prediction_grey_grad = prediction_grey.clone() prediction_grey_grad.requires_grad_() - try: - loss_value = BRISQUELoss()(prediction_grey_grad) - loss_value.backward() - assert prediction_grey_grad.grad is not None, 'Expected non None gradient of leaf variable' - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") + loss_value = BRISQUELoss()(prediction_grey_grad) + loss_value.backward() + assert prediction_grey_grad.grad is not None, 'Expected non None gradient of leaf variable' def test_brisque_loss_if_works_with_RGB(prediction_RGB: torch.Tensor) -> None: prediction_RGB_grad = prediction_RGB.clone() prediction_RGB_grad.requires_grad_() - try: - loss_value = BRISQUELoss()(prediction_RGB_grad) - loss_value.backward() - assert prediction_RGB_grad.grad is not None, 'Expected non None gradient of leaf variable' - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") + loss_value = BRISQUELoss()(prediction_RGB_grad) + loss_value.backward() + assert prediction_RGB_grad.grad is not None, 'Expected non None gradient of leaf variable' def test_brisque_loss_raises_if_wrong_reduction(prediction_grey: torch.Tensor) -> None: for mode in ['mean', 'sum', 'none']: - try: - BRISQUELoss(reduction=mode)(prediction_grey) - except Exception as e: - pytest.fail(f"Unexpected error occurred: {e}") + BRISQUELoss(reduction=mode)(prediction_grey) + for mode in [None, 'n', 2]: with pytest.raises(KeyError): BRISQUELoss(reduction=mode)(prediction_grey)