Skip to content

Commit

Permalink
Merge pull request #3 from ornlneutronimaging/add_src_code
Browse files Browse the repository at this point in the history
Migrate prototype code
  • Loading branch information
KedoKudo authored May 15, 2024
2 parents 3a9b129 + 714b04f commit 937a529
Show file tree
Hide file tree
Showing 29 changed files with 2,217 additions and 224 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: run unit tests
run: |
echo "running unit tests"
python -m pytest --cov=src --cov-report=xml --cov-report=term-missing tests/
python -m pytest --cov=src --cov-report=xml --cov-report=term-missing -m "not cuda_required"
- name: upload coverage to codecov
uses: codecov/codecov-action@v4
with:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.envrc
src/bm3dornl/_version.py
45 changes: 22 additions & 23 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,39 +1,38 @@
name: bm3dornl
channels:
- conda-forge
- nvidia
dependencies:
# -- Runtime dependencies
# base: list all base dependencies here
- python>=3.8 # please specify the mimimum version of python here
# base
- python>=3.10
- pip
- versioningit
# compute: list all compute dependencies here
- numpy
- pandas
# plot: list all plot dependencies here, if applicable
# compute
- cupy
- numba
- scipy<1.13 # avoid a bug in 1.13
- scikit-image
# [Optional]visualization
- matplotlib
# jupyter: list all jupyter dependencies here, if applicable
# [Optional]jupyter
- jupyterlab
- ipympl
# -- Development dependencies
# utils:
# utils
- pre-commit
# pacakge building:
- libmamba
- libarchive
- line_profiler # useful for development
- memory_profiler # useful for development
# packaging
- anaconda-client
- boa
- conda-build < 4 # conda-build 24.x has a bug, missing update_index from conda_build.index
- conda-build < 4
- conda-verify
- libmamba
- libarchive
- python-build
# test: list all test dependencies here
# test
- pytest
- pytest-cov
- pytest-mock
- pytest-xdist
# --------------------------------------------------
# add additional sections such as Qt, etc. if needed
# --------------------------------------------------
# if pakcages are not available on conda, list them here
- pip
# pip packages
- pip:
- bm3d-streak-removal # example
- pytest-playwright
- bm3d-streak-removal # this is our reference package
396 changes: 392 additions & 4 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ testpaths = ["tests"]
python_files = ["test*.py"]
norecursedirs = [".git", "tmp*", "_tmp*", "__pycache__", "*dataset*", "*data_set*"]
markers = [
"mymarker: example markers goes here"
"cuda_required: test requires cuda to run."
]

[tool.pylint]
Expand Down
7 changes: 0 additions & 7 deletions src/bm3dornl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,3 @@
from ._version import __version__ # noqa: F401
except ImportError:
__version__ = "0.0.1"


def PackageName(): # pylint: disable=invalid-name
"""This is needed for backward compatibility because mantid workbench does "from shiver import Shiver" """
from .packagenamepy import PackageName as packagename # pylint: disable=import-outside-toplevel

return packagename()
43 changes: 43 additions & 0 deletions src/bm3dornl/aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env python3
"""Functions for aggregating hyper patch block into a single image."""

import numpy as np
from numba import jit, prange


@jit(nopython=True, parallel=True)
def aggregate_patches(
estimate_denoised_image: np.ndarray,
weights: np.ndarray,
hyper_block: np.ndarray,
hyper_block_index: np.ndarray,
):
"""
Aggregate patches into the denoised image matrix and update the corresponding weights matrix.
Parameters
----------
estimate_denoised_image : np.ndarray
The 2D numpy array where the aggregate result of the denoised patches will be stored.
weights : np.ndarray
The 2D numpy array that counts the contributions of the patches to the cells of the `estimate_denoised_image`.
hyper_block : np.ndarray
A 4D numpy array of patches to be aggregated. Shape is (num_blocks, num_patches_per_block, patch_height, patch_width).
hyper_block_index : np.ndarray
A 3D numpy array containing the top-left indices (row, column) for each patch in the `hyper_block`.
Shape is (num_blocks, num_patches_per_block, 2).
Notes
-----
This function uses Numba's JIT compilation with parallel execution to speed up the aggregation of image patches.
Each thread handles a block of patches independently, reducing computational time significantly on multi-core processors.
"""
num_blocks, num_patches, ph, pw = hyper_block.shape
for i in prange(num_blocks):
for p in range(num_patches):
patch = hyper_block[i, p]
i_pos, j_pos = hyper_block_index[i, p]
for ii in range(ph):
for jj in range(pw):
estimate_denoised_image[i_pos + ii, j_pos + jj] += patch[ii, jj]
weights[i_pos + ii, j_pos + jj] += 1
162 changes: 162 additions & 0 deletions src/bm3dornl/block_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#!/usr/bin/env python3
"""Block matching to build hyper block from single sinogram."""

