Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve multi-scale vertical streak removal quality #29

Merged
merged 8 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions notebooks/demo_denoise_mode.ipynb

Large diffs are not rendered by default.

36 changes: 18 additions & 18 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

118 changes: 88 additions & 30 deletions notebooks/example_ms.ipynb

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions src/bm3dornl/block_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def get_signal_patch_positions(
# Note: raise error when couldn't find a single signal patch from the entire
# sinogram, which usually indicating a bad background estimation.
if len(signal_patches) == 0:
raise ValueError(
"Couldn't find any signal patches in the image! Please check the background threshold."
)
raise ValueError("Couldn't find any signal patches in the image!")

return np.array(signal_patches)

Expand Down
209 changes: 166 additions & 43 deletions src/bm3dornl/bm3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import logging
import numpy as np
import cupy as cp
from typing import Tuple, Callable
from scipy.signal import medfilt2d
from .block_matching import (
get_signal_patch_positions,
get_patch_numba,
Expand Down Expand Up @@ -32,8 +32,8 @@
hadamard_transform,
)
from .utils import (
horizontal_binning,
horizontal_debinning,
downscale_2d_horizontal,
upscale_2d_horizontal,
)

# NOTE: These default parameters are based on the parameter tuning study.
Expand Down Expand Up @@ -295,6 +295,66 @@ def global_fourier_thresholding(
return new_noisy_image


def padded_piecewise_weighted_denoising(
sinogram: np.ndarray,
window_size: int = 50,
step_size: int = 10,
pad_size: int = None,
) -> np.ndarray:
"""
Perform piecewise weighted denoising on a sinogram with padding using CuPy for GPU acceleration.

Parameters
----------
sinogram : np.ndarray
Input sinogram to be denoised.
window_size : int, optional
Size of the window used for piecewise denoising, by default 50.
step_size : int, optional
Step size for moving the window, by default 10.
pad_size : int, optional
Padding size added to the sinogram before denoising. If None, pad_size is set to 2 * window_size.

Returns
-------
np.ndarray
Denoised sinogram after piecewise weighted denoising.
"""
if pad_size is None:
pad_size = window_size * 2

# Move data to GPU
sinogram_gpu = cp.asarray(sinogram)

# Pad the sinogram
padded_sinogram = cp.pad(sinogram_gpu, ((pad_size, pad_size), (0, 0)), mode="wrap")

rows, _ = padded_sinogram.shape
new_img = cp.zeros_like(padded_sinogram, dtype=cp.float32)
new_wgt = cp.zeros_like(padded_sinogram, dtype=cp.float32)

for i in range(0, rows, step_size):
end = min(i + window_size, rows)
start = max(0, end - window_size)
window = padded_sinogram[start:end, :]
median = cp.median(window, axis=0)
new_img[start:end, :] += window - median
new_wgt[start:end, :] += 1

# Avoid division by zero
new_wgt = cp.maximum(new_wgt, 1e-10)
denoised = new_img / new_wgt

# Restore the overall intensity level
denoised += cp.median(sinogram_gpu)

# Remove padding
denoised = denoised[pad_size:-pad_size, :]

# Move result back to CPU
return cp.asnumpy(denoised)


def estimate_noise_free_sinogram(sinogram: np.ndarray) -> np.ndarray:
"""
Estimate noise-free sinogram from noisy sinogram.
Expand All @@ -309,14 +369,18 @@ def estimate_noise_free_sinogram(sinogram: np.ndarray) -> np.ndarray:
np.ndarray
Noise-free sinogram.
"""
# subtract column-wise median
sinogram = sinogram - np.median(sinogram, axis=0)
# perform median filtering to remove salt-and-pepper noise
sinogram = medfilt2d(sinogram, kernel_size=3)
# rescale to [0, 1]
sinogram -= sinogram.min()
sinogram /= sinogram.max()
return sinogram
# use piecewise weighted denoising
window_size = sinogram.shape[0] // 4
step_size = 1
denoised = padded_piecewise_weighted_denoising(
sinogram, window_size=window_size, step_size=step_size
)

# normalize to [0, 1]
denoised -= np.min(denoised)
denoised /= np.max(denoised)

return denoised


def bm3d_full(
Expand Down Expand Up @@ -654,14 +718,62 @@ def bm3d_ring_artifact_removal(
raise ValueError(f"Unknown mode: {mode}")


def get_scale_adjusted_blockmatching_params(
original_params: dict, scale_factor: int
) -> dict:
"""Scale the parameters based on the given factor.

Parameters
----------
original_params : dict
The original parameters.
scale_factor : int
The scale factor.

Returns
-------
dict
The adjusted parameters.
"""
adjusted_params = original_params.copy()

# Adjust patch size
# minimum patch size is 3x3
adjusted_params["patch_size"] = tuple(
max(3, int(x * scale_factor)) for x in original_params["patch_size"]
)

# Adjust stride
# minimum stride is 1
adjusted_params["stride"] = max(1, int(original_params["stride"] * scale_factor))

# Adjust cut-off distance
# minimum cut-off distance is 8
adjusted_params["cut_off_distance"] = tuple(
max(8, int(x / scale_factor)) for x in original_params["cut_off_distance"]
)

# Optionally adjust number of patches per group
if scale_factor > 1:
adjusted_params["num_patches_per_group"] = max(
16, original_params["num_patches_per_group"] // scale_factor
)

return adjusted_params


def bm3d_ring_artifact_removal_ms(
sinogram: np.ndarray,
k: int = 3,
mode: str = "simple", # express, simple, full
block_matching_kwargs: dict = default_block_matching_kwargs,
filter_kwargs: dict = default_filter_kwargs,
use_iterative_refinement: bool = True,
refinement_iterations: int = 3,
scale_factor_base: int = 2,
) -> np.ndarray:
"""Multiscale BM3D for streak removal
"""
Multiscale BM3D for streak removal

Parameters
----------
Expand All @@ -675,6 +787,12 @@ def bm3d_ring_artifact_removal_ms(
The block matching parameters.
filter_kwargs : dict
The filter parameters.
use_iterative_refinement : bool, optional
Whether to use iterative refinement in upscaling, by default True
refinement_iterations : int, optional
Number of refinement iterations if using iterative refinement, by default 3
scale_factor_base : int, optional
The base scale factor for binning, by default 2

Returns
-------
Expand All @@ -686,8 +804,8 @@ def bm3d_ring_artifact_removal_ms(
[1] ref: `Collaborative Filtering of Correlated Noise <https://doi.org/10.1109/TIP.2020.3014721>`_
[2] ref: `Ring artifact reduction via multiscale nonlocal collaborative filtering of spatially correlated noise <https://doi.org/10.1107/S1600577521001910>`_
"""
# step 0: median filter the sinogram
sino_star = sinogram
# step 0: initialize
sino_star = np.array(sinogram)

if k == 0:
# single pass
Expand All @@ -698,46 +816,51 @@ def bm3d_ring_artifact_removal_ms(
filter_kwargs=filter_kwargs,
)

denoised_sino = None
# Make a copy of an original sinogram
sino_orig = horizontal_binning(sino_star, 1, dim=0)
binned_sinos_orig = [np.copy(sino_orig)]

# Contains upscaled denoised sinograms
binned_sinos = [np.zeros(0)]
binned_sinos_orig = [sino_star]

# Bin horizontally
for i in range(0, k):
binned_sinos_orig.append(
horizontal_binning(binned_sinos_orig[-1], fac=2, dim=1)
)
binned_sinos.append(np.zeros(0))

binned_sinos[-1] = binned_sinos_orig[-1]
for i in range(k):
binned_sinos_orig.append(downscale_2d_horizontal(binned_sinos_orig[-1], 2))

# Multi-scale denoising
for i in range(k, -1, -1):
logging.info(f"Processing binned sinogram {i + 1} of {k}")
logging.info(f"Processing binned sinogram {i + 1} of {k + 1}")

# compute the adjusted parameters
scale_factor = int(scale_factor_base ** (i / 2))
adjusted_block_matching_kwargs = get_scale_adjusted_blockmatching_params(
block_matching_kwargs, scale_factor
)
adjusted_filter_kwargs = filter_kwargs.copy()
adjusted_filter_kwargs["shrinkage_factor"] = (
filter_kwargs["shrinkage_factor"] / scale_factor
)

# Denoise binned sinogram
denoised_sino = bm3d_ring_artifact_removal(
binned_sinos[i],
binned_sinos_orig[i],
mode=mode,
block_matching_kwargs=block_matching_kwargs,
filter_kwargs=filter_kwargs,
block_matching_kwargs=adjusted_block_matching_kwargs,
filter_kwargs=adjusted_filter_kwargs,
)

# For iterations except the last, create the next noisy image with a finer scale residual
if i > 0:
debinned_sino = horizontal_debinning(
denoised_sino - binned_sinos_orig[i],
binned_sinos_orig[i - 1].shape[1],
# Calculate the noise at current scale
noise_at_scale_i = binned_sinos_orig[i] - denoised_sino

# Upscale the noise to the next finer scale
upscaled_noise = upscale_2d_horizontal(
noise_at_scale_i,
2,
30,
dim=1,
original_width=binned_sinos_orig[i - 1].shape[1],
use_iterative_refinement=use_iterative_refinement,
refinement_iterations=refinement_iterations,
)
binned_sinos[i - 1] = binned_sinos_orig[i - 1] + debinned_sino

# residual
sino_star = sino_star + horizontal_debinning(
denoised_sino - sino_orig, sino_star.shape[0], fac=1, n_iter=30, dim=0
)
# Remove the upscaled noise from the finer scale
# NOTE: The subtraction of noise will also be upscaled in the next iteration, therefore
# propagating the noise removal from coarser to finer scales
binned_sinos_orig[i - 1] -= upscaled_noise

return sino_star
return binned_sinos_orig[0]
Loading
Loading