Skip to content

Commit

Permalink
refactored alignment.py using classes
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Jan 12, 2024
1 parent 2d7e036 commit f3c1b27
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 260 deletions.
6 changes: 3 additions & 3 deletions brainglobe_template_builder/napari/_widget.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from brainglobe_utils.qtpy.collapsible_widget import CollapsibleWidgetContainer
from napari.viewer import Viewer

from brainglobe_template_builder.napari.align_widget import AlignMidplane
from brainglobe_template_builder.napari.mask_widget import CreateMask
from brainglobe_template_builder.napari.midline_widget import FindMidline


class PreprocWidgets(CollapsibleWidgetContainer):
Expand All @@ -17,9 +17,9 @@ def __init__(self, napari_viewer: Viewer, parent=None):
self._expand_mask_widget()

self.add_widget(
FindMidline(napari_viewer, parent=self),
AlignMidplane(napari_viewer, parent=self),
collapsible=True,
widget_title="Find midline",
widget_title="Align midplane",
)
self._connect_midline_widget_toggle()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
)

from brainglobe_template_builder.preproc import (
apply_transform,
get_alignment_transform,
get_midline_points,
MidplaneAligner,
MidplaneEstimator,
)


class FindMidline(QWidget):
"""Widget to find the mid-sagittal plane based on annotated points."""
class AlignMidplane(QWidget):
"""Widget to align the plane of symmetry to the midplane of the image."""

def __init__(self, napari_viewer: Viewer, parent=None):
super().__init__(parent=parent)
Expand All @@ -27,9 +26,9 @@ def __init__(self, napari_viewer: Viewer, parent=None):
self._create_align_group()

def _create_estimate_group(self):
"""Create the group of widgets concerned with estimating midline
"""Create the group of widgets concerned with estimating midplane
points."""
self.estimate_groupbox = QGroupBox("Estimate points along midline")
self.estimate_groupbox = QGroupBox("Estimate points along midplane")
self.estimate_groupbox.setLayout(QFormLayout())
self.layout().addRow(self.estimate_groupbox)

Expand All @@ -43,7 +42,14 @@ def _create_estimate_group(self):
"mask:", self.select_mask_dropdown
)

# Initialise button to estimate midline points
# Add dropdown to select axis
self.select_axis_dropdown = QComboBox(parent=self.estimate_groupbox)
self.select_axis_dropdown.addItems(["x", "y", "z"])
self.estimate_groupbox.layout().addRow(
"symmetry axis:", self.select_axis_dropdown
)

# Initialise button to estimate midplane points
self.estimate_points_button = QPushButton(
"Estimate points", parent=self.estimate_groupbox
)
Expand All @@ -55,9 +61,9 @@ def _create_estimate_group(self):

def _create_align_group(self):
"""Create the group of widgets concerned with aligning the image to
the midline."""
the midplane."""

self.align_groupbox = QGroupBox("Align image to midline")
self.align_groupbox = QGroupBox("Align image to midplane")
self.align_groupbox.setLayout(QFormLayout())
self.layout().addRow(self.align_groupbox)

Expand All @@ -81,32 +87,14 @@ def _create_align_group(self):
"points:", self.select_points_dropdown
)

# Add dropdown to select axis
self.select_axis_dropdown = QComboBox(parent=self.align_groupbox)
self.select_axis_dropdown.addItems(["x", "y", "z"])
self.align_groupbox.layout().addRow("axis:", self.select_axis_dropdown)

# Add button to align image to midline
# Add button to align image to midplane
self.align_image_button = QPushButton(
"Align image", parent=self.align_groupbox
)
self.align_image_button.setEnabled(False)
self.align_image_button.clicked.connect(self._on_align_button_click)
self.align_groupbox.layout().addRow(self.align_image_button)

# 9 colors taken from ColorBrewer2.org Set3 palette
self.point_colors = [
"#8dd3c7",
"#ffffb3",
"#bebada",
"#fb8072",
"#80b1d3",
"#fdb462",
"#b3de69",
"#fccde5",
"#d9d9d9",
]

def _get_layers_by_type(self, layer_type: Layer) -> list:
"""Return a list of napari layers of a given type."""
return [
Expand All @@ -129,50 +117,46 @@ def refresh_dropdowns(self):
dropdown.addItems(self._get_layers_by_type(layer_type))

def _on_estimate_button_click(self):
"""Estimate midline points and add them to the viewer."""

# Estimate 9 midline points based on the selected mask
"""Estimate midplane points and add them to the viewer."""
# Estimate 9 midplane points based on the selected mask
mask_name = self.select_mask_dropdown.currentText()
mask = self.viewer.layers[mask_name]
points = get_midline_points(mask.data)
axis = self.select_axis_dropdown.currentText()
estimator = MidplaneEstimator(mask.data, symmetry_axis=axis)
points = estimator.get_points()

# Point layer attributes
point_attrs = {
"properties": {"label": range(1, points.shape[0] + 1)},
"face_color": "label",
"face_color_cycle": self.point_colors,
"properties": {"label": list(range(9))},
"face_color": "green",
"symbol": "cross",
"edge_width": 0,
"opacity": 0.6,
"size": 6,
"ndim": mask.ndim,
"name": "midline points",
"name": "midplane points",
}

mask.visible = False
self.viewer.add_points(points, **point_attrs)
self.refresh_dropdowns()
show_info(
"Please move the estimated points so that they sit exactly "
"on the mid-sagittal plane."
)
# Move viewer to show z-plane of first point
self.viewer.dims.set_point(0, points[0][0])
# Enable "Select points" mode
self.viewer.layers["midplane points"].mode = "select"
show_info("Please move all 9 estimated points exactly to the midplane")

def _on_align_button_click(self):
"""Align image and add the transformed image to the viewer."""
image_name = self.select_image_dropdown.currentText()
points_name = self.select_points_dropdown.currentText()
axis = self.select_axis_dropdown.currentText()

transform = get_alignment_transform(
aligner = MidplaneAligner(
self.viewer.layers[image_name].data,
self.viewer.layers[points_name].data,
axis=axis,
symmetry_axis=axis,
)

aligned_image = apply_transform(
self.viewer.layers[image_name].data,
transform,
)

aligned_image = aligner.transform_image()
self.viewer.add_image(aligned_image, name="aligned image")

def _on_dropdown_selection_change(self):
Expand Down
5 changes: 2 additions & 3 deletions brainglobe_template_builder/preproc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from brainglobe_template_builder.preproc.masking import create_mask
from brainglobe_template_builder.preproc.alignment import (
get_midline_points,
get_alignment_transform,
apply_transform,
MidplaneAligner,
MidplaneEstimator,
)
Loading

0 comments on commit f3c1b27

Please sign in to comment.