Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the denoise quality from hard thresholding #5

Merged
merged 10 commits into from
May 15, 2024
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/ornlneutronimaging/bm3dornl/next.svg)](https://results.pre-commit.ci/latest/github/ornlneutronimaging/bm3dornl/next)

BM3D ORNL repo
--------------
This repository contains the BM3D ORNL code, which is a Python implementation of the BM3D denoising algorithm. The BM3D algorithm was originally proposed by K. Dabov, A. Foi, V. Katkovnik, and K. Egiazarian in the paper "Image Denoising by Sparse 3D Transform-Domain Collaborative Filtering" (2007).
The BM3D algorithm is a state-of-the-art denoising algorithm that is widely used in the image processing community.
The BM3D ORNL code is a Python implementation of the BM3D algorithm that has been optimized for performance using both `Numba` and `CuPy`.
The BM3D ORNL code is designed to be easy to use and easy to integrate into existing Python workflows.
The BM3D ORNL code is released under an open-source license, and is freely available for download and use.
2 changes: 1 addition & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
How to build doc
How to build doc
2 changes: 1 addition & 1 deletion docs/developer.rst
Original file line number Diff line number Diff line change
@@ -1 +1 @@
developer file
developer file
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Index file
Index file
2 changes: 1 addition & 1 deletion docs/user.rst
Original file line number Diff line number Diff line change
@@ -1 +1 @@
user file
user file
45 changes: 24 additions & 21 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

30 changes: 26 additions & 4 deletions src/bm3dornl/block_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,15 @@ def group_signal_patches(
Maximum Euclidean distance in intensity for patches to be considered similar.
"""
num_patches = len(self.signal_patches_pos)
self.signal_blocks_matrix = np.eye(num_patches, dtype=bool)
# Initialize the signal blocks matrix
# note:
# - the matrix is symmetric
# - the zero values means the patches are not similar
# - the non-zero values are the Euclidean distance between the patches, i.e smaller values means smaller distance, higher similarity
self.signal_blocks_matrix = np.zeros(
(num_patches, num_patches),
dtype=float,
)

# Cache patches as views
cached_patches = [self.get_patch(pos) for pos in self.signal_patches_pos]
Expand All @@ -108,8 +116,16 @@ def group_signal_patches(
cached_patches[neightbor_patch_id],
intensity_diff_threshold,
):
self.signal_blocks_matrix[ref_patch_id, neightbor_patch_id] = True
self.signal_blocks_matrix[neightbor_patch_id, ref_patch_id] = True
val_diff = max(
np.linalg.norm(ref_patch - cached_patches[neightbor_patch_id]),
1e-8,
)
self.signal_blocks_matrix[ref_patch_id, neightbor_patch_id] = (
val_diff
)
self.signal_blocks_matrix[neightbor_patch_id, ref_patch_id] = (
val_diff
)

def get_hyper_block(
self,
Expand Down Expand Up @@ -146,7 +162,13 @@ def get_hyper_block(
positions = np.empty((group_size, num_patches_per_group, 2), dtype=np.int32)

for i, row in enumerate(self.signal_blocks_matrix):
candidate_patch_ids = np.where(row)[0]
# find the ids
candidate_patch_ids = np.where(row > 0)[0]
# get the difference
candidate_patch_val = row[candidate_patch_ids]
# sort candidate_patch_ids by candidate_patch_val, smallest first
candidate_patch_ids = candidate_patch_ids[np.argsort(candidate_patch_val)]
# pad the patch ids
padded_patch_ids = pad_patch_ids(
candidate_patch_ids, num_patches_per_group, mode=padding_mode
)
Expand Down
27 changes: 22 additions & 5 deletions src/bm3dornl/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def re_filtering(
)

logging.info("Wiener-Hadamard filtering...")
block = wiener_hadamard(block, sigma_squared)
block = wiener_hadamard(block, sigma_squared * 1e3) # why does this work?

# manual release of memory
memory_cleanup()
Expand Down Expand Up @@ -214,11 +214,12 @@ def denoise(
self.thresholding(
cut_off_distance, intensity_diff_threshold, num_patches_per_group, threshold
)
self.final_denoised_image = self.estimate_denoised_image

logging.info("Second pass: Re-filtering")
self.re_filtering(
cut_off_distance, intensity_diff_threshold, num_patches_per_group
)
# logging.info("Second pass: Re-filtering")
# self.re_filtering(
# cut_off_distance, intensity_diff_threshold, num_patches_per_group
# )


def bm3d_streak_removal(
Expand Down Expand Up @@ -269,6 +270,22 @@ def bm3d_streak_removal(
sinogram = medfilt2d(sinogram, kernel_size=3)
sino_star = sinogram

if k == 0:
# direct without multi-scale
worker = BM3D(
image=sino_star,
patch_size=patch_size,
stride=stride,
background_threshold=background_threshold,
)
worker.denoise(
cut_off_distance=cut_off_distance,
intensity_diff_threshold=intensity_diff_threshold,
num_patches_per_group=num_patches_per_group,
threshold=shrinkage_threshold,
)
return worker.final_denoised_image

# step 1: create a list of binned sinograms
binned_sinos = horizontal_binning(sinogram, k=k)
# reverse the list
Expand Down
3 changes: 3 additions & 0 deletions src/bm3dornl/gpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def hard_thresholding(
# Transform the patch block to the frequency domain using rfft
hyper_block = cp.fft.rfft2(hyper_block, axes=(1, 2, 3))

# find the quantile value based on the threshold
threshold = cp.quantile(cp.abs(hyper_block), threshold)

# Apply hard thresholding
hyper_block[cp.abs(hyper_block) < threshold] = 0

Expand Down
2 changes: 1 addition & 1 deletion src/bm3dornl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def find_candidate_patch_ids(
"""
num_patches = signal_patches.shape[0]
ref_pos = signal_patches[ref_index]
candidate_patch_ids = []
candidate_patch_ids = [ref_index]

for i in range(ref_index + 1, num_patches): # Ensure only checking upper triangle
if (
Expand Down
13 changes: 10 additions & 3 deletions tests/unit/bm3dornl/test_block_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,16 @@ def test_group_signal_patches_geometric(patch_manager):
cut_off_distance = (100, 100) # Larger than image dimensions
intensity_diff_threshold = 0.5 # Irrelevant due to uniform image
patch_manager.group_signal_patches(cut_off_distance, intensity_diff_threshold)
expected_blocks = np.ones(
(len(patch_manager.signal_patches_pos), len(patch_manager.signal_patches_pos)),
dtype=bool,
# uniform image, all patches are similar, so everybody get the smallest distance
expected_blocks = (
np.ones(
(
len(patch_manager.signal_patches_pos),
len(patch_manager.signal_patches_pos),
),
dtype=float,
)
* 1e-8
)
np.testing.assert_array_equal(
patch_manager.signal_blocks_matrix,
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/bm3dornl/test_gpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
def test_hard_thresholding():
# Setup the patch block
patch_block = np.random.rand(2, 5, 8, 8) # Random block of patches on GPU
threshold = 0.5 # Threshold for hard thresholding
threshold_quantile = 0.5 # Threshold for hard thresholding

# Apply shrinkage
denoised_block = hard_thresholding(patch_block, threshold)
denoised_block = hard_thresholding(patch_block, threshold_quantile)

# Convert back to frequency domain to check thresholding
dct_block_check = cp.fft.rfft2(cp.asarray(denoised_block), axes=(1, 2, 3)).get()
threshold = np.quantile(np.abs(dct_block_check), threshold_quantile)

# Test if all values in the DCT domain are either zero or above the threshold
# Allow a small tolerance for floating point arithmetic issues
Expand All @@ -39,6 +40,7 @@ def test_hard_thresholding():

# Check for any values that should not have been zeroed out
original_dct_block = cp.fft.rfft2(cp.asarray(patch_block), axes=(1, 2, 3)).get()
threshold = np.quantile(np.abs(original_dct_block), threshold_quantile)
should_not_change = np.abs(original_dct_block) >= threshold
assert np.allclose(
dct_block_check[should_not_change],
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/bm3dornl/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_find_candidate_patch_ids():
ref_index = 0
cut_off_distance = (1, 1)
expected = [
0,
1,
3,
4,
Expand All @@ -32,21 +33,21 @@ def test_find_candidate_patch_ids():
# Test case 2
ref_index = 2
cut_off_distance = (2, 2)
expected = [3, 4, 5] # Indices that are within 2 units from (0, 2)
expected = [2, 3, 4, 5] # Indices that are within 2 units from (0, 2)
result = find_candidate_patch_ids(signal_patches, ref_index, cut_off_distance)
assert result == expected, "Test case 2 failed"

# Test case 3
ref_index = 4
cut_off_distance = (3, 3)
expected = [5, 6] # Indices that are within 3 units from (1, 1)
expected = [4, 5, 6] # Indices that are within 3 units from (1, 1)
result = find_candidate_patch_ids(signal_patches, ref_index, cut_off_distance)
assert result == expected, "Test case 3 failed"

# Test case 4
ref_index = 0
cut_off_distance = (0, 0)
expected = [] # No patch within 0 distance from (0, 0) except itself
expected = [0] # No patch within 0 distance from (0, 0) except itself
result = find_candidate_patch_ids(signal_patches, ref_index, cut_off_distance)
assert result == expected, "Test case 4 failed"

Expand Down
Loading