Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement power spectrum density for filtering #13

Merged
merged 18 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,6 @@ cython_debug/
#.idea/
.envrc
src/bm3dornl/_version.py
.vscode/settings.json
tmp/*
dev*.ipynb
333 changes: 333 additions & 0 deletions notebooks/demo_denoise_mode.ipynb

Large diffs are not rendered by default.

158 changes: 75 additions & 83 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

646 changes: 646 additions & 0 deletions notebooks/stepbystep_Makinen2020.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/bm3dornl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
from ._version import __version__ # noqa: F401
except ImportError:
__version__ = "0.0.1"

from .bm3d import bm3d_ring_artifact_removal # noqa: F401
110 changes: 85 additions & 25 deletions src/bm3dornl/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,102 @@
"""Functions for aggregating hyper patch block into a single image."""

import numpy as np
from numba import jit, prange
from typing import Tuple
from numba import njit, prange


@jit(nopython=True, parallel=True)
def aggregate_patches(
estimate_denoised_image: np.ndarray,
weights: np.ndarray,
hyper_block: np.ndarray,
hyper_block_index: np.ndarray,
):
@njit(parallel=True)
def aggregate_block_to_image(
image_shape: Tuple[int, int],
hyper_blocks: np.ndarray,
hyper_block_indices: np.ndarray,
variance_blocks: np.ndarray,
) -> np.ndarray:
"""
Aggregate patches into the denoised image matrix and update the corresponding weights matrix.
Aggregate patches into the denoised image matrix and update the corresponding weights matrix using smart weighting.

Parameters
----------
estimate_denoised_image : np.ndarray
The 2D numpy array where the aggregate result of the denoised patches will be stored.
weights : np.ndarray
The 2D numpy array that counts the contributions of the patches to the cells of the `estimate_denoised_image`.
hyper_block : np.ndarray
image_shape : tuple
The shape of the image to be denoised.
hyper_blocks : np.ndarray
A 4D numpy array of patches to be aggregated. Shape is (num_blocks, num_patches_per_block, patch_height, patch_width).
hyper_block_index : np.ndarray
A 3D numpy array containing the top-left indices (row, column) for each patch in the `hyper_block`.
hyper_block_indices : np.ndarray
A 3D numpy array containing the top-left indices (row, column) for each patch in the `hyper_blocks`.
Shape is (num_blocks, num_patches_per_block, 2).
variance_blocks : np.ndarray
A 4D numpy array of the variances for each patch. Shape is the same as `hyper_blocks`.

Notes
-----
This function uses Numba's JIT compilation with parallel execution to speed up the aggregation of image patches.
Each thread handles a block of patches independently, reducing computational time significantly on multi-core processors.
Returns
-------
np.ndarray
The denoised image.
"""
num_blocks, num_patches, ph, pw = hyper_block.shape
estimate_denoised_image = np.zeros(image_shape)
weights = np.zeros(image_shape)

num_blocks, num_patches, ph, pw = hyper_blocks.shape

for i in prange(num_blocks):
for p in range(num_patches):
patch = hyper_block[i, p]
i_pos, j_pos = hyper_block_index[i, p]
patch = hyper_blocks[i, p]
variance = variance_blocks[i, p]
weight = 1 / (variance + 1e-8) # Small epsilon to avoid division by zero
i_pos, j_pos = hyper_block_indices[i, p]
for ii in range(ph):
for jj in range(pw):
estimate_denoised_image[i_pos + ii, j_pos + jj] += patch[ii, jj]
weights[i_pos + ii, j_pos + jj] += 1
estimate_denoised_image[i_pos + ii, j_pos + jj] += (
patch[ii, jj] * weight[ii, jj]
)
weights[i_pos + ii, j_pos + jj] += weight[ii, jj]

# Normalize the denoised image by the sum of weights
estimate_denoised_image /= np.maximum(weights, 1)

return estimate_denoised_image


@njit(parallel=True)
def aggregate_denoised_block_to_image(
image_shape: Tuple[int, int],
denoised_patches: np.ndarray,
patch_positions: np.ndarray,
) -> np.ndarray:
"""
Aggregate denoised patches into the final denoised image.

Parameters
----------
image_shape : tuple
The shape of the final denoised image (height, width).
denoised_patches : np.ndarray
A 4D numpy array of denoised patches. Shape is (num_blocks, num_patches_per_block, patch_height, patch_width).
patch_positions : np.ndarray
A 3D numpy array containing the top-left indices (row, column) for each patch in the `denoised_patches`.
Shape is (num_blocks, num_patches_per_block, 2).

Returns
-------
np.ndarray
The final denoised image.
"""
denoised_image = np.zeros(image_shape, dtype=np.float32)
weights = np.zeros(image_shape, dtype=np.float32)

num_blocks, num_patches_per_block, patch_height, patch_width = (
denoised_patches.shape
)

for block_idx in prange(num_blocks):
for patch_idx in range(num_patches_per_block):
patch = denoised_patches[block_idx, patch_idx]
top_left_row, top_left_col = patch_positions[block_idx, patch_idx]
for i in range(patch_height):
for j in range(patch_width):
denoised_image[top_left_row + i, top_left_col + j] += patch[i, j]
weights[top_left_row + i, top_left_col + j] += 1

# Normalize the denoised image by the sum of weights
denoised_image /= np.maximum(weights, 1)

return denoised_image
Loading
Loading