Skip to content

Commit 24574ee

Browse files
committed
Brush initial functionality
1 parent f231267 commit 24574ee

File tree

1 file changed

+111
-2
lines changed

1 file changed

+111
-2
lines changed

pyidi/selection/main_window.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,29 @@
55
from matplotlib.path import Path
66
# import pyidi # Assuming pyidi is a custom module for video handling
77

8+
class BrushViewBox(pg.ViewBox):
9+
def __init__(self, parent_gui, *args, **kwargs):
10+
super().__init__(*args, **kwargs)
11+
self.setMouseMode(self.PanMode)
12+
self.parent_gui = parent_gui
13+
14+
def mouseClickEvent(self, ev):
15+
if self.parent_gui.mode == "manual" and self.parent_gui.method_buttons["Brush"].isChecked():
16+
ev.accept() # Prevent normal event
17+
self.parent_gui.handle_brush_start(ev)
18+
19+
def mouseDragEvent(self, ev, axis=None):
20+
if self.parent_gui.mode == "manual" and self.parent_gui.method_buttons["Brush"].isChecked():
21+
ev.accept()
22+
if ev.isStart():
23+
self.parent_gui._painting = True
24+
self.parent_gui._brush_path = []
25+
elif ev.isFinish():
26+
self.parent_gui._painting = False
27+
self.parent_gui.handle_brush_end(ev)
28+
else:
29+
self.parent_gui.handle_brush_move(ev)
30+
831
class SelectionGUI(QtWidgets.QMainWindow):
932
def __init__(self, video):
1033
app = QtWidgets.QApplication.instance()
@@ -16,6 +39,9 @@ def __init__(self, video):
1639
self.setWindowTitle("ROI Selection Tool")
1740
self.resize(1200, 800)
1841

42+
self._paint_mask = None # Same shape as the image
43+
self._paint_radius = 10 # pixels
44+
1945
self.selected_points = []
2046
self.manual_points = []
2147
self.candidate_points = []
@@ -122,7 +148,9 @@ def create_help_button(self, tooltip_text: str) -> QtWidgets.QToolButton:
122148
def ui_graphics(self):
123149
# Image viewer
124150
self.pg_widget = GraphicsLayoutWidget()
125-
self.view = self.pg_widget.addViewBox(lockAspect=True)
151+
self.view = BrushViewBox(parent_gui=self, lockAspect=True)
152+
self.pg_widget.addItem(self.view)
153+
126154

127155
self.image_item = ImageItem()
128156
self.polygon_line = pg.PlotDataItem(pen=pg.mkPen('y', width=2))
@@ -192,6 +220,7 @@ def ui_manual_right_menu(self):
192220
"Grid",
193221
"Manual",
194222
"Along the line",
223+
"Brush",
195224
"Remove point",
196225
]
197226
for i, name in enumerate(method_names):
@@ -374,6 +403,11 @@ def method_selected(self, id: int):
374403
print(f"Selected method: {method_name}")
375404
is_along = method_name == "Along the line"
376405
is_grid = method_name == "Grid"
406+
is_brush = method_name == "Brush"
407+
408+
# Disable panning
409+
self.view.setMouseEnabled(not is_brush, not is_brush)
410+
377411
show_spacing = is_along or is_grid
378412

379413
self.start_new_line_button.setVisible(is_along or is_grid)
@@ -401,7 +435,7 @@ def switch_mode(self, mode: str):
401435
self.automatic_mode_button.setChecked(True)
402436
self.stack.setCurrentWidget(self.automatic_widget)
403437

404-
self.compute_candidate_points()
438+
self.compute_candidate_points_shi_tomasi()
405439
self.show_points_checkbox.setChecked(False)
406440
self.roi_overlay.setVisible(False)
407441
self.scatter.setVisible(False)
@@ -419,6 +453,8 @@ def on_mouse_click(self, event):
419453
self.handle_grid_drawing(event)
420454
elif self.method_buttons["Remove point"].isChecked():
421455
self.handle_remove_point(event)
456+
elif self.method_buttons["Brush"].isChecked():
457+
self.handle_brush_start(event)
422458

423459
def update_selected_points(self):
424460
polygon_points = [pt for poly in self.drawing_polygons for pt in poly['roi_points']]
@@ -825,6 +861,63 @@ def clear_candidates(self):
825861
self.candidate_scatter.clear()
826862

827863
self.update_selected_points() # Update main display to remove candidates
864+
865+
# Brush
866+
def handle_brush_start(self, ev):
867+
QtWidgets.QApplication.setOverrideCursor(QtCore.Qt.CursorShape.CrossCursor)
868+
if self.image_item.image is None:
869+
return
870+
h, w = self.image_item.image.shape[:2]
871+
self._paint_mask = np.zeros((h, w), dtype=bool)
872+
self.handle_brush_move(ev)
873+
874+
def handle_brush_move(self, ev):
875+
if self._paint_mask is None:
876+
return
877+
878+
pos = ev.pos()
879+
if self.view.sceneBoundingRect().contains(pos):
880+
mouse_point = self.view.mapSceneToView(pos)
881+
y, x = int(round(mouse_point.x())), int(round(mouse_point.y()))
882+
r = self._paint_radius
883+
884+
h, w = self._paint_mask.shape
885+
yy, xx = np.ogrid[max(0, y - r):min(h, y + r + 1),
886+
max(0, x - r):min(w, x + r + 1)]
887+
mask = (yy - y) ** 2 + (xx - x) ** 2 <= r ** 2
888+
self._paint_mask[max(0, y - r):min(h, y + r + 1),
889+
max(0, x - r):min(w, x + r + 1)][mask] = True
890+
891+
self.update_brush_overlay()
892+
893+
def handle_brush_end(self, ev):
894+
QtWidgets.QApplication.restoreOverrideCursor()
895+
896+
if self._paint_mask is None:
897+
return
898+
899+
subset_size = self.subset_size_spinbox.value()
900+
spacing = self.distance_slider.value()
901+
brush_rois = rois_inside_mask(self._paint_mask, subset_size, spacing)
902+
self.manual_points.extend(brush_rois)
903+
904+
self._paint_mask = None
905+
self.update_selected_points()
906+
self.update_brush_overlay()
907+
908+
909+
def update_brush_overlay(self):
910+
if not hasattr(self, 'brush_overlay'):
911+
self.brush_overlay = ImageItem()
912+
self.view.addItem(self.brush_overlay)
913+
914+
if self._paint_mask is not None:
915+
rgba = np.zeros((*self._paint_mask.shape, 4), dtype=np.uint8)
916+
rgba[self._paint_mask] = [0, 200, 255, 80] # Cyan with transparency
917+
self.brush_overlay.setImage(rgba, autoLevels=False)
918+
self.brush_overlay.setZValue(2)
919+
else:
920+
self.brush_overlay.clear()
828921
################################################################################################
829922
# Automatic subset detection
830923
################################################################################################
@@ -878,6 +971,22 @@ def rois_inside_polygon(polygon, subset_size, spacing):
878971
mask = Path(polygon).contains_points(points)
879972
return [tuple(p) for p in points[mask]]
880973

974+
def rois_inside_mask(mask, subset_size, spacing):
975+
step = subset_size + spacing
976+
if step <= 0:
977+
step = 1
978+
979+
h, w = mask.shape
980+
xs = np.arange(0, w, step)
981+
ys = np.arange(0, h, step)
982+
grid_x, grid_y = np.meshgrid(xs, ys)
983+
984+
candidate_points = np.vstack([grid_y.ravel(), grid_x.ravel()]).T # (y, x)
985+
986+
# Only keep points where the mask is True
987+
selected = [tuple(p) for p in candidate_points if mask[p[0], p[1]]]
988+
return selected
989+
881990
if __name__ == "__main__":
882991
# import pyidi
883992
# filename = "data/data_showcase.cih"

0 commit comments

Comments
 (0)