Skip to content

Commit

Permalink
add unit test for bm3d module
Browse files Browse the repository at this point in the history
  • Loading branch information
KedoKudo committed Jun 18, 2024
1 parent 4f61430 commit 54b14d0
Show file tree
Hide file tree
Showing 9 changed files with 529 additions and 110 deletions.
84 changes: 84 additions & 0 deletions src/bm3dornl/bm3d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
"""Denoising functions using CuPy for GPU acceleration."""

import logging
import numpy as np
from typing import Tuple, Callable
from scipy.ndimage import gaussian_filter
Expand Down Expand Up @@ -30,6 +31,10 @@
fft_transform,
hadamard_transform,
)
from .utils import (
horizontal_binning,
horizontal_debinning,
)


def shrinkage_via_hardthresholding(
Expand Down Expand Up @@ -678,3 +683,82 @@ def bm3d_ring_artifact_removal(
)
else:
raise ValueError(f"Unknown mode: {mode}")


def bm3d_ring_artifact_removal_ms(
sinogram: np.ndarray,
k: int = 4,
mode: str = "simple", # express, simple, full
block_matching_kwargs: dict = {
"patch_size": (8, 8),
"stride": 3,
"background_threshold": 0.0,
"cut_off_distance": (64, 64),
"num_patches_per_group": 32,
"padding_mode": "circular",
},
filter_kwargs: dict = {
"filter_function": "fft",
"shrinkage_factor": 3e-2,
},
) -> np.ndarray:
"""Multiscale BM3D for streak removal
Parameters
----------
sinogram : np.ndarray
The input sinogram to be denoised.
k : int, optional
The number of iterations for horizontal binning, by default 3
mode : str
The denoising mode to use.
block_matching_kwargs : dict
The block matching parameters.
filter_kwargs : dict
The filter parameters.
Returns
-------
np.ndarray
The denoised sinogram.
References
----------
[1] ref: `Collaborative Filtering of Correlated Noise <https://doi.org/10.1109/TIP.2020.3014721>`_
[2] ref: `Ring artifact reduction via multiscale nonlocal collaborative filtering of spatially correlated noise <https://doi.org/10.1107/S1600577521001910>`_
"""
# step 0: median filter the sinogram
sino_star = sinogram

if k == 0:
# single pass
return bm3d_ring_artifact_removal(
sino_star,
mode=mode,
block_matching_kwargs=block_matching_kwargs,
filter_kwargs=filter_kwargs,
)

# step 1: create a list of binned sinograms
binned_sinos = horizontal_binning(sinogram, k=k)
# reverse the list
binned_sinos = binned_sinos[::-1]

# step 2: estimate the noise level from the coarsest sinogram, then working back to the original sinogram
noise_estimate = None
for i in range(len(binned_sinos)):
logging.info(f"Processing binned sinogram {i+1} of {len(binned_sinos)}")
sino = binned_sinos[i]
sino_star = (
sino if i == 0 else sino - horizontal_debinning(noise_estimate, sino)
)

if i < len(binned_sinos) - 1:
noise_estimate = sino - bm3d_ring_artifact_removal(
sino_star,
mode=mode,
block_matching_kwargs=block_matching_kwargs,
filter_kwargs=filter_kwargs,
)

return sino_star
47 changes: 47 additions & 0 deletions tests/unit/bm3dornl/test_bm3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
import numpy as np
from bm3dornl.bm3d import (
global_fourier_thresholding,
global_wiener_filtering,
estimate_noise_free_sinogram,
)


def test_global_fourier_thresholding():
noisy_image = np.random.rand(256, 256)
noise_psd = np.random.rand(256, 256)
estimated_image = np.random.rand(256, 256)

result = global_fourier_thresholding(noisy_image, noise_psd, estimated_image)

assert result is not None
assert result.shape == noisy_image.shape
assert np.all(np.isfinite(result))


def test_global_wiener_filtering():
sinogram = np.random.rand(256, 256)

result = global_wiener_filtering(sinogram)

assert result is not None
assert result.shape == sinogram.shape
assert np.all(np.isfinite(result))
assert np.min(result) >= 0
assert np.max(result) <= 1


def test_estimate_noise_free_sinogram():
sinogram = np.random.rand(256, 256)

result = estimate_noise_free_sinogram(sinogram)

assert result is not None
assert result.shape == sinogram.shape
assert np.all(np.isfinite(result))
assert np.min(result) >= 0
assert np.max(result) <= 1


if __name__ == "__main__":
pytest.main([__file__])
86 changes: 86 additions & 0 deletions tests/unit/bm3dornl/test_bm3d_collaborative_filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
from unittest.mock import patch
import numpy as np
from bm3dornl.bm3d import collaborative_filtering


@pytest.fixture
def setup_data():
sinogram = np.random.rand(256, 256)
denoised_sinogram = np.random.rand(256, 256)
patch_size = (8, 8)
num_patches_per_group = 32
padding_mode = "circular"
noise_variance = np.random.rand(100, 32, 8, 8)
patch_positions = np.random.randint(0, 256, (100, 2))
cut_off_distance = (64, 64)
return (
sinogram,
denoised_sinogram,
patch_size,
num_patches_per_group,
padding_mode,
noise_variance,
patch_positions,
cut_off_distance,
)


@patch("bm3dornl.bm3d.aggregate_denoised_block_to_image")
@patch("bm3dornl.bm3d.collaborative_wiener_filtering")
@patch("bm3dornl.bm3d.form_hyper_blocks_from_two_images")
@patch("bm3dornl.bm3d.compute_distance_matrix_no_variance")
@patch("bm3dornl.bm3d.get_patch_numba")
def test_collaborative_filtering(
mock_get_patch_numba,
mock_distance_matrix,
mock_form_hyper_blocks,
mock_collaborative_filtering,
mock_aggregate_block,
setup_data,
):
(
sinogram,
denoised_sinogram,
patch_size,
num_patches_per_group,
padding_mode,
noise_variance,
patch_positions,
cut_off_distance,
) = setup_data

mock_get_patch_numba.return_value = np.random.rand(8, 8)
mock_distance_matrix.return_value = np.random.rand(100, 100)
mock_form_hyper_blocks.return_value = (
np.random.rand(100, 32, 8, 8),
np.random.rand(100, 32, 8, 8),
np.random.randint(0, 256, (100, 32, 2)),
np.random.rand(100, 32, 8, 8),
)
mock_collaborative_filtering.return_value = np.random.rand(100, 32, 8, 8)
mock_aggregate_block.return_value = np.random.rand(256, 256)

result = collaborative_filtering(
sinogram,
denoised_sinogram,
patch_size,
num_patches_per_group,
padding_mode,
noise_variance,
patch_positions,
cut_off_distance,
lambda x: x,
mock_collaborative_filtering,
)

assert result is not None
assert mock_get_patch_numba.call_count == len(patch_positions)
mock_distance_matrix.assert_called_once()
mock_form_hyper_blocks.assert_called_once()
mock_collaborative_filtering.assert_called_once()
mock_aggregate_block.assert_called_once()


if __name__ == "__main__":
pytest.main([__file__])
55 changes: 55 additions & 0 deletions tests/unit/bm3dornl/test_bm3d_full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
from unittest.mock import patch
import numpy as np
from bm3dornl.bm3d import bm3d_full


@pytest.fixture
def setup_sinogram():
return np.random.rand(256, 256)


@patch("bm3dornl.bm3d.get_patch_numba")
@patch("bm3dornl.bm3d.global_fourier_thresholding")
@patch("bm3dornl.bm3d.collaborative_filtering")
@patch("bm3dornl.bm3d.shrinkage_via_hardthresholding")
@patch("bm3dornl.bm3d.get_signal_patch_positions")
@patch("bm3dornl.bm3d.estimate_noise_psd")
@patch("bm3dornl.bm3d.get_exact_noise_variance")
@patch("bm3dornl.bm3d.fft_transform")
def test_bm3d_full(
mock_fft_transform,
mock_get_exact_noise_variance,
mock_estimate_noise_psd,
mock_get_signal_patch_positions,
mock_shrinkage_via_hardthresholding,
mock_collaborative_filtering,
mock_global_fourier_thresholding,
mock_get_patch_numba,
setup_sinogram,
):
sinogram = setup_sinogram
mock_get_signal_patch_positions.return_value = np.random.randint(0, 256, (100, 2))
mock_fft_transform.return_value = np.random.rand(100, 8, 8)
mock_get_exact_noise_variance.return_value = np.random.rand(100, 8, 8)
mock_shrinkage_via_hardthresholding.return_value = np.random.rand(256, 256)
mock_collaborative_filtering.return_value = np.random.rand(256, 256)
mock_global_fourier_thresholding.return_value = np.random.rand(256, 256)
mock_estimate_noise_psd.return_value = np.random.rand(256, 256)
mock_get_patch_numba.return_value = np.random.rand(8, 8)

result = bm3d_full(sinogram)

assert result is not None
mock_get_signal_patch_positions.assert_called()
mock_fft_transform.assert_called()
mock_get_exact_noise_variance.assert_called()
mock_shrinkage_via_hardthresholding.assert_called()
mock_collaborative_filtering.assert_called()
mock_global_fourier_thresholding.assert_called()
mock_estimate_noise_psd.assert_called()
mock_get_patch_numba.assert_called()


if __name__ == "__main__":
pytest.main([__file__])
43 changes: 43 additions & 0 deletions tests/unit/bm3dornl/test_bm3d_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
from unittest.mock import patch
import numpy as np
from bm3dornl.bm3d import bm3d_lite


@pytest.fixture
def setup_sinogram():
return np.random.rand(256, 256)


@patch("bm3dornl.bm3d.get_patch_numba")
@patch("bm3dornl.bm3d.global_fourier_thresholding")
@patch("bm3dornl.bm3d.collaborative_filtering")
@patch("bm3dornl.bm3d.estimate_noise_free_sinogram")
@patch("bm3dornl.bm3d.get_signal_patch_positions")
def test_bm3d_lite(
mock_get_signal_patch_positions,
mock_estimate_noise_free_sinogram,
mock_collaborative_filtering,
mock_global_fourier_thresholding,
mock_get_patch_numba,
setup_sinogram,
):
sinogram = setup_sinogram
mock_get_signal_patch_positions.return_value = np.random.randint(0, 256, (100, 2))
mock_estimate_noise_free_sinogram.return_value = np.random.rand(256, 256)
mock_collaborative_filtering.return_value = np.random.rand(256, 256)
mock_global_fourier_thresholding.return_value = np.random.rand(256, 256)
mock_get_patch_numba.return_value = np.random.rand(8, 8)

result = bm3d_lite(sinogram)

assert result is not None
mock_get_signal_patch_positions.assert_called_once()
mock_estimate_noise_free_sinogram.assert_called_once()
mock_collaborative_filtering.assert_called()
mock_global_fourier_thresholding.assert_called()
mock_get_patch_numba.assert_called()


if __name__ == "__main__":
pytest.main([__file__])
Loading

0 comments on commit 54b14d0

Please sign in to comment.