diff --git a/src/napari_em_stack_reg/_widget.py b/src/napari_em_stack_reg/_widget.py index ece383a..8a738af 100644 --- a/src/napari_em_stack_reg/_widget.py +++ b/src/napari_em_stack_reg/_widget.py @@ -42,6 +42,7 @@ # from qtpy.QtCore import Qt from qtpy.QtWidgets import ( + QFrame, # QHBoxLayout, QLabel, # QMainWindow, @@ -51,6 +52,9 @@ QWidget, ) +from napari_em_stack_reg.tools.register import StackableImage + + if TYPE_CHECKING: import napari @@ -65,25 +69,45 @@ def __init__(self, viewer: "napari.viewer.Viewer"): self._stack_details = QLabel("No image layer found.") self._stack_details.setWordWrap(True) - self._begin_registration_button = QPushButton("Begin registration") - self._begin_registration_button.hide() + self._begin_button = QPushButton("Begin") + self._begin_button.hide() self._update_stack_details() self.setLayout(QVBoxLayout()) self.layout().addWidget(self._stack_details) - self.layout().addWidget(self._begin_registration_button) + + self._divider = QFrame() + self._divider.setFrameShape(QFrame.HLine) + self._divider.setFrameShadow(QFrame.Sunken) + self.layout().addWidget(self._divider) + self.layout().addWidget(self._begin_button) self._viewer.layers.events.inserted.connect(self._on_layer_inserted) self._viewer.layers.events.removed.connect(self._on_layer_removed) + self._begin_button.clicked.connect(self._on_begin_button_clicked) + def _on_layer_inserted(self): self._update_stack_details() - self._begin_registration_button.show() + self._begin_button.show() def _on_layer_removed(self): self._update_stack_details() - self._begin_registration_button.hide() + self._begin_button.hide() + + def _on_begin_button_clicked(self): + # print("Begin button clicked") + + self._begin_button.hide() + stackable_img = StackableImage(self._viewer) + stackable_img.get_registration_images() + + for layer in self._viewer.layers: + layer.events.transform.connect(self._on_transform_changed) + + def _on_transform_changed(self, event): + print("Transform changed") def _update_stack_details(self): # print("this") diff --git a/src/napari_em_stack_reg/tools/__init__.py b/src/napari_em_stack_reg/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/napari_em_stack_reg/tools/register.py b/src/napari_em_stack_reg/tools/register.py new file mode 100644 index 0000000..4ee8321 --- /dev/null +++ b/src/napari_em_stack_reg/tools/register.py @@ -0,0 +1,70 @@ +import os + +import dask.array as da +import numpy as np +from napari.layers import Image +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import napari + + +class StackableImage: + def __init__(self, viewer: "napari.viewer.Viewer"): + self._viewer = viewer + self._set_original_stack() + + self._reference_index = 0 + self._moving_index = 1 + self._transforms = [] + + def _set_original_stack(self): + image_layer = next( + ( + layer + for layer in self._viewer.layers + if isinstance(layer, Image) + ), + None, + ) + + if image_layer is not None: + if isinstance(image_layer.data, np.ndarray): + self._original_stack = da.from_array(image_layer.data) + elif isinstance(image_layer.data, da.Array): + self._original_stack = image_layer.data + + def get_registration_images(self, reference_index: int = 0): + self._reference_img = self._original_stack[reference_index] + self._moving_img = self._original_stack[self._moving_index] + + current_image_layer = next( + ( + layer + for layer in self._viewer.layers + if isinstance(layer, Image) + ), + None, + ) + + if current_image_layer is not None: + self._viewer.layers.remove(current_image_layer) + + # add reference image + self._viewer.add_image( + self._reference_img, + name=f"ref - slice {reference_index}", + blending="translucent_no_depth", + colormap="gray", + ) + + # add moving image + self._viewer.add_image( + self._moving_img, + name=f"moving - slice {self._moving_index}", + blending="translucent_no_depth", + colormap="gray", + opacity=0.5, + ) + + self._viewer.layers[-1].mode = "transform"