Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate prototype code #3

Merged
merged 16 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: run unit tests
run: |
echo "running unit tests"
python -m pytest --cov=src --cov-report=xml --cov-report=term-missing tests/
python -m pytest --cov=src --cov-report=xml --cov-report=term-missing -m "not cuda_required" tests/
KedoKudo marked this conversation as resolved.
Show resolved Hide resolved
- name: upload coverage to codecov
uses: codecov/codecov-action@v4
with:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.envrc
src/bm3dornl/_version.py
45 changes: 22 additions & 23 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,39 +1,38 @@
name: bm3dornl
channels:
- conda-forge
- nvidia
dependencies:
# -- Runtime dependencies
# base: list all base dependencies here
- python>=3.8 # please specify the mimimum version of python here
# base
- python>=3.10
- pip
- versioningit
# compute: list all compute dependencies here
- numpy
- pandas
# plot: list all plot dependencies here, if applicable
# compute
- cupy
- numba
- scipy<1.13 # avoid a bug in 1.13
- scikit-image
# [Optional]visualization
- matplotlib
# jupyter: list all jupyter dependencies here, if applicable
# [Optional]jupyter
- jupyterlab
- ipympl
# -- Development dependencies
# utils:
# utils
- pre-commit
# pacakge building:
- libmamba
- libarchive
- line_profiler # useful for development
- memory_profiler # useful for development
# packaging
- anaconda-client
- boa
- conda-build < 4 # conda-build 24.x has a bug, missing update_index from conda_build.index
- conda-build < 4
- conda-verify
- libmamba
- libarchive
- python-build
# test: list all test dependencies here
# test
- pytest
- pytest-cov
- pytest-mock
- pytest-xdist
# --------------------------------------------------
# add additional sections such as Qt, etc. if needed
# --------------------------------------------------
# if pakcages are not available on conda, list them here
- pip
# pip packages
- pip:
- bm3d-streak-removal # example
- pytest-playwright
- bm3d-streak-removal # this is our reference package
396 changes: 392 additions & 4 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ testpaths = ["tests"]
python_files = ["test*.py"]
norecursedirs = [".git", "tmp*", "_tmp*", "__pycache__", "*dataset*", "*data_set*"]
markers = [
"mymarker: example markers goes here"
"cuda_required: test requires cuda to run."
]

[tool.pylint]
Expand Down
7 changes: 0 additions & 7 deletions src/bm3dornl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,3 @@
from ._version import __version__ # noqa: F401
except ImportError:
__version__ = "0.0.1"


def PackageName(): # pylint: disable=invalid-name
"""This is needed for backward compatibility because mantid workbench does "from shiver import Shiver" """
from .packagenamepy import PackageName as packagename # pylint: disable=import-outside-toplevel

return packagename()
43 changes: 43 additions & 0 deletions src/bm3dornl/aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env python3
"""Functions for aggregating hyper patch block into a single image."""

import numpy as np
from numba import jit, prange


@jit(nopython=True, parallel=True)
def aggregate_patches(
estimate_denoised_image: np.ndarray,
weights: np.ndarray,
hyper_block: np.ndarray,
hyper_block_index: np.ndarray,
):
"""
Aggregate patches into the denoised image matrix and update the corresponding weights matrix.

Parameters
----------
estimate_denoised_image : np.ndarray
The 2D numpy array where the aggregate result of the denoised patches will be stored.
weights : np.ndarray
The 2D numpy array that counts the contributions of the patches to the cells of the `estimate_denoised_image`.
hyper_block : np.ndarray
A 4D numpy array of patches to be aggregated. Shape is (num_blocks, num_patches_per_block, patch_height, patch_width).
hyper_block_index : np.ndarray
A 3D numpy array containing the top-left indices (row, column) for each patch in the `hyper_block`.
Shape is (num_blocks, num_patches_per_block, 2).

Notes
-----
This function uses Numba's JIT compilation with parallel execution to speed up the aggregation of image patches.
Each thread handles a block of patches independently, reducing computational time significantly on multi-core processors.
"""
num_blocks, num_patches, ph, pw = hyper_block.shape
for i in prange(num_blocks):
for p in range(num_patches):
patch = hyper_block[i, p]
i_pos, j_pos = hyper_block_index[i, p]
for ii in range(ph):
for jj in range(pw):
estimate_denoised_image[i_pos + ii, j_pos + jj] += patch[ii, jj]
weights[i_pos + ii, j_pos + jj] += 1
162 changes: 162 additions & 0 deletions src/bm3dornl/block_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#!/usr/bin/env python3
"""Block matching to build hyper block from single sinogram."""

