From 09af2b95418807a6ab6afe55641e169fd74d4047 Mon Sep 17 00:00:00 2001 From: CURTLab Date: Fri, 10 Jan 2025 20:31:08 +0100 Subject: [PATCH] fix: Improvements and bugfixes (#9) * Add batch size to prediction * Reshape the prediction to match the input axes * Enable patch_Z_spin when enable_3D is checked * Better handle the situation if the number of channels is not correct set in advanced settings * Fixed bug in configuration order * Added tests and fixed reshape_prediction --------- Co-authored-by: Fabian CRG --- .../careamics_utils/configuration.py | 3 +- .../signals/prediction_signal.py | 3 + src/careamics_napari/training_plugin.py | 15 ++- src/careamics_napari/utils/axes_utils.py | 54 +++++++++++ .../widgets/prediction_widget.py | 21 +++++ .../widgets/training_configuration_widget.py | 2 + .../workers/prediction_worker.py | 3 + .../workers/training_worker.py | 40 ++++---- tests/test_predictions.py | 94 +++++++++++++++++++ 9 files changed, 216 insertions(+), 19 deletions(-) create mode 100644 tests/test_predictions.py diff --git a/src/careamics_napari/careamics_utils/configuration.py b/src/careamics_napari/careamics_utils/configuration.py index bac4736..d7beeb5 100644 --- a/src/careamics_napari/careamics_utils/configuration.py +++ b/src/careamics_napari/careamics_utils/configuration.py @@ -39,10 +39,11 @@ def create_configuration(signal: TrainingSignal) -> Configuration: experiment_name = signal.experiment_name if signal.is_3d: + # order of patches is ZYX patches: list[int] = [ + signal.patch_size_z, signal.patch_size_xy, signal.patch_size_xy, - signal.patch_size_z, ] else: patches = [signal.patch_size_xy, signal.patch_size_xy] diff --git a/src/careamics_napari/signals/prediction_signal.py b/src/careamics_napari/signals/prediction_signal.py index a53476e..846a795 100644 --- a/src/careamics_napari/signals/prediction_signal.py +++ b/src/careamics_napari/signals/prediction_signal.py @@ -55,3 +55,6 @@ class PredictionSignal: tile_overlap_z: int = 4 # TODO currently fixed """Overlap between the tiles along the Z dimension.""" + + batch_size: int = 1 + """Batch size.""" diff --git a/src/careamics_napari/training_plugin.py b/src/careamics_napari/training_plugin.py index 7e1959f..8ce52c0 100644 --- a/src/careamics_napari/training_plugin.py +++ b/src/careamics_napari/training_plugin.py @@ -40,6 +40,9 @@ create_gpu_label, ) from careamics_napari.workers import predict_worker, save_worker, train_worker +from careamics_napari.utils.axes_utils import reshape_prediction + +import numpy as np if TYPE_CHECKING: import napari @@ -346,7 +349,17 @@ def _update_from_prediction(self, update: PredictionUpdate) -> None: # add image to napari # TODO keep scaling? if self.viewer is not None: - self.viewer.add_image(update.value, name="Prediction") + # value is eighter a numpy array or a list of numpy arrays with each sample/timepoint as an element + if isinstance(update.value, list): + # combine all samples + samples = np.concatenate(update.value, axis=0) + else: + samples = update.value + + # reshape the prediction to match the input axes + samples = reshape_prediction(samples, self.train_config_signal.axes, self.pred_config_signal.is_3d) + + self.viewer.add_image(samples, name="Prediction") else: self.pred_status.update(update) diff --git a/src/careamics_napari/utils/axes_utils.py b/src/careamics_napari/utils/axes_utils.py index 79ec95d..19a00c3 100644 --- a/src/careamics_napari/utils/axes_utils.py +++ b/src/careamics_napari/utils/axes_utils.py @@ -2,6 +2,7 @@ import warnings from itertools import permutations +import numpy as np REF_AXES = "STCZYX" """References axes in CAREamics.""" @@ -80,3 +81,56 @@ def are_axes_valid(axes: str) -> bool: # prior: X and Y contiguous return ("XY" in _axes) or ("YX" in _axes) + +def reshape_prediction(prediction: np.ndarray, axes: str, is_3d: bool) -> np.ndarray: + """Reshape the prediction to match the input axes. + The default axes of the model prediction is SC(Z)YX. + + Parameters + ---------- + prediction : np.ndarray + Prediction. + axes : str + Axes of the input data. + is_3d : bool + Whether the data is 3D. + + Returns + ------- + np.ndarray + Reshaped prediction. + """ + + # model outputs SC(Z)YX + pred_axes = "SCZYX" if is_3d else "SCYX" + + # transpose the axes + # TODO: during prediction T and S are merged. Check how to handle this + input_axes = axes.replace("T", "S") + remove_c, remove_s = False, False + + if not "C" in input_axes: + # add C if missing + input_axes = "C" + input_axes + remove_c = True + + if not "S" in input_axes: + # add S if missing + input_axes = "S" + input_axes + remove_s = True + + # TODO: check if all axes are present + assert all([ax in input_axes for ax in pred_axes]) + + indices = [pred_axes.index(ax) for ax in input_axes] + prediction = np.transpose(prediction, indices) + + # remove S if not present in the input axes + if remove_c: + prediction = prediction[0] + + # remove C if not present in the input axes + if remove_s: + prediction = prediction[0] + + return prediction \ No newline at end of file diff --git a/src/careamics_napari/widgets/prediction_widget.py b/src/careamics_napari/widgets/prediction_widget.py index 17c0e5f..992e35a 100644 --- a/src/careamics_napari/widgets/prediction_widget.py +++ b/src/careamics_napari/widgets/prediction_widget.py @@ -27,6 +27,7 @@ from .qt_widgets import ( PowerOfTwoSpinBox, create_progressbar, + create_int_spinbox ) @@ -103,9 +104,16 @@ def __init__( self.tile_size_z.setToolTip("Tile size in the z dimension.") self.tile_size_z.setEnabled(False) + self.batch_size_spin = create_int_spinbox(1, 512, 1, 1) + self.batch_size_spin.setToolTip( + "Number of patches per batch (decrease if GPU memory is insufficient)" + ) + self.batch_size_spin.setEnabled(False) + tiling_form = QFormLayout() tiling_form.addRow("XY tile size", self.tile_size_xy) tiling_form.addRow("Z tile size", self.tile_size_z) + tiling_form.addRow("Batch size", self.batch_size_spin) tiling_widget = QWidget() tiling_widget.setLayout(tiling_form) self.layout().addWidget(tiling_widget) @@ -139,6 +147,7 @@ def __init__( self.tile_size_xy.valueChanged.connect(self._set_xy_tile_size) self.tile_size_z.valueChanged.connect(self._set_z_tile_size) + self.batch_size_spin.valueChanged.connect(self._set_batch_size) # listening to the signals self.train_signal.events.is_3d.connect(self._set_3d) @@ -170,6 +179,17 @@ def _set_z_tile_size(self: Self, size: int) -> None: if self.pred_signal is not None: self.pred_signal.tile_size_z = size + def _set_batch_size(self: Self, size: int) -> None: + """Update the signal batch size. + + Parameters + ---------- + size : int + The new batch size. + """ + if self.pred_signal is not None: + self.pred_signal.batch_size = size + def _set_3d(self: Self, state: bool) -> None: """Enable the z tile size spinbox if the data is 3D. @@ -191,6 +211,7 @@ def _update_tiles(self: Self, state: bool) -> None: """ self.pred_signal.tiled = state self.tile_size_xy.setEnabled(state) + self.batch_size_spin.setEnabled(state) if self.train_signal.is_3d: self.tile_size_z.setEnabled(state) diff --git a/src/careamics_napari/widgets/training_configuration_widget.py b/src/careamics_napari/widgets/training_configuration_widget.py index 17eb0eb..c64a1db 100644 --- a/src/careamics_napari/widgets/training_configuration_widget.py +++ b/src/careamics_napari/widgets/training_configuration_widget.py @@ -78,6 +78,7 @@ def __init__(self: Self, training_signal: Optional[TrainingSignal] = None) -> No self.patch_Z_spin = PowerOfTwoSpinBox(8, 512, 8) self.patch_Z_spin.setToolTip("Dimension of the patches in Z.") + # TODO: is this necessary? if self.configuration_signal is not None: self.patch_Z_spin.setEnabled(self.configuration_signal.is_3d) @@ -126,6 +127,7 @@ def _enable_3d_changed(self: Self, state: bool) -> None: 3D state. """ self.patch_Z_spin.setVisible(state) + self.patch_Z_spin.setEnabled(state) if self.configuration_signal is not None: self.configuration_signal.is_3d = state diff --git a/src/careamics_napari/workers/prediction_worker.py b/src/careamics_napari/workers/prediction_worker.py index d57dd0c..9d54c59 100644 --- a/src/careamics_napari/workers/prediction_worker.py +++ b/src/careamics_napari/workers/prediction_worker.py @@ -146,9 +146,11 @@ def _predict( config_signal.tile_overlap_xy, config_signal.tile_overlap_xy, ) + batch_size = config_signal.batch_size else: tile_size = None tile_overlap = None + batch_size = 1 # Predict with CAREamist try: @@ -157,6 +159,7 @@ def _predict( data_type="tiff" if config_signal.load_from_disk else "array", tile_size=tile_size, tile_overlap=tile_overlap, + batch_size=batch_size, ) update_queue.put(PredictionUpdate(PredictionUpdateType.SAMPLE, result)) diff --git a/src/careamics_napari/workers/training_worker.py b/src/careamics_napari/workers/training_worker.py index 9a031ac..7428995 100644 --- a/src/careamics_napari/workers/training_worker.py +++ b/src/careamics_napari/workers/training_worker.py @@ -109,26 +109,32 @@ def _train( CAREamist instance. """ # get configuration and queue - config = create_configuration(config_signal) - - # Create CAREamist - if careamist is None: - careamist = CAREamist( - config, callbacks=[UpdaterCallBack(training_queue, predict_queue)] - ) + try: + # create_configuration can raise an exception + config = create_configuration(config_signal) - else: - # only update the number of epochs - careamist.cfg.training_config.num_epochs = config.training_config.num_epochs - - if config_signal.layer_val == "" and config_signal.path_val == "": - ntf.show_error( - "Continuing training is currently not supported without explicitely " - "passing validation. The reason is that otherwise, the data used for " - "validation will be different and there will be data leakage in the " - "training set." + # Create CAREamist + if careamist is None: + careamist = CAREamist( + config, callbacks=[UpdaterCallBack(training_queue, predict_queue)] ) + else: + # only update the number of epochs + careamist.cfg.training_config.num_epochs = config.training_config.num_epochs + + if config_signal.layer_val == "" and config_signal.path_val == "": + ntf.show_error( + "Continuing training is currently not supported without explicitely " + "passing validation. The reason is that otherwise, the data used for " + "validation will be different and there will be data leakage in the " + "training set." + ) + except Exception as e: + traceback.print_exc() + + training_queue.put(TrainUpdate(TrainUpdateType.EXCEPTION, e)) + # Register CAREamist training_queue.put(TrainUpdate(TrainUpdateType.CAREAMIST, careamist)) diff --git a/tests/test_predictions.py b/tests/test_predictions.py new file mode 100644 index 0000000..3869bb3 --- /dev/null +++ b/tests/test_predictions.py @@ -0,0 +1,94 @@ +#from src.careamics_napari.training_plugin import TrainingPlugin +from careamics import CAREamist + +from careamics.config import create_n2v_configuration + +import numpy as np +import contextlib +import sys +from itertools import combinations + +# disable logging +from careamics.careamist import logger +import logging + +from careamics_napari.utils.axes_utils import reshape_prediction + +logger.setLevel("ERROR") +logging.getLogger("pytorch_lightning.utilities.rank_zero").setLevel(logging.FATAL) + +# nostdout from https://stackoverflow.com/questions/2828953/silence-the-stdout-of-a-function-in-python-without-trashing-sys-stdout-and-resto +class DummyFile(object): + def write(self, x): + pass + +@contextlib.contextmanager +def nostdout(): + save_stdout = sys.stdout + sys.stdout = DummyFile() + yield + sys.stdout = save_stdout + +def generate_combinations_and_rotations(s): + # generate all combinations + combinations_list = [] + for r in range(1, len(s) + 1): + combinations_list.extend([''.join(comb) for comb in combinations(s, r)]) + + # generate all rotations + rotations = set() + for i in range(len(s)): + rotated = s[i:] + s[:i] + rotations.add(rotated) + + # combine results + all_results = set(combinations_list) + for rot in rotations: + for r in range(1, len(rot) + 1): + all_results.update([''.join(comb) for comb in combinations(rot, r)]) + + # add an empty + all_results.add("") + + return sorted(all_results) + +augmentation = generate_combinations_and_rotations("TZC") +for ax in augmentation: + test_axes = ax + "YX" + n_channels = 1 + shape = [] + for ax in test_axes: + if ax == "S": + shape.append(2) + elif ax == "T": + shape.append(4) + elif ax == "C": + shape.append(3) + n_channels = 3 + else: + shape.append(16) + + pred_data = np.random.randint(0, 255, shape).astype(np.float32) + with nostdout(): + # create a configuration + config = create_n2v_configuration( + experiment_name=f'N2V_{test_axes}', + data_type="array", + axes=test_axes, + n_channels=n_channels, + patch_size=[8, 8, 8] if "Z" in test_axes else [8, 8], + batch_size=1, + num_epochs=1, + ) + + # instantiate a careamist + careamist = CAREamist(config) + careamist.cfg.data_config.set_means_and_stds([127.0]*n_channels, [75.0]*n_channels) + + predction = careamist.predict(source=pred_data) + if isinstance(predction, list): + predction = np.concatenate(predction, axis=0) + + pred = reshape_prediction(predction, test_axes, "Z" in test_axes) + + assert pred_data.shape == pred.shape, f"Prediction shape {pred_data.shape} != {predction.shape} for axes {test_axes}" \ No newline at end of file