Skip to content

Commit

Permalink
Merge pull request #16 from ornlneutronimaging/bin
Browse files Browse the repository at this point in the history
Multiscale filtering functionality
  • Loading branch information
KedoKudo authored Jul 3, 2024
2 parents cfc310f + 972fa85 commit 306a6c9
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 138 deletions.
60 changes: 60 additions & 0 deletions notebooks/example_ms.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from bm3dornl.bm3d import bm3d_ring_artifact_removal_ms\n",
"from bm3dornl.denoiser_gpu import memory_cleanup\n",
"\n",
"with open(\"../../tests/bm3dornl-data/sino.npy\", \"rb\") as f:\n",
" sino_noisy = np.load(f)\n",
"\n",
"memory_cleanup()\n",
"\n",
"block_matching_kwargs: dict = {\n",
" \"patch_size\": (8, 8),\n",
" \"stride\": 3,\n",
" \"background_threshold\": 0.0,\n",
" \"cut_off_distance\": (64, 64),\n",
" \"num_patches_per_group\": 32,\n",
" \"padding_mode\": \"circular\",\n",
"}\n",
"filter_kwargs: dict = {\n",
" \"filter_function\": \"fft\",\n",
" \"shrinkage_factor\": 3e-2,\n",
"}\n",
"kwargs = {\n",
" \"mode\": \"simple\",\n",
" \"k\": 4,\n",
" \"block_matching_kwargs\": block_matching_kwargs,\n",
" \"filter_kwargs\": filter_kwargs,\n",
"}\n",
"\n",
"sino_bm3dornl = bm3d_ring_artifact_removal_ms(\n",
" sinogram=sino_noisy,\n",
" **kwargs,\n",
")\n",
"\n",
"\n",
"fig, axs = plt.subplots(1, 2, figsize=(12, 4))\n",
"axs[0].imshow(sino_noisy, cmap=\"gray\")\n",
"axs[0].set_title(\"Noisy sinogram\")\n",
"axs[1].imshow(sino_bm3dornl, cmap=\"gray\")\n",
"axs[1].set_title(\"BM3D denoised sinogram\")\n",
"plt.show()\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
56 changes: 38 additions & 18 deletions src/bm3dornl/bm3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,26 +739,46 @@ def bm3d_ring_artifact_removal_ms(
filter_kwargs=filter_kwargs,
)

# step 1: create a list of binned sinograms
binned_sinos = horizontal_binning(sinogram, k=k)
# reverse the list
binned_sinos = binned_sinos[::-1]

