Skip to content

Commit

Permalink
fix: Improvements and bugfixes (#9)
Browse files Browse the repository at this point in the history
* Add batch size to prediction

* Reshape the prediction to match the input axes

* Enable patch_Z_spin when enable_3D is checked

* Better handle the situation if the number of channels is not correct set in advanced settings

* Fixed bug in configuration order

* Added tests and fixed reshape_prediction

---------

Co-authored-by: Fabian CRG <[email protected]>
  • Loading branch information
CURTLab and Fabian CRG authored Jan 10, 2025
1 parent 86ff1a0 commit 09af2b9
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 19 deletions.
3 changes: 2 additions & 1 deletion src/careamics_napari/careamics_utils/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ def create_configuration(signal: TrainingSignal) -> Configuration:
experiment_name = signal.experiment_name

if signal.is_3d:
# order of patches is ZYX
patches: list[int] = [
signal.patch_size_z,
signal.patch_size_xy,
signal.patch_size_xy,
signal.patch_size_z,
]
else:
patches = [signal.patch_size_xy, signal.patch_size_xy]
Expand Down
3 changes: 3 additions & 0 deletions src/careamics_napari/signals/prediction_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,6 @@ class PredictionSignal:

tile_overlap_z: int = 4 # TODO currently fixed
"""Overlap between the tiles along the Z dimension."""

batch_size: int = 1
"""Batch size."""
15 changes: 14 additions & 1 deletion src/careamics_napari/training_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
create_gpu_label,
)
from careamics_napari.workers import predict_worker, save_worker, train_worker
from careamics_napari.utils.axes_utils import reshape_prediction

import numpy as np

if TYPE_CHECKING:
import napari
Expand Down Expand Up @@ -346,7 +349,17 @@ def _update_from_prediction(self, update: PredictionUpdate) -> None:
# add image to napari
# TODO keep scaling?
if self.viewer is not None:
self.viewer.add_image(update.value, name="Prediction")
# value is eighter a numpy array or a list of numpy arrays with each sample/timepoint as an element
if isinstance(update.value, list):
# combine all samples
samples = np.concatenate(update.value, axis=0)
else:
samples = update.value

# reshape the prediction to match the input axes
samples = reshape_prediction(samples, self.train_config_signal.axes, self.pred_config_signal.is_3d)

self.viewer.add_image(samples, name="Prediction")
else:
self.pred_status.update(update)

Expand Down
54 changes: 54 additions & 0 deletions src/careamics_napari/utils/axes_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings
from itertools import permutations
import numpy as np

REF_AXES = "STCZYX"
"""References axes in CAREamics."""
Expand Down Expand Up @@ -80,3 +81,56 @@ def are_axes_valid(axes: str) -> bool:

# prior: X and Y contiguous
return ("XY" in _axes) or ("YX" in _axes)

def reshape_prediction(prediction: np.ndarray, axes: str, is_3d: bool) -> np.ndarray:
"""Reshape the prediction to match the input axes.
The default axes of the model prediction is SC(Z)YX.
Parameters
----------
prediction : np.ndarray
Prediction.
axes : str
Axes of the input data.
is_3d : bool
Whether the data is 3D.
Returns
-------
np.ndarray
Reshaped prediction.
"""

# model outputs SC(Z)YX
pred_axes = "SCZYX" if is_3d else "SCYX"

# transpose the axes
# TODO: during prediction T and S are merged. Check how to handle this
input_axes = axes.replace("T", "S")
remove_c, remove_s = False, False

if not "C" in input_axes:
# add C if missing
input_axes = "C" + input_axes
remove_c = True

if not "S" in input_axes:
# add S if missing
input_axes = "S" + input_axes
remove_s = True

# TODO: check if all axes are present
assert all([ax in input_axes for ax in pred_axes])

indices = [pred_axes.index(ax) for ax in input_axes]
prediction = np.transpose(prediction, indices)

# remove S if not present in the input axes
if remove_c:
prediction = prediction[0]

# remove C if not present in the input axes
if remove_s:
prediction = prediction[0]

