diff --git a/.gitignore b/.gitignore index bff69f29..eb6f33c0 100644 --- a/.gitignore +++ b/.gitignore @@ -12,7 +12,6 @@ /build /dist /python/imars3d.egg-info -conda.recipe/ # temp files and dirs _tmp diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f66b8cca..73faa5ac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: trailing-whitespace - id: check-docstring-first @@ -25,7 +25,7 @@ repos: - id: end-of-file-fixer - id: sort-simple-yaml - repo: https://github.com/psf/black - rev: 24.3.0 + rev: 24.8.0 hooks: - id: black args: ['--line-length=119'] diff --git a/.readthedocs.yaml b/.readthedocs.yaml index e303aa84..3cb4e8da 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,9 +1,9 @@ version: 2 build: - os: ubuntu-20.04 + os: "ubuntu-22.04" tools: - python: "mambaforge-4.10" + python: "mambaforge-22.9" jobs: pre_build: - mkdir ~/tmp diff --git a/conda.recipe/meta.yaml b/conda.recipe/meta.yaml index 98175aec..1a668f04 100644 --- a/conda.recipe/meta.yaml +++ b/conda.recipe/meta.yaml @@ -47,6 +47,7 @@ requirements: - bokeh - datashader - hvplot + - numpy<2 test: imports: diff --git a/environment.yml b/environment.yml index c159f939..55fc4247 100644 --- a/environment.yml +++ b/environment.yml @@ -11,11 +11,13 @@ dependencies: - astropy - tomopy - algotom + - numpy < 2 # plot - holoviews - bokeh - datashader - hvplot + - dask # GUI - panel - param diff --git a/src/imars3d/backend/corrections/beam_hardening.py b/src/imars3d/backend/corrections/beam_hardening.py index 9338fabe..69be5a77 100644 --- a/src/imars3d/backend/corrections/beam_hardening.py +++ b/src/imars3d/backend/corrections/beam_hardening.py @@ -4,7 +4,7 @@ import logging import param import numpy as np -from imars3d.backend.util.functions import clamp_max_workers +from imars3d.backend.util.functions import clamp_max_workers, calculate_chunksize from multiprocessing.managers import SharedMemoryManager from functools import partial from tqdm.contrib.concurrent import process_map @@ -83,6 +83,7 @@ def __call__(self, **params): # mp kwargs = { "max_workers": self.max_workers, + "chunksize": calculate_chunksize(params.arrays.shape[0], self.max_workers), "desc": "denoise_by_bilateral", } if self.tqdm_class: diff --git a/src/imars3d/backend/corrections/denoise.py b/src/imars3d/backend/corrections/denoise.py index 67c980cc..2f45b9d2 100644 --- a/src/imars3d/backend/corrections/denoise.py +++ b/src/imars3d/backend/corrections/denoise.py @@ -3,7 +3,7 @@ """Image noise reduction (denoise) module.""" import logging import param -from imars3d.backend.util.functions import clamp_max_workers +from imars3d.backend.util.functions import clamp_max_workers, calculate_chunksize import numpy as np import tomopy from multiprocessing.managers import SharedMemoryManager @@ -153,6 +153,7 @@ def denoise_by_bilateral( # mp kwargs = { "max_workers": max_workers, + "chunksize": calculate_chunksize(arrays.shape[0], max_workers), "desc": "denoise_by_bilateral", } if tqdm_class: diff --git a/src/imars3d/backend/corrections/intensity_fluctuation_correction.py b/src/imars3d/backend/corrections/intensity_fluctuation_correction.py index 44276362..64a8c229 100644 --- a/src/imars3d/backend/corrections/intensity_fluctuation_correction.py +++ b/src/imars3d/backend/corrections/intensity_fluctuation_correction.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- """iMars3D's intensity fluctuation correction module.""" import logging -from imars3d.backend.util.functions import clamp_max_workers +from imars3d.backend.util.functions import clamp_max_workers, calculate_chunksize import numpy as np import param import tomopy @@ -93,6 +93,7 @@ def _intensity_fluctuation_correction(self, ct, air_pixels, sigma, max_workers, # map the multiprocessing calls kwargs = { "max_workers": max_workers, + "chunksize": calculate_chunksize(ct.shape[0], max_workers), "desc": "intensity_fluctuation_correction", } if tqdm_class: diff --git a/src/imars3d/backend/corrections/ring_removal.py b/src/imars3d/backend/corrections/ring_removal.py index dd39c6e3..aef0fcf6 100644 --- a/src/imars3d/backend/corrections/ring_removal.py +++ b/src/imars3d/backend/corrections/ring_removal.py @@ -3,7 +3,7 @@ """iMars3D's ring artifact correction module.""" import logging import param -from imars3d.backend.util.functions import clamp_max_workers +from imars3d.backend.util.functions import clamp_max_workers, calculate_chunksize import scipy import numpy as np @@ -238,6 +238,7 @@ def _remove_ring_artifact( # invoke mp via tqdm wrapper kwargs = { "max_workers": max_workers, + "chunksize": calculate_chunksize(arrays.shape[1], max_workers), "desc": "Removing ring artifact", } if tqdm_class: diff --git a/src/imars3d/backend/dataio/data.py b/src/imars3d/backend/dataio/data.py index ebbdd648..718dd878 100644 --- a/src/imars3d/backend/dataio/data.py +++ b/src/imars3d/backend/dataio/data.py @@ -3,7 +3,7 @@ # package imports from imars3d.backend.dataio.metadata import MetaData -from imars3d.backend.util.functions import clamp_max_workers, to_time_str +from imars3d.backend.util.functions import clamp_max_workers, to_time_str, calculate_chunksize # third party imports import numpy as np @@ -100,7 +100,7 @@ class load_data(param.ParameterizedFunction): dc_fnmatch: Optional[str] Unix shells-style wild card (``*.tiff``) for selecting dark current max_workers: Optional[int] - maximum number of processes allowed during loading, default to use as many as possible. + maximum number of processes allowed during loading, default to use a single core. tqdm_class: panel.widgets.Tqdm Class to be used for rendering tqdm progress @@ -125,6 +125,9 @@ class load_data(param.ParameterizedFunction): Currently, we are using a forgiving reader to load the image where a corrupted file will not block reading other data. + + The rotation angles are extracted from the filenames if possible, otherwise from the + metadata embedded in the tiff files. If both failed, the angle will be set to None. """ # @@ -296,7 +299,17 @@ def _load_images(filelist: List[str], desc: str, max_workers: int, tqdm_class) - file_ext = Path(filelist[0]).suffix.lower() if file_ext in (".tif", ".tiff"): # use tifffile directly for a faster loading - reader = partial(tifffile.imread, out="memmap") + # NOTE: Test conducted on 09-05-2024 on bl10-analysis1 shows that using + # memmap is faster, which contradicts the observation from the instrument + # team. + # | Method | Time (s) | + # |--------|----------| + # | `imread(out="memmap")` | 2.62 s ± 24.6 ms | + # | `imread()` | 3.59 s ± 13.6 ms | + # The `memmap` option is removed until we have a better understanding of the + # discrepancy. + # reader = partial(tifffile.imread, out="memmap") + reader = tifffile.imread elif file_ext == ".fits": reader = dxchange.read_fits else: @@ -316,6 +329,7 @@ def _load_images(filelist: List[str], desc: str, max_workers: int, tqdm_class) - # - there are a lot of cores available kwargs = { "max_workers": max_workers, + "chunksize": calculate_chunksize(len(filelist), max_workers), "desc": desc, } rst = process_map(partial(_forgiving_reader, reader=reader), filelist, **kwargs) @@ -534,7 +548,7 @@ def _get_filelist_by_dir( def _extract_rotation_angles( filelist: List[str], metadata_idx: int = 65039, -) -> np.ndarray: +) -> Optional[np.ndarray]: """ Extract rotation angles in degrees from filename or metadata. @@ -548,40 +562,106 @@ def _extract_rotation_angles( Returns ------- rotation_angles + Array of rotation angles if successfully extracted, None otherwise. """ # sanity check - if filelist == []: + if not filelist: logger.error("filelist is [].") raise ValueError("filelist cannot be empty list.") - # extract rotation angles from file names + # process one file at a time + rotation_angles = [] + for filename in filelist: + file_ext = Path(filename).suffix.lower() + angle = None + if file_ext == ".tiff": + # first, let's try to extract the angle from the filename + angle = extract_rotation_angle_from_filename(filename) + if angle is None: + # if failed, try to extract from metadata + angle = extract_rotation_angle_from_tiff_metadata(filename, metadata_idx) + if angle is None: + # if failed, log a warning and move on + logger.warning(f"Failed to extract rotation angle from {filename}.") + elif file_ext in (".tif", ".fits"): + # for tif and fits, we can only extract from filename as the metadata is not reliable + angle = extract_rotation_angle_from_filename(filename) + if angle is None: + # if failed, log a warning and move on + logger.warning(f"Failed to extract rotation angle from {filename}.") + else: + # if the file type is not supported, raise value error + logger.error(f"Unsupported file type: {file_ext}") + raise ValueError(f"Unsupported file type: {file_ext}") + + rotation_angles.append(angle) + + # this means we have a list of None + if all(angle is None for angle in rotation_angles): + logger.warning("Failed to extract any rotation angles.") + return None + + # warn users if some angles are missing + if any(angle is None for angle in rotation_angles): + logger.warning("Some rotation angles are missing. You will see nan in the rotation angles array.") + + return np.array(rotation_angles, dtype=float) + + +def extract_rotation_angle_from_filename(filename: str) -> Optional[float]: + """ + Extract rotation angle in degrees from filename. + + Parameters + ---------- + filename: + Filename to extract rotation angle from. + + Returns + ------- + rotation_angle + Rotation angle in degrees if successfully extracted, None otherwise. + """ + # extract rotation angle from file names # Note # ---- # For the following file - # 20191030_ironman_small_0070_300_440_0520.tiff + # 20191030_ironman_small_0070_300_440_0520.tif(f) + # 20191030_ironman_small_0070_300_440_0520.fits # the rotation angle is 300.44 degrees - # If all given filenames follows the pattern, we will use the angles from - # filenames. Otherwise, we will use the angles from metadata. - regex = r"\d{8}_\S*_\d{4}_(?P\d{3})_(?P\d{3})_\d*\.tiff" - matches = [re.match(regex, Path(f).name) for f in filelist] - if all(matches): - logger.info("Using rotation angles from filenames.") - rotation_angles = np.array([float(".".join(m.groups())) for m in matches]) + regex = r"\d{8}_\S*_\d{4}_(?P\d{3})_(?P\d{3})_\d*\.(?:tiff?|fits)" + match = re.match(regex, Path(filename).name) + if match: + rotation_angle = float(".".join(match.groups())) else: - # extract rotation angles from metadata - file_ext = set([Path(f).suffix for f in filelist]) - if file_ext != {".tiff"}: - logger.error("Only tiff files are supported.") - raise ValueError("Rotation angle from metadata is only supported for Tiff.") + rotation_angle = None + return rotation_angle + + +def extract_rotation_angle_from_tiff_metadata(filename: str, metadata_idx: int = 65039) -> Optional[float]: + """ + Extract rotation angle in degrees from metadata of a tiff file. + + Parameters + ---------- + filename: + Filename to extract rotation angle from. + metadata_idx: + Index of metadata to extract rotation angle from, default is 65039. + + Returns + ------- + rotation_angle + Rotation angle in degrees if successfully extracted, None otherwise. + """ + try: # -- read metadata # img = tifffile.TiffFile("test_with_metadata_0.tiff") # img.pages[0].tags[65039].value # >> 'RotationActual:0.579840' - rotation_angles = np.array( - [float(tifffile.TiffFile(f).pages[0].tags[metadata_idx].value.split(":")[-1]) for f in filelist], - dtype="float", - ) - return rotation_angles + return float(tifffile.TiffFile(filename).pages[0].tags[metadata_idx].value.split(":")[-1]) + except Exception: + return None def _save_data(filename: Path, data: np.ndarray, rot_angles: np.ndarray = None) -> None: diff --git a/src/imars3d/backend/diagnostics/rotation.py b/src/imars3d/backend/diagnostics/rotation.py index 22f2f5b8..2674e357 100644 --- a/src/imars3d/backend/diagnostics/rotation.py +++ b/src/imars3d/backend/diagnostics/rotation.py @@ -5,7 +5,7 @@ import numpy as np import param -from imars3d.backend.util.functions import clamp_max_workers +from imars3d.backend.util.functions import clamp_max_workers, calculate_chunksize from multiprocessing.managers import SharedMemoryManager from tqdm.contrib.concurrent import process_map from tomopy.recon.rotation import find_center_pc @@ -139,6 +139,7 @@ def _find_rotation_center( # map the multiprocessing calls kwargs = { "max_workers": max_workers, + "chunksize": calculate_chunksize(len(idx_low), max_workers), "desc": "Finding rotation center", } if tqdm_class: diff --git a/src/imars3d/backend/diagnostics/tilt.py b/src/imars3d/backend/diagnostics/tilt.py index 79389f04..8cd61fbe 100644 --- a/src/imars3d/backend/diagnostics/tilt.py +++ b/src/imars3d/backend/diagnostics/tilt.py @@ -4,9 +4,9 @@ import logging import param import multiprocessing -from imars3d.backend.util.functions import clamp_max_workers +from imars3d.backend.util.functions import clamp_max_workers, calculate_chunksize import numpy as np -from typing import Tuple +from typing import Tuple, Union, Optional from functools import partial from scipy.optimize import minimize_scalar from scipy.optimize import OptimizeResult @@ -103,6 +103,7 @@ def calculate_dissimilarity( tilt: float, image0: np.ndarray, image1: np.ndarray, + center: Optional[Tuple[Union[float, int], Union[float, int]]] = None, ) -> float: """Calculate the dissimilarity between two images with given tilt. @@ -119,6 +120,9 @@ def calculate_dissimilarity( image1: The second image for comparison, which is often the radiograph taken at omega + 180 deg + center: + The center of the rotation axis, default is None, which means the center + of the image. This will be passed to the rotation function from skimage. Returns ------- @@ -168,6 +172,7 @@ def calculate_dissimilarity( resize=True, preserve_range=True, order=1, # use default bi-linear interpolation for rotation + center=center, ) # since 180 is flipped, tilting back -2 deg of the original img180 means tilting +2 deg # of the flipped one @@ -178,6 +183,7 @@ def calculate_dissimilarity( resize=True, preserve_range=True, order=1, # use default bi-linear interpolation for rotation + center=center, ) # p-norm @@ -198,6 +204,7 @@ def calculate_tilt( image180: np.ndarray, low_bound: float = -5.0, high_bound: float = 5.0, + center: Optional[Tuple[Union[float, int], Union[float, int]]] = None, ) -> OptimizeResult: """ Use optimization to find the in-plane tilt angle. @@ -214,13 +221,16 @@ def calculate_tilt( The lower bound of the tilt angle search space high_bound: The upper bound of the tilt angle search space + center: + The center of the rotation axis, default is None, which means the center + of the image. This will be passed to the rotation function from skimage. Returns ------- The optimization results from scipy.optimize.minimize_scalar """ # make the error function - err_func = partial(calculate_dissimilarity, image0=image0, image1=image180) + err_func = partial(calculate_dissimilarity, image0=image0, image1=image180, center=center) # use bounded uni-variable optimizer to locate the tilt angle that minimize # the dissimilarity of the 180 deg pair res = minimize_scalar( @@ -249,6 +259,9 @@ class tilt_correction(param.ParameterizedFunction): cut_off_angle_deg: float The angle in degrees to cut off the rotation axis tilt correction, i.e. skip applying tilt correction for tilt angles that are too small. + center: Any + The center of the rotation axis, default is None, which means the center + of the image. This will be passed to the rotation function from skimage. max_workers: Number of cores to use for parallel median filtering, default is 0, which means using all available cores. @@ -275,6 +288,10 @@ class tilt_correction(param.ParameterizedFunction): default=2.0, doc="The angle in degrees to cut off the rotation axis tilt correction, i.e. skip applying tilt correction for tilt angles that are too small.", ) + center = param.Parameter( + default=None, + doc="The center of the rotation axis, default is None, which means the center of the image. This will be passed to the rotation function from skimage.", + ) # NOTE: # The front and backend are sharing the same computing unit, therefore we can # set a hard cap on the max_workers. @@ -319,6 +336,7 @@ def __call__(self, **params): kwargs = { "max_workers": self.max_workers, + "chunksize": calculate_chunksize(len(idx_lowrange), self.max_workers), "desc": "Calculating tilt correction", } if params.tqdm_class: @@ -329,6 +347,7 @@ def __call__(self, **params): calculate_tilt, low_bound=params.low_bound, high_bound=params.high_bound, + center=params.center, ), [shm_arrays[il] for il in idx_lowrange], [shm_arrays[ih] for ih in idx_highrange], @@ -349,6 +368,7 @@ def __call__(self, **params): corrected_array = apply_tilt_correction( arrays=params.arrays, tilt=tilt, + center=params.center, max_workers=self.max_workers, ) return corrected_array @@ -366,6 +386,8 @@ class apply_tilt_correction(param.ParameterizedFunction): The array for tilt correction tilt: float The rotation axis tilt angle in degrees + center: Any + The center of the rotation axis, default is None, which means the center max_workers: int Number of cores to use for parallel median filtering, default is 0, which means using all available cores. tqdm_class: panel.widgets.Tqdm @@ -379,6 +401,10 @@ class apply_tilt_correction(param.ParameterizedFunction): arrays = param.Array(doc="The array for tilt correction", default=None) tilt = param.Number(doc="The rotation axis tilt angle in degrees", default=None) + center = param.Parameter( + default=None, + doc="The center of the rotation axis, default is None, which means the center of the image. This will be passed to the rotation function from skimage.", + ) # NOTE: # The front and backend are sharing the same computing unit, therefore we can # set a hard cap on the max_workers. @@ -406,7 +432,9 @@ def __call__(self, **params): # dimensionality check if params.arrays.ndim == 2: logger.info(f"2D image detected, applying tilt correction with tilt = {params.tilt:.3f} deg") - corrected_array = rotate(params.arrays, -params.tilt, resize=False, preserve_range=True) + corrected_array = rotate( + params.arrays, -params.tilt, resize=False, preserve_range=True, center=params.center + ) elif params.arrays.ndim == 3: logger.info(f"3D array detected, applying tilt correction with tilt = {params.tilt:.3f} deg") with SharedMemoryManager() as smm: @@ -415,12 +443,13 @@ def __call__(self, **params): np.copyto(shm_arrays, params.arrays) kwargs = { "max_workers": self.max_workers, + "chunksize": calculate_chunksize(params.arrays.shape[0], self.max_workers), "desc": "Applying tilt corr", } if params.tqdm_class: kwargs["tqdm_class"] = params.tqdm_class rst = process_map( - partial(rotate, angle=-params.tilt, resize=False, preserve_range=True), + partial(rotate, angle=-params.tilt, resize=False, preserve_range=True, center=params.center), [shm_arrays[idx] for idx in range(params.arrays.shape[0])], **kwargs, ) diff --git a/src/imars3d/backend/preparation/normalization.py b/src/imars3d/backend/preparation/normalization.py index 49692213..9e03e2b1 100644 --- a/src/imars3d/backend/preparation/normalization.py +++ b/src/imars3d/backend/preparation/normalization.py @@ -41,7 +41,7 @@ class normalization(param.ParameterizedFunction): flats = param.Array( doc="3D array of flat field images (aka flat field, open beam), axis=0 is the image number axis.", default=None ) - darks = param.Array(doc="3D array of dark field images, axis=0 is the image number axis.", default=None) + darks = param.Array(doc="3D array of optional dark field images, axis=0 is the image number axis.", default=None) max_workers = param.Integer( default=0, bounds=(0, None), @@ -59,10 +59,15 @@ def __call__(self, **params): self.max_workers = clamp_max_workers(params.max_workers) logger.debug(f"max_worker={self.max_workers}") - # use median filter to remove outliers from flats and darks - # NOTE: this will remove the random noises coming from the environment. + # process flats (formerly known as open beam, white field) self.flats = np.median(params.flats, axis=0) - self.darks = np.median(params.darks, axis=0) + + # process darks (formerly known as black field) + if params.darks is None: + self.darks = np.zeros_like(self.flats) + else: + self.darks = np.median(params.darks, axis=0) + # apply normalization _bg = self.flats - self.darks _bg[_bg <= 0] = 1e-6 diff --git a/src/imars3d/backend/util/functions.py b/src/imars3d/backend/util/functions.py index 39a83ba7..ac5097ab 100644 --- a/src/imars3d/backend/util/functions.py +++ b/src/imars3d/backend/util/functions.py @@ -6,7 +6,7 @@ import logging import multiprocessing import resource -from typing import Union +from typing import Optional, Union logger = logging.getLogger(__name__) @@ -16,6 +16,15 @@ def clamp_max_workers(max_workers: Union[int, None]) -> int: """Calculate the number of max workers. If it isn't specified, return something appropriate for the system. + + Parameters + ---------- + max_workers: + The maximum number of workers to use + + Returns + ------- + The number of maximum """ if max_workers is None: max_workers = 0 @@ -29,6 +38,31 @@ def clamp_max_workers(max_workers: Union[int, None]) -> int: return result +def calculate_chunksize(num_elements: int, max_workers: Optional[int] = None, scale_factor: int = 4) -> int: + """Calculate an optimal chunk size for multiprocessing. + + Parameters + ---------- + num_elements: + The number of elements to process + max_workers: + The maximum number of workers to use + scale_factor: + The factor to scale the chunk size by + + Returns + ------- + The optimal chunk size + """ + # Calculate the number of workers + workers = clamp_max_workers(max_workers) + + # Calculate chunk size based on number of elements and workers + chunksize = max(1, num_elements // (workers * scale_factor)) + + return chunksize + + def to_time_str(value: datetime = datetime.now()) -> str: """ Convert the supplied datetime to a formatted string. diff --git a/tests/unit/backend/dataio/test_data.py b/tests/unit/backend/dataio/test_data.py index efbe04fb..0bd677a3 100644 --- a/tests/unit/backend/dataio/test_data.py +++ b/tests/unit/backend/dataio/test_data.py @@ -15,6 +15,8 @@ load_data, save_checkpoint, save_data, + extract_rotation_angle_from_filename, + extract_rotation_angle_from_tiff_metadata, ) @@ -181,6 +183,31 @@ def test_extract_rotation_angles(data_fixture): rst = _extract_rotation_angles([metadata_tiff] * 3) ref = np.array([0.1, 0.1, 0.1]) np.testing.assert_array_almost_equal(rst, ref) + # case_2: mixed file types + rst = _extract_rotation_angles([good_tiff, metadata_tiff, generic_tiff, generic_fits]) + ref = np.array([10.02, 0.1, np.nan, np.nan]) + np.testing.assert_array_equal(rst, ref) + # case_3: all files without extractable angles + rst = _extract_rotation_angles([generic_tiff, generic_fits]) + assert rst is None + + +def test_extract_rotation_angle_from_filename(): + # Test cases for extract_rotation_angle_from_filename + assert extract_rotation_angle_from_filename("20191030_sample_0070_300_440_0520.tiff") == 300.44 + assert extract_rotation_angle_from_filename("20191030_sample_0071_301_441_0521.tif") == 301.441 + assert extract_rotation_angle_from_filename("20191030_sample_0072_302_442_0522.fits") == 302.442 + assert extract_rotation_angle_from_filename("generic_file.tiff") is None + + +def test_extract_rotation_angle_from_tiff_metadata(tmpdir): + # Create a TIFF file with rotation angle in metadata + data = np.ones((3, 3)) + filename = str(tmpdir / "metadata.tiff") + tifffile.imwrite(filename, data, extratags=[(65039, "s", 0, "RotationActual:0.5", True)]) + + assert extract_rotation_angle_from_tiff_metadata(filename) == 0.5 + assert extract_rotation_angle_from_tiff_metadata("non_existent_file.tiff") is None @pytest.fixture(scope="module") diff --git a/tests/unit/backend/preparation/test_normalization.py b/tests/unit/backend/preparation/test_normalization.py index b0f9ee84..94b7a062 100644 --- a/tests/unit/backend/preparation/test_normalization.py +++ b/tests/unit/backend/preparation/test_normalization.py @@ -124,6 +124,24 @@ def test_normalization_bright_dark(): assert diff < 0.01 +def test_normalization_no_darks(): + """Test normalization routine without providing dark field images.""" + raw, _, flats, proj = prepare_synthetic_data() + + # Process with normalization, passing None for darks + proj_imars3d = normalization(arrays=raw, flats=flats, darks=None) + + # Compare results + diff = np.absolute(proj_imars3d - proj).sum() / np.prod(proj.shape) + assert diff < 0.02 # Increased tolerance from 0.01 to 0.02 to account for the lack of dark field images + + # Additional check: Ensure the shape of the output matches the input + assert proj_imars3d.shape == raw.shape + + # Check that values are within expected range (0 to 1 for normalized data) + assert np.all(proj_imars3d >= 0) and np.all(proj_imars3d <= 1) + + class TestMinusLog: @pytest.mark.parametrize("ncore", [1, 2]) def test_execution(self, ncore: int) -> None: diff --git a/tests/unit/backend/util/test_util.py b/tests/unit/backend/util/test_util.py index 19f609c4..ea384eef 100644 --- a/tests/unit/backend/util/test_util.py +++ b/tests/unit/backend/util/test_util.py @@ -1,8 +1,9 @@ # package imports -from imars3d.backend.util.functions import clamp_max_workers, to_time_str +from imars3d.backend.util.functions import clamp_max_workers, to_time_str, calculate_chunksize # third party imports import pytest +from unittest.mock import patch # standard imports from datetime import datetime @@ -13,6 +14,51 @@ def test_clamp_max_workers(): assert clamp_max_workers(-10) >= 1 +@patch("multiprocessing.cpu_count", return_value=8) +def test_chunksize_with_small_number_of_elements(mock_cpu_count): + num_elements = 10 + max_workers = None + chunksize = calculate_chunksize(num_elements, max_workers) + assert chunksize == 1 + + +@patch("multiprocessing.cpu_count", return_value=8) +def test_chunksize_with_large_number_of_elements(mock_cpu_count): + num_elements = 10000 + max_workers = None + chunksize = calculate_chunksize(num_elements, max_workers) + expected_chunksize = max(1, num_elements // (6 * 4)) # 6 workers, scale factor 4 + assert chunksize == expected_chunksize + + +@patch("multiprocessing.cpu_count", return_value=4) +def test_chunksize_with_different_cpu_count(mock_cpu_count): + num_elements = 10000 + max_workers = None + chunksize = calculate_chunksize(num_elements, max_workers) + expected_chunksize = max(1, num_elements // (2 * 4)) # 2 workers (cpu_count - 2), scale factor 4 + assert chunksize == expected_chunksize + + +@patch("multiprocessing.cpu_count", return_value=8) +def test_chunksize_with_max_workers(mock_cpu_count): + num_elements = 10000 + max_workers = 4 + chunksize = calculate_chunksize(num_elements, max_workers) + expected_chunksize = max(1, num_elements // (4 * 4)) # 4 workers manually set + assert chunksize == expected_chunksize + + +@patch("multiprocessing.cpu_count", return_value=8) +def test_chunksize_with_custom_scale_factor(mock_cpu_count): + num_elements = 10000 + max_workers = None + scale_factor = 2 + chunksize = calculate_chunksize(num_elements, max_workers, scale_factor=scale_factor) + expected_chunksize = max(1, num_elements // (6 * 2)) # 6 workers, scale factor 2 + assert chunksize == expected_chunksize + + @pytest.mark.parametrize( "timestamp", [