import numpy as np
from typing import Tuple, Optional
from bm3dornl.utils import (
get_signal_patch_positions,
find_candidate_patch_ids,
is_within_threshold,
pad_patch_ids,
)


class PatchManager:
def __init__(
self,
image: np.ndarray,
patch_size: Tuple[int, int] = (8, 8),
stride: int = 1,
background_threshold: float = 0.1,
):
"""
Initialize the PatchManager with an image, patch configuration, and background threshold
for distinguishing between signal and background patches.
Parameters
----------
image : np.ndarray
The image from which patches will be managed.
patch_size : tuple
Dimensions (height, width) of each patch. Default is (8, 8).
stride : int
The stride with which to slide the window across the image. Default is 1.
background_threshold : float
The mean intensity threshold below which a patch is considered a background patch.
"""
self._image = image
self.patch_size = patch_size
self.stride = stride
self.background_threshold = background_threshold
self.signal_patches_pos = []
self.signal_blocks_matrix = []
self._generate_patch_positions()

def _generate_patch_positions(self):
"""Generate the positions of signal and background patches in the image."""
self.signal_patches_pos = get_signal_patch_positions(
self._image, self.patch_size, self.stride, self.background_threshold
)

@property
def image(self):
return self._image

@image.setter
def image(self, value):
self._image = value
self._generate_patch_positions()

def get_patch(
self, position: tuple, source_image: Optional[np.ndarray] = None
) -> np.ndarray:
"""Retreive a patch from the image at the specified position.
Parameters:
----------
position : tuple
The row and column indices of the top-left corner of the patch.
source_image : np.ndarray
Returns:
-------
np.ndarray
The patch extracted from the image.
"""
source_image = self._image if source_image is None else source_image
i, j = position
return source_image[i : i + self.patch_size[0], j : j + self.patch_size[1]]

def group_signal_patches(
self, cut_off_distance: tuple, intensity_diff_threshold: float
):
"""
Group signal patches into blocks based on spatial and intensity distance thresholds.
Parameters:
----------
cut_off_distance : tuple
Maximum spatial distance in terms of row and column indices for patches in the same block, Manhattan distance (taxi cab distance).
intensity_diff_threshold : float
Maximum Euclidean distance in intensity for patches to be considered similar.
"""
num_patches = len(self.signal_patches_pos)
self.signal_blocks_matrix = np.eye(num_patches, dtype=bool)

# Cache patches as views
cached_patches = [self.get_patch(pos) for pos in self.signal_patches_pos]

for ref_patch_id in range(num_patches):
ref_patch = cached_patches[ref_patch_id]
candidate_patch_ids = find_candidate_patch_ids(
self.signal_patches_pos, ref_patch_id, cut_off_distance
)
# iterate over the candidate patches
for neightbor_patch_id in candidate_patch_ids:
if is_within_threshold(
ref_patch,
cached_patches[neightbor_patch_id],
intensity_diff_threshold,
):
self.signal_blocks_matrix[ref_patch_id, neightbor_patch_id] = True
self.signal_blocks_matrix[neightbor_patch_id, ref_patch_id] = True

def get_hyper_block(
self,
num_patches_per_group: int,
padding_mode="circular",
alternative_source: np.ndarray = None,
):
"""
Return groups of similar patches as 4D arrays with each group having a fixed number of patches.
Parameters:
----------
num_patches_per_group : int
Number of patches in each group.
padding_mode : str
Mode for padding the patch IDs when the number of candidates is less than `num_patches_per_group`.
Options are 'first', 'repeat_sequence', 'circular', 'mirror', 'random'.
alternative_source : cp.ndarray
An alternative source image to extract patches from. Default is None.
Returns:
-------
tuple
A tuple containing the 4D array of patch groups and the corresponding positions.
TODO:
-----
- use multi-processing to further improve the speed of block building
"""
group_size = len(self.signal_blocks_matrix)
block = np.empty(
(group_size, num_patches_per_group, *self.patch_size), dtype=np.float32
)
positions = np.empty((group_size, num_patches_per_group, 2), dtype=np.int32)

for i, row in enumerate(self.signal_blocks_matrix):
candidate_patch_ids = np.where(row)[0]
padded_patch_ids = pad_patch_ids(
candidate_patch_ids, num_patches_per_group, mode=padding_mode
)
# update block and positions
block[i] = np.array(
[
self.get_patch(self.signal_patches_pos[idx], alternative_source)
for idx in padded_patch_ids
]
)
positions[i] = np.array(self.signal_patches_pos[padded_patch_ids])

return block, positions
66 changes: 0 additions & 66 deletions src/bm3dornl/bm3dornl.py

This file was deleted.

Loading

0 comments on commit 937a529

Please sign in to comment.