import numpy as np
from typing import Tuple, Optional
from bm3dornl.utils import (
get_signal_patch_positions,
find_candidate_patch_ids,
is_within_threshold,
pad_patch_ids,
)


class PatchManager:
def __init__(
self,
image: np.ndarray,
patch_size: Tuple[int, int] = (8, 8),
stride: int = 1,
background_threshold: float = 0.1,
):
"""
Initialize the PatchManager with an image, patch configuration, and background threshold
for distinguishing between signal and background patches.

Parameters
----------
image : np.ndarray
The image from which patches will be managed.
patch_size : tuple
Dimensions (height, width) of each patch. Default is (8, 8).
stride : int
The stride with which to slide the window across the image. Default is 1.
background_threshold : float
The mean intensity threshold below which a patch is considered a background patch.
"""
self._image = image
self.patch_size = patch_size
self.stride = stride
self.background_threshold = background_threshold
self.signal_patches_pos = []
self.signal_blocks_matrix = []
self._generate_patch_positions()

def _generate_patch_positions(self):
"""Generate the positions of signal and background patches in the image."""
self.signal_patches_pos = get_signal_patch_positions(
self._image, self.patch_size, self.stride, self.background_threshold
)

@property
def image(self):
return self._image

@image.setter
def image(self, value):
self._image = value
self._generate_patch_positions()

def get_patch(
self, position: tuple, source_image: Optional[np.ndarray] = None
) -> np.ndarray:
"""Retreive a patch from the image at the specified position.

Parameters:
----------
position : tuple
The row and column indices of the top-left corner of the patch.
source_image : np.ndarray

Returns:
-------
np.ndarray
The patch extracted from the image.
"""
source_image = self._image if source_image is None else source_image
i, j = position
return source_image[i : i + self.patch_size[0], j : j + self.patch_size[1]]

def group_signal_patches(
self, cut_off_distance: tuple, intensity_diff_threshold: float
):
"""
Group signal patches into blocks based on spatial and intensity distance thresholds.

Parameters:
----------
cut_off_distance : tuple
Maximum spatial distance in terms of row and column indices for patches in the same block, Manhattan distance (taxi cab distance).
intensity_diff_threshold : float
Maximum Euclidean distance in intensity for patches to be considered similar.
"""
num_patches = len(self.signal_patches_pos)
self.signal_blocks_matrix = np.eye(num_patches, dtype=bool)

# Cache patches as views
cached_patches = [self.get_patch(pos) for pos in self.signal_patches_pos]

for ref_patch_id in range(num_patches):
ref_patch = cached_patches[ref_patch_id]
candidate_patch_ids = find_candidate_patch_ids(
self.signal_patches_pos, ref_patch_id, cut_off_distance
)
# iterate over the candidate patches
for neightbor_patch_id in candidate_patch_ids:
if is_within_threshold(
ref_patch,
cached_patches[neightbor_patch_id],
intensity_diff_threshold,
):
self.signal_blocks_matrix[ref_patch_id, neightbor_patch_id] = True
self.signal_blocks_matrix[neightbor_patch_id, ref_patch_id] = True

def get_hyper_block(
self,
num_patches_per_group: int,
padding_mode="circular",
alternative_source: np.ndarray = None,
):
"""
Return groups of similar patches as 4D arrays with each group having a fixed number of patches.

Parameters:
----------
num_patches_per_group : int
Number of patches in each group.
padding_mode : str
Mode for padding the patch IDs when the number of candidates is less than `num_patches_per_group`.
Options are 'first', 'repeat_sequence', 'circular', 'mirror', 'random'.
alternative_source : cp.ndarray
An alternative source image to extract patches from. Default is None.

Returns:
-------
tuple
A tuple containing the 4D array of patch groups and the corresponding positions.

TODO:
-----
- use multi-processing to further improve the speed of block building
"""
group_size = len(self.signal_blocks_matrix)
block = np.empty(
(group_size, num_patches_per_group, *self.patch_size), dtype=np.float32
)
positions = np.empty((group_size, num_patches_per_group, 2), dtype=np.int32)

for i, row in enumerate(self.signal_blocks_matrix):
candidate_patch_ids = np.where(row)[0]
padded_patch_ids = pad_patch_ids(
candidate_patch_ids, num_patches_per_group, mode=padding_mode
)
# update block and positions
block[i] = np.array(
[
self.get_patch(self.signal_patches_pos[idx], alternative_source)
for idx in padded_patch_ids
]
)
positions[i] = np.array(self.signal_patches_pos[padded_patch_ids])

return block, positions
66 changes: 0 additions & 66 deletions src/bm3dornl/bm3dornl.py

This file was deleted.

Loading
Loading