return prediction
21 changes: 21 additions & 0 deletions src/careamics_napari/widgets/prediction_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .qt_widgets import (
PowerOfTwoSpinBox,
create_progressbar,
create_int_spinbox
)


Expand Down Expand Up @@ -103,9 +104,16 @@ def __init__(
self.tile_size_z.setToolTip("Tile size in the z dimension.")
self.tile_size_z.setEnabled(False)

self.batch_size_spin = create_int_spinbox(1, 512, 1, 1)
self.batch_size_spin.setToolTip(
"Number of patches per batch (decrease if GPU memory is insufficient)"
)
self.batch_size_spin.setEnabled(False)

tiling_form = QFormLayout()
tiling_form.addRow("XY tile size", self.tile_size_xy)
tiling_form.addRow("Z tile size", self.tile_size_z)
tiling_form.addRow("Batch size", self.batch_size_spin)
tiling_widget = QWidget()
tiling_widget.setLayout(tiling_form)
self.layout().addWidget(tiling_widget)
Expand Down Expand Up @@ -139,6 +147,7 @@ def __init__(

self.tile_size_xy.valueChanged.connect(self._set_xy_tile_size)
self.tile_size_z.valueChanged.connect(self._set_z_tile_size)
self.batch_size_spin.valueChanged.connect(self._set_batch_size)

# listening to the signals
self.train_signal.events.is_3d.connect(self._set_3d)
Expand Down Expand Up @@ -170,6 +179,17 @@ def _set_z_tile_size(self: Self, size: int) -> None:
if self.pred_signal is not None:
self.pred_signal.tile_size_z = size

def _set_batch_size(self: Self, size: int) -> None:
"""Update the signal batch size.
Parameters
----------
size : int
The new batch size.
"""
if self.pred_signal is not None:
self.pred_signal.batch_size = size

def _set_3d(self: Self, state: bool) -> None:
"""Enable the z tile size spinbox if the data is 3D.
Expand All @@ -191,6 +211,7 @@ def _update_tiles(self: Self, state: bool) -> None:
"""
self.pred_signal.tiled = state
self.tile_size_xy.setEnabled(state)
self.batch_size_spin.setEnabled(state)

if self.train_signal.is_3d:
self.tile_size_z.setEnabled(state)
Expand Down
2 changes: 2 additions & 0 deletions src/careamics_napari/widgets/training_configuration_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self: Self, training_signal: Optional[TrainingSignal] = None) -> No
self.patch_Z_spin = PowerOfTwoSpinBox(8, 512, 8)
self.patch_Z_spin.setToolTip("Dimension of the patches in Z.")

# TODO: is this necessary?
if self.configuration_signal is not None:
self.patch_Z_spin.setEnabled(self.configuration_signal.is_3d)

Expand Down Expand Up @@ -126,6 +127,7 @@ def _enable_3d_changed(self: Self, state: bool) -> None:
3D state.
"""
self.patch_Z_spin.setVisible(state)
self.patch_Z_spin.setEnabled(state)

if self.configuration_signal is not None:
self.configuration_signal.is_3d = state
Expand Down
3 changes: 3 additions & 0 deletions src/careamics_napari/workers/prediction_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,11 @@ def _predict(
config_signal.tile_overlap_xy,
config_signal.tile_overlap_xy,
)
batch_size = config_signal.batch_size
else:
tile_size = None
tile_overlap = None
batch_size = 1

# Predict with CAREamist
try:
Expand All @@ -157,6 +159,7 @@ def _predict(
data_type="tiff" if config_signal.load_from_disk else "array",
tile_size=tile_size,
tile_overlap=tile_overlap,
batch_size=batch_size,
)

update_queue.put(PredictionUpdate(PredictionUpdateType.SAMPLE, result))
Expand Down
40 changes: 23 additions & 17 deletions src/careamics_napari/workers/training_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,26 +109,32 @@ def _train(
CAREamist instance.
"""
# get configuration and queue
config = create_configuration(config_signal)

# Create CAREamist
if careamist is None:
careamist = CAREamist(
config, callbacks=[UpdaterCallBack(training_queue, predict_queue)]
)
try:
# create_configuration can raise an exception
config = create_configuration(config_signal)

else:
# only update the number of epochs
careamist.cfg.training_config.num_epochs = config.training_config.num_epochs

if config_signal.layer_val == "" and config_signal.path_val == "":
ntf.show_error(
"Continuing training is currently not supported without explicitely "
"passing validation. The reason is that otherwise, the data used for "
"validation will be different and there will be data leakage in the "
"training set."
# Create CAREamist
if careamist is None:
careamist = CAREamist(
config, callbacks=[UpdaterCallBack(training_queue, predict_queue)]
)

else:
# only update the number of epochs
careamist.cfg.training_config.num_epochs = config.training_config.num_epochs

if config_signal.layer_val == "" and config_signal.path_val == "":
ntf.show_error(
"Continuing training is currently not supported without explicitely "
"passing validation. The reason is that otherwise, the data used for "
"validation will be different and there will be data leakage in the "
"training set."
)
except Exception as e:
traceback.print_exc()

training_queue.put(TrainUpdate(TrainUpdateType.EXCEPTION, e))

# Register CAREamist
training_queue.put(TrainUpdate(TrainUpdateType.CAREAMIST, careamist))

Expand Down
94 changes: 94 additions & 0 deletions tests/test_predictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#from src.careamics_napari.training_plugin import TrainingPlugin
from careamics import CAREamist

from careamics.config import create_n2v_configuration

import numpy as np
import contextlib
import sys
from itertools import combinations

# disable logging
from careamics.careamist import logger
import logging

from careamics_napari.utils.axes_utils import reshape_prediction

logger.setLevel("ERROR")
logging.getLogger("pytorch_lightning.utilities.rank_zero").setLevel(logging.FATAL)

# nostdout from https://stackoverflow.com/questions/2828953/silence-the-stdout-of-a-function-in-python-without-trashing-sys-stdout-and-resto
class DummyFile(object):
def write(self, x):
pass

@contextlib.contextmanager
def nostdout():
save_stdout = sys.stdout
sys.stdout = DummyFile()
yield
sys.stdout = save_stdout

def generate_combinations_and_rotations(s):
# generate all combinations
combinations_list = []
for r in range(1, len(s) + 1):
combinations_list.extend([''.join(comb) for comb in combinations(s, r)])

# generate all rotations
rotations = set()
for i in range(len(s)):
rotated = s[i:] + s[:i]
rotations.add(rotated)

# combine results
all_results = set(combinations_list)
for rot in rotations:
for r in range(1, len(rot) + 1):
all_results.update([''.join(comb) for comb in combinations(rot, r)])

# add an empty
all_results.add("")

return sorted(all_results)

augmentation = generate_combinations_and_rotations("TZC")
for ax in augmentation:
test_axes = ax + "YX"
n_channels = 1
shape = []
for ax in test_axes:
if ax == "S":
shape.append(2)
elif ax == "T":
shape.append(4)
elif ax == "C":
shape.append(3)
n_channels = 3
else:
shape.append(16)

pred_data = np.random.randint(0, 255, shape).astype(np.float32)
with nostdout():
# create a configuration
config = create_n2v_configuration(
experiment_name=f'N2V_{test_axes}',
data_type="array",
axes=test_axes,
n_channels=n_channels,
patch_size=[8, 8, 8] if "Z" in test_axes else [8, 8],
batch_size=1,
num_epochs=1,
)

# instantiate a careamist
careamist = CAREamist(config)
careamist.cfg.data_config.set_means_and_stds([127.0]*n_channels, [75.0]*n_channels)

predction = careamist.predict(source=pred_data)
if isinstance(predction, list):
predction = np.concatenate(predction, axis=0)

pred = reshape_prediction(predction, test_axes, "Z" in test_axes)

assert pred_data.shape == pred.shape, f"Prediction shape {pred_data.shape} != {predction.shape} for axes {test_axes}"

0 comments on commit 09af2b9

Please sign in to comment.