diff --git a/notebooks/example_ms.ipynb b/notebooks/example_ms.ipynb new file mode 100644 index 0000000..28f008e --- /dev/null +++ b/notebooks/example_ms.ipynb @@ -0,0 +1,60 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from bm3dornl.bm3d import bm3d_ring_artifact_removal_ms\n", + "from bm3dornl.denoiser_gpu import memory_cleanup\n", + "\n", + "with open(\"../../tests/bm3dornl-data/sino.npy\", \"rb\") as f:\n", + " sino_noisy = np.load(f)\n", + "\n", + "memory_cleanup()\n", + "\n", + "block_matching_kwargs: dict = {\n", + " \"patch_size\": (8, 8),\n", + " \"stride\": 3,\n", + " \"background_threshold\": 0.0,\n", + " \"cut_off_distance\": (64, 64),\n", + " \"num_patches_per_group\": 32,\n", + " \"padding_mode\": \"circular\",\n", + "}\n", + "filter_kwargs: dict = {\n", + " \"filter_function\": \"fft\",\n", + " \"shrinkage_factor\": 3e-2,\n", + "}\n", + "kwargs = {\n", + " \"mode\": \"simple\",\n", + " \"k\": 4,\n", + " \"block_matching_kwargs\": block_matching_kwargs,\n", + " \"filter_kwargs\": filter_kwargs,\n", + "}\n", + "\n", + "sino_bm3dornl = bm3d_ring_artifact_removal_ms(\n", + " sinogram=sino_noisy,\n", + " **kwargs,\n", + ")\n", + "\n", + "\n", + "fig, axs = plt.subplots(1, 2, figsize=(12, 4))\n", + "axs[0].imshow(sino_noisy, cmap=\"gray\")\n", + "axs[0].set_title(\"Noisy sinogram\")\n", + "axs[1].imshow(sino_bm3dornl, cmap=\"gray\")\n", + "axs[1].set_title(\"BM3D denoised sinogram\")\n", + "plt.show()\n" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/bm3dornl/bm3d.py b/src/bm3dornl/bm3d.py index 953bbd5..6b586e2 100644 --- a/src/bm3dornl/bm3d.py +++ b/src/bm3dornl/bm3d.py @@ -739,26 +739,46 @@ def bm3d_ring_artifact_removal_ms( filter_kwargs=filter_kwargs, ) - # step 1: create a list of binned sinograms - binned_sinos = horizontal_binning(sinogram, k=k) - # reverse the list - binned_sinos = binned_sinos[::-1] - - # step 2: estimate the noise level from the coarsest sinogram, then working back to the original sinogram - noise_estimate = None - for i in range(len(binned_sinos)): - logging.info(f"Processing binned sinogram {i+1} of {len(binned_sinos)}") - sino = binned_sinos[i] - sino_star = ( - sino if i == 0 else sino - horizontal_debinning(noise_estimate, sino) + denoised_sino = None + # Make a copy of an original sinogram + sino_orig = horizontal_binning(sino_star, 1, dim=0) + binned_sinos_orig = [np.copy(sino_orig)] + + # Contains upscaled denoised sinograms + binned_sinos = [np.zeros(0)] + + # Bin horizontally + for i in range(0, k): + binned_sinos_orig.append( + horizontal_binning(binned_sinos_orig[-1], fac=2, dim=1) ) + binned_sinos.append(np.zeros(0)) - if i < len(binned_sinos) - 1: - noise_estimate = sino - bm3d_ring_artifact_removal( - sino_star, - mode=mode, - block_matching_kwargs=block_matching_kwargs, - filter_kwargs=filter_kwargs, + binned_sinos[-1] = binned_sinos_orig[-1] + + for i in range(k, -1, -1): + logging.info(f"Processing binned sinogram {i + 1} of {k}") + # Denoise binned sinogram + denoised_sino = bm3d_ring_artifact_removal( + binned_sinos[i], + mode=mode, + block_matching_kwargs=block_matching_kwargs, + filter_kwargs=filter_kwargs, + ) + # For iterations except the last, create the next noisy image with a finer scale residual + if i > 0: + debinned_sino = horizontal_debinning( + denoised_sino - binned_sinos_orig[i], + binned_sinos_orig[i - 1].shape[1], + 2, + 30, + dim=1, ) + binned_sinos[i - 1] = binned_sinos_orig[i - 1] + debinned_sino + + # residual + sino_star = sino_star + horizontal_debinning( + denoised_sino - sino_orig, sino_star.shape[0], fac=1, n_iter=30, dim=0 + ) return sino_star diff --git a/src/bm3dornl/utils.py b/src/bm3dornl/utils.py index 276a02e..4052cbf 100644 --- a/src/bm3dornl/utils.py +++ b/src/bm3dornl/utils.py @@ -2,8 +2,9 @@ """Utility functions for BM3DORNL.""" import numpy as np -from scipy.interpolate import RectBivariateSpline from numba import njit +from scipy.signal import convolve2d +from scipy.interpolate import interp1d @njit @@ -43,84 +44,130 @@ def is_within_threshold( return np.linalg.norm(ref_patch - cmp_patch) <= intensity_diff_threshold -def horizontal_binning(Z: np.ndarray, k: int = 0) -> list[np.ndarray]: +def create_array(base_arr: np.ndarray, h: int, dim: int): """ - Horizontal binning of the image Z into a list of k images. + Create a padded and convolved array used in both binning and debinning. + Parameters + ---------- + base_arr: Input array + h: bin count + dim: bin dimension (0 or 1) + Returns + ------- + resulting array + """ + mod_pad = h - ((base_arr.shape[dim] - 1) % h) - 1 + if dim == 0: + pads = ((0, mod_pad), (0, 0)) + kernel = np.ones((h, 1)) + else: + pads = ((0, 0), (0, mod_pad)) + kernel = np.ones((1, h)) + + return convolve2d(np.pad(base_arr, pads, "symmetric"), kernel, "same", "fill") + + +def horizontal_binning(Z: np.ndarray, fac: int = 2, dim: int = 1) -> np.ndarray: + """ + Horizontal binning of the image Z Parameters ---------- Z : np.ndarray The image to be binned. - k : int - Number of iterations to bin the image by half. - + fac : int + binning factor + dim : direction X=0, Y=1 Returns ------- - list[np.ndarray] - List of k images. + np.ndarray + binned image - Example - ------- - >>> Z = np.random.rand(64, 64) - >>> binned_zs = horizontal_binning(Z, 3) - >>> len(binned_zs) - 4 """ - binned_zs = [Z] - for _ in range(k): - sub_z0 = Z[:, ::2] - sub_z1 = Z[:, 1::2] - # make sure z0 and z1 have the same shape - if sub_z0.shape[1] > sub_z1.shape[1]: - sub_z0 = sub_z0[:, :-1] - elif sub_z0.shape[1] < sub_z1.shape[1]: - sub_z1 = sub_z1[:, :-1] - # average z0 and z1 - Z = (sub_z0 + sub_z1) * 0.5 - binned_zs.append(Z) - return binned_zs - - -def horizontal_debinning(original_image: np.ndarray, target: np.ndarray) -> np.ndarray: + + if fac > 1: + fac_half = fac // 2 + binned_zs = create_array(Z, fac, dim) + + # get coordinates of bin centres + if dim == 0: + binned_zs = binned_zs[ + fac_half + ((fac % 2) == 1) : binned_zs.shape[dim] - fac_half + 1 : fac, + :, + ] + else: + binned_zs = binned_zs[ + :, + fac_half + ((fac % 2) == 1) : binned_zs.shape[dim] - fac_half + 1 : fac, + ] + + return binned_zs + + return Z + + +def horizontal_debinning( + Z: np.ndarray, size: int, fac: int, n_iter: int, dim: int = 1 +) -> np.ndarray: """ Horizontal debinning of the image Z into the same shape as Z_target. Parameters ---------- - original_image : np.ndarray + Z : np.ndarray The image to be debinned. - target : np.ndarray - The target image to match the shape. + size: target size (original size before binning) for the second dimension + fac: binning factor (original divisor) + n_iter: number of iterations + dim: dimension for binning (Y = 0 or X = 1) Returns ------- np.ndarray The debinned image. - - Example - ------- - >>> Z = np.random.rand(64, 64) - >>> target = np.random.rand(64, 128) - >>> debinned_z = horizontal_debinning(Z, target) - >>> debinned_z.shape - (64, 128) """ - # Original dimensions - original_height, original_width = original_image.shape - # Target dimensions - new_height, new_width = target.shape - - # Original grid - original_x = np.arange(original_width) - original_y = np.arange(original_height) - - # Target grid - new_x = np.linspace(0, original_width - 1, new_width) - new_y = np.linspace(0, original_height - 1, new_height) - - # Spline interpolation - spline = RectBivariateSpline(original_y, original_x, original_image) - interpolated_image = spline(new_y, new_x) + if fac <= 1: + return np.copy(Z) + + fac_half = fac // 2 + + if dim == 0: + base_array = np.ones((size, 1)) + else: + base_array = np.ones((1, size)) + + n_counter = create_array(base_array, fac, dim) + + # coordinates of bin counts + x1c = np.arange(fac_half + ((fac % 2) == 1), (Z.shape[dim]) * fac, fac) + x1 = np.arange(fac_half + 1 - ((fac % 2) == 0) / 2, (Z.shape[dim]) * fac, fac) + + # coordinates of image pixels + ix1 = np.arange(1, size + 1) + + interpolated_image = 0 + + for j in range(max(1, n_iter)): + # residual + if j > 0: + residual = Z - horizontal_binning(interpolated_image, fac, dim) + else: + residual = Z + + # interpolation + if dim == 0: + interp = interp1d( + x1, + residual / n_counter[x1c, :], + kind="cubic", + fill_value="extrapolate", + axis=0, + ) + else: + interp = interp1d( + x1, residual / n_counter[:, x1c], kind="cubic", fill_value="extrapolate" + ) + interpolated_image = interpolated_image + interp(ix1) return interpolated_image diff --git a/tests/unit/bm3dornl/test_bm3d_ring_artifact_removal_ms.py b/tests/unit/bm3dornl/test_bm3d_ring_artifact_removal_ms.py index bcf9e92..895f5d6 100644 --- a/tests/unit/bm3dornl/test_bm3d_ring_artifact_removal_ms.py +++ b/tests/unit/bm3dornl/test_bm3d_ring_artifact_removal_ms.py @@ -1,53 +1,36 @@ import pytest -from unittest.mock import patch import numpy as np from bm3dornl.bm3d import bm3d_ring_artifact_removal_ms +size_x = 256 +size_y = 256 + + @pytest.fixture def setup_sinogram(): - return np.random.rand(256, 256) + return np.random.rand(size_x, size_y) -@patch("bm3dornl.bm3d.horizontal_binning") -@patch("bm3dornl.bm3d.horizontal_debinning") -@patch("bm3dornl.bm3d.bm3d_ring_artifact_removal") +@pytest.mark.cuda_required def test_bm3d_ring_artifact_removal_ms( - mock_bm3d_ring_artifact_removal, - mock_horizontal_debinning, - mock_horizontal_binning, setup_sinogram, ): sinogram = setup_sinogram - binned_sinos = [np.random.rand(64, 64) for _ in range(4)] - mock_horizontal_binning.return_value = binned_sinos - binned_sinos = [np.random.rand(64, 64) for _ in range(4)] - mock_bm3d_ring_artifact_removal.return_value = binned_sinos - binned_sinos = [np.random.rand(64, 64) for _ in range(4)] - mock_horizontal_debinning.return_value = binned_sinos result = bm3d_ring_artifact_removal_ms(sinogram, k=4) assert result is not None - mock_horizontal_binning.assert_called_once_with(sinogram, k=4) - assert mock_bm3d_ring_artifact_removal.call_count == 3 - assert mock_horizontal_debinning.call_count == 3 + + r, c = result.shape + + assert c == size_x + assert r == size_y result_single_pass = bm3d_ring_artifact_removal_ms(sinogram, k=0) assert result_single_pass is not None - mock_bm3d_ring_artifact_removal.assert_called_with( - sinogram, - mode="simple", - block_matching_kwargs={ - "patch_size": (8, 8), - "stride": 3, - "background_threshold": 0.0, - "cut_off_distance": (64, 64), - "num_patches_per_group": 32, - "padding_mode": "circular", - }, - filter_kwargs={ - "filter_function": "fft", - "shrinkage_factor": 3e-2, - }, - ) + + r, c = result_single_pass.shape + + assert c == size_x + assert r == size_y diff --git a/tests/unit/bm3dornl/test_utils.py b/tests/unit/bm3dornl/test_utils.py index 49000a5..2df4783 100644 --- a/tests/unit/bm3dornl/test_utils.py +++ b/tests/unit/bm3dornl/test_utils.py @@ -46,39 +46,22 @@ def test_is_within_threshold(): def test_horizontal_binning(): + size_x, size_y = 64, 64 + k = 6 # Initial setup: Create a test image - Z = np.random.rand(64, 64) - - # Number of binning iterations - k = 3 - - # Perform the binning - binned_images = horizontal_binning(Z, k) - - # Assert we have the correct number of images - assert len(binned_images) == k + 1, "Incorrect number of binned images returned" + Z = np.random.rand(size_y, size_y) # Assert that each image has the correct dimensions - expected_width = 64 - for i, img in enumerate(binned_images): - assert img.shape[0] == 64, f"Height of image {i} is incorrect" - assert img.shape[1] == expected_width, f"Width of image {i} is incorrect" + expected_width = size_x + for i in range(k): + # Perform the binning expected_width = (expected_width + 1) // 2 # Calculate the next expected width - - -def test_horizontal_binning_k_zero(): - Z = np.random.rand(64, 64) - binned_images = horizontal_binning(Z, 0) - assert len(binned_images) == 1 and np.array_equal( - binned_images[0], Z - ), "Binning with k=0 should return only the original image" - - -def test_horizontal_binning_large_k(): - Z = np.random.rand(64, 64) - binned_images = horizontal_binning(Z, 6) - assert len(binned_images) == 7, "Incorrect number of images for large k" - assert binned_images[-1].shape[1] == 1, "Final image width should be 1 for large k" + binned_image = horizontal_binning(Z, fac=2) + assert binned_image.shape[0] == 64, f"Height of image {i} is incorrect" + assert ( + binned_image.shape[1] == expected_width + ), f"Width of image {i} is incorrect" + Z = binned_image @pytest.mark.parametrize( @@ -87,7 +70,9 @@ def test_horizontal_binning_large_k(): def test_horizontal_debinning_scaling(original_width, target_width): original_image = np.random.rand(64, original_width) target_shape = (64, target_width) - debinned_image = horizontal_debinning(original_image, np.empty(target_shape)) + debinned_image = horizontal_debinning( + original_image, target_width, fac=2, dim=1, n_iter=1 + ) assert ( debinned_image.shape == target_shape ), f"Failed to scale from {original_width} to {target_width}"