# step 2: estimate the noise level from the coarsest sinogram, then working back to the original sinogram
noise_estimate = None
for i in range(len(binned_sinos)):
logging.info(f"Processing binned sinogram {i+1} of {len(binned_sinos)}")
sino = binned_sinos[i]
sino_star = (
sino if i == 0 else sino - horizontal_debinning(noise_estimate, sino)
denoised_sino = None
# Make a copy of an original sinogram
sino_orig = horizontal_binning(sino_star, 1, dim=0)
binned_sinos_orig = [np.copy(sino_orig)]

# Contains upscaled denoised sinograms
binned_sinos = [np.zeros(0)]

# Bin horizontally
for i in range(0, k):
binned_sinos_orig.append(
horizontal_binning(binned_sinos_orig[-1], fac=2, dim=1)
)
binned_sinos.append(np.zeros(0))

if i < len(binned_sinos) - 1:
noise_estimate = sino - bm3d_ring_artifact_removal(
sino_star,
mode=mode,
block_matching_kwargs=block_matching_kwargs,
filter_kwargs=filter_kwargs,
binned_sinos[-1] = binned_sinos_orig[-1]

for i in range(k, -1, -1):
logging.info(f"Processing binned sinogram {i + 1} of {k}")
# Denoise binned sinogram
denoised_sino = bm3d_ring_artifact_removal(
binned_sinos[i],
mode=mode,
block_matching_kwargs=block_matching_kwargs,
filter_kwargs=filter_kwargs,
)
# For iterations except the last, create the next noisy image with a finer scale residual
if i > 0:
debinned_sino = horizontal_debinning(
denoised_sino - binned_sinos_orig[i],
binned_sinos_orig[i - 1].shape[1],
2,
30,
dim=1,
)
binned_sinos[i - 1] = binned_sinos_orig[i - 1] + debinned_sino

# residual
sino_star = sino_star + horizontal_debinning(
denoised_sino - sino_orig, sino_star.shape[0], fac=1, n_iter=30, dim=0
)

return sino_star
161 changes: 104 additions & 57 deletions src/bm3dornl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
"""Utility functions for BM3DORNL."""

import numpy as np
from scipy.interpolate import RectBivariateSpline
from numba import njit
from scipy.signal import convolve2d
from scipy.interpolate import interp1d


@njit
Expand Down Expand Up @@ -43,84 +44,130 @@ def is_within_threshold(
return np.linalg.norm(ref_patch - cmp_patch) <= intensity_diff_threshold


def horizontal_binning(Z: np.ndarray, k: int = 0) -> list[np.ndarray]:
def create_array(base_arr: np.ndarray, h: int, dim: int):
"""
Horizontal binning of the image Z into a list of k images.
Create a padded and convolved array used in both binning and debinning.
Parameters
----------
base_arr: Input array
h: bin count
dim: bin dimension (0 or 1)
Returns
-------
resulting array
"""
mod_pad = h - ((base_arr.shape[dim] - 1) % h) - 1
if dim == 0:
pads = ((0, mod_pad), (0, 0))
kernel = np.ones((h, 1))
else:
pads = ((0, 0), (0, mod_pad))
kernel = np.ones((1, h))

return convolve2d(np.pad(base_arr, pads, "symmetric"), kernel, "same", "fill")


def horizontal_binning(Z: np.ndarray, fac: int = 2, dim: int = 1) -> np.ndarray:
"""
Horizontal binning of the image Z
Parameters
----------
Z : np.ndarray
The image to be binned.
k : int
Number of iterations to bin the image by half.
fac : int
binning factor
dim : direction X=0, Y=1
Returns
-------
list[np.ndarray]
List of k images.
np.ndarray
binned image
Example
-------
>>> Z = np.random.rand(64, 64)
>>> binned_zs = horizontal_binning(Z, 3)
>>> len(binned_zs)
4
"""
binned_zs = [Z]
for _ in range(k):
sub_z0 = Z[:, ::2]
sub_z1 = Z[:, 1::2]
# make sure z0 and z1 have the same shape
if sub_z0.shape[1] > sub_z1.shape[1]:
sub_z0 = sub_z0[:, :-1]
elif sub_z0.shape[1] < sub_z1.shape[1]:
sub_z1 = sub_z1[:, :-1]
# average z0 and z1
Z = (sub_z0 + sub_z1) * 0.5
binned_zs.append(Z)
return binned_zs


def horizontal_debinning(original_image: np.ndarray, target: np.ndarray) -> np.ndarray:

if fac > 1:
fac_half = fac // 2
binned_zs = create_array(Z, fac, dim)

# get coordinates of bin centres
if dim == 0:
binned_zs = binned_zs[
fac_half + ((fac % 2) == 1) : binned_zs.shape[dim] - fac_half + 1 : fac,
:,
]
else:
binned_zs = binned_zs[
:,
fac_half + ((fac % 2) == 1) : binned_zs.shape[dim] - fac_half + 1 : fac,
]

return binned_zs

return Z


def horizontal_debinning(
Z: np.ndarray, size: int, fac: int, n_iter: int, dim: int = 1
) -> np.ndarray:
"""
Horizontal debinning of the image Z into the same shape as Z_target.
Parameters
----------
original_image : np.ndarray
Z : np.ndarray
The image to be debinned.
target : np.ndarray
The target image to match the shape.
size: target size (original size before binning) for the second dimension
fac: binning factor (original divisor)
n_iter: number of iterations
dim: dimension for binning (Y = 0 or X = 1)
Returns
-------
np.ndarray
The debinned image.
Example
-------
>>> Z = np.random.rand(64, 64)
>>> target = np.random.rand(64, 128)
>>> debinned_z = horizontal_debinning(Z, target)
>>> debinned_z.shape
(64, 128)
"""
# Original dimensions
original_height, original_width = original_image.shape
# Target dimensions
new_height, new_width = target.shape

# Original grid
original_x = np.arange(original_width)
original_y = np.arange(original_height)

# Target grid
new_x = np.linspace(0, original_width - 1, new_width)
new_y = np.linspace(0, original_height - 1, new_height)

# Spline interpolation
spline = RectBivariateSpline(original_y, original_x, original_image)
interpolated_image = spline(new_y, new_x)
if fac <= 1:
return np.copy(Z)

fac_half = fac // 2

if dim == 0:
base_array = np.ones((size, 1))
else:
base_array = np.ones((1, size))

n_counter = create_array(base_array, fac, dim)

# coordinates of bin counts
x1c = np.arange(fac_half + ((fac % 2) == 1), (Z.shape[dim]) * fac, fac)
x1 = np.arange(fac_half + 1 - ((fac % 2) == 0) / 2, (Z.shape[dim]) * fac, fac)

# coordinates of image pixels
ix1 = np.arange(1, size + 1)

interpolated_image = 0

for j in range(max(1, n_iter)):
# residual
if j > 0:
residual = Z - horizontal_binning(interpolated_image, fac, dim)
else:
residual = Z

# interpolation
if dim == 0:
interp = interp1d(
x1,
residual / n_counter[x1c, :],
kind="cubic",
fill_value="extrapolate",
axis=0,
)
else:
interp = interp1d(
x1, residual / n_counter[:, x1c], kind="cubic", fill_value="extrapolate"
)
interpolated_image = interpolated_image + interp(ix1)

return interpolated_image

Expand Down
49 changes: 16 additions & 33 deletions tests/unit/bm3dornl/test_bm3d_ring_artifact_removal_ms.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,36 @@
import pytest
from unittest.mock import patch
import numpy as np
from bm3dornl.bm3d import bm3d_ring_artifact_removal_ms


size_x = 256
size_y = 256


@pytest.fixture
def setup_sinogram():
return np.random.rand(256, 256)
return np.random.rand(size_x, size_y)


@patch("bm3dornl.bm3d.horizontal_binning")
@patch("bm3dornl.bm3d.horizontal_debinning")
@patch("bm3dornl.bm3d.bm3d_ring_artifact_removal")
@pytest.mark.cuda_required
def test_bm3d_ring_artifact_removal_ms(
mock_bm3d_ring_artifact_removal,
mock_horizontal_debinning,
mock_horizontal_binning,
setup_sinogram,
):
sinogram = setup_sinogram
binned_sinos = [np.random.rand(64, 64) for _ in range(4)]
mock_horizontal_binning.return_value = binned_sinos
binned_sinos = [np.random.rand(64, 64) for _ in range(4)]
mock_bm3d_ring_artifact_removal.return_value = binned_sinos
binned_sinos = [np.random.rand(64, 64) for _ in range(4)]
mock_horizontal_debinning.return_value = binned_sinos

result = bm3d_ring_artifact_removal_ms(sinogram, k=4)

assert result is not None
mock_horizontal_binning.assert_called_once_with(sinogram, k=4)
assert mock_bm3d_ring_artifact_removal.call_count == 3
assert mock_horizontal_debinning.call_count == 3

r, c = result.shape

assert c == size_x
assert r == size_y

result_single_pass = bm3d_ring_artifact_removal_ms(sinogram, k=0)
assert result_single_pass is not None
mock_bm3d_ring_artifact_removal.assert_called_with(
sinogram,
mode="simple",
block_matching_kwargs={
"patch_size": (8, 8),
"stride": 3,
"background_threshold": 0.0,
"cut_off_distance": (64, 64),
"num_patches_per_group": 32,
"padding_mode": "circular",
},
filter_kwargs={
"filter_function": "fft",
"shrinkage_factor": 3e-2,
},
)

r, c = result_single_pass.shape

assert c == size_x
assert r == size_y
Loading

0 comments on commit 306a6c9

Please sign in to comment.