Skip to content

Commit

Permalink
Refactors flat fielding
Browse files Browse the repository at this point in the history
  • Loading branch information
Dimitar Tasev committed Dec 18, 2020
1 parent 5ea736b commit 3bbf042
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
33 changes: 17 additions & 16 deletions mantidimaging/core/operations/flat_fielding/flat_fielding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from mantidimaging import helper as h
from mantidimaging.core.data import Images
from mantidimaging.core.operations.base_filter import BaseFilter, FilterGroup
from mantidimaging.core.parallel import two_shared_mem as ptsm
from mantidimaging.core.parallel import utility as pu
from mantidimaging.core.parallel import utility as pu, shared as ps
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.gui.utility.qt_helpers import Type
from mantidimaging.gui.widgets.stack_selector import StackSelectorWidgetView
Expand Down Expand Up @@ -58,7 +57,7 @@ class FlatFieldFilter(BaseFilter):
filter_name = 'Flat-fielding'

@staticmethod
def filter_func(data: Images,
def filter_func(images: Images,
flat_before: Images = None,
flat_after: Images = None,
dark_before: Images = None,
Expand All @@ -80,7 +79,7 @@ def filter_func(data: Images,
:param chunksize: The number of chunks that each worker will receive.
:return: Filtered data (stack of images)
"""
h.check_data_stack(data)
h.check_data_stack(images)

if selected_flat_fielding is not None:
if selected_flat_fielding == "Both, concatenated" and flat_after is not None and flat_before is not None \
Expand All @@ -101,19 +100,19 @@ def filter_func(data: Images,
if 2 != flat_avg.ndim or 2 != dark_avg.ndim:
raise ValueError(
f"Incorrect shape of the flat image ({flat_avg.shape}) or dark image ({dark_avg.shape}) \
which should match the shape of the sample images ({data.data.shape})")
which should match the shape of the sample images ({images.data.shape})")

if not data.data.shape[1:] == flat_avg.shape == dark_avg.shape:
raise ValueError(f"Not all images are the expected shape: {data.data.shape[1:]}, instead "
if not images.data.shape[1:] == flat_avg.shape == dark_avg.shape:
raise ValueError(f"Not all images are the expected shape: {images.data.shape[1:]}, instead "
f"flat had shape: {flat_avg.shape}, and dark had shape: {dark_avg.shape}")

progress = Progress.ensure_instance(progress,
num_steps=data.data.shape[0],
num_steps=images.data.shape[0],
task_name='Background Correction')
_execute(data.data, flat_avg, dark_avg, cores, chunksize, progress)
_execute(images.data, flat_avg, dark_avg, cores, chunksize, progress)

h.check_data_stack(data)
return data
h.check_data_stack(images)
return images

@staticmethod
def register_gui(form, on_change, view: FiltersWindowView) -> Dict[str, Any]:
Expand Down Expand Up @@ -260,7 +259,7 @@ def _subtract(data, dark=None):
np.subtract(data, dark, out=data)


def _execute(data, flat=None, dark=None, cores=None, chunksize=None, progress=None):
def _execute(data: np.ndarray, flat=None, dark=None, cores=None, chunksize=None, progress=None):
"""A benchmark justifying the current implementation, performed on
500x2048x2048 images.
Expand Down Expand Up @@ -289,11 +288,13 @@ def _execute(data, flat=None, dark=None, cores=None, chunksize=None, progress=No
norm_divide[norm_divide == 0] = MINIMUM_PIXEL_VALUE

# subtract the dark from all images
f = ptsm.create_partial(_subtract, fwd_function=ptsm.inplace_second_2d)
data, dark = ptsm.execute(data, dark, f, cores, chunksize, progress=progress)
do_subtract = ps.create_partial(_subtract, fwd_function=ps.inplace_second_2d)
ps.shared_list = [data, dark]
ps.execute(do_subtract, data.shape[0], progress, cores=cores)

# divide the data by (flat - dark)
f = ptsm.create_partial(_divide, fwd_function=ptsm.inplace_second_2d)
data, norm_divide = ptsm.execute(data, norm_divide, f, cores, chunksize, progress=progress)
do_divide = ps.create_partial(_divide, fwd_function=ps.inplace_second_2d)
ps.shared_list = [data, norm_divide]
ps.execute(do_divide, data.shape[0], progress, cores=cores)

return data
4 changes: 4 additions & 0 deletions mantidimaging/core/parallel/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def return_to_self1(func, i, **kwargs):
shared_list[0][i] = func(shared_list[0][i], **kwargs)


def inplace_second_2d(func, i, **kwargs):
func(shared_list[0][i], shared_list[1], **kwargs)


def create_partial(func, fwd_function, **kwargs):
"""
Create a partial using functools.partial, to forward the kwargs to the
Expand Down

0 comments on commit 3bbf042

Please sign in to comment.