From 7e88291b91dd1edad2ec917ef1ac09b2fe832370 Mon Sep 17 00:00:00 2001 From: lufre1 <155526548+lufre1@users.noreply.github.com> Date: Wed, 3 Apr 2024 16:59:05 +0200 Subject: [PATCH] 490 finish settings for embedding widget (#493) Finish settings for the embedding widget --------- Co-authored-by: Constantin Pape --- micro_sam/sam_annotator/_widgets.py | 54 +++++++++++++++++++++++-- test/test_sam_annotator/test_widgets.py | 2 +- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 44b131d4..89a41e7e 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -16,7 +16,7 @@ import z5py from qtpy import QtWidgets -from qtpy.QtCore import QObject, Signal +from qtpy.QtCore import QObject, Signal, QFileInfo from superqt import QCollapsible from magicgui import magic_factory from magicgui.widgets import ComboBox, Container, create_widget @@ -123,6 +123,44 @@ def _add_shape_param(self, names, values, min_val, max_val, step=1): return x_param, y_param, layout + def _add_path_param(self, name, value, title=None, select_file=False): + layout = QtWidgets.QHBoxLayout() + layout.addWidget(QtWidgets.QLabel(name if title is None else title)) + + directory_textbox = QtWidgets.QLineEdit() + directory_textbox.setText(value) + layout.addWidget(directory_textbox) + + button_text = "Browse File" if select_file else "Browse Directory" # Adjust button text + directory_button = QtWidgets.QPushButton(button_text) + # Call appropriate function based on select_file + directory_button.clicked.connect(lambda: getattr(self, "_get_{}_path".format( + "directory" if not select_file else "file"))(name, directory_textbox)) + layout.addWidget(directory_button) + + return layout + + def _get_directory_path(self, name, directory_textbox): + directory = QtWidgets.QFileDialog.getExistingDirectory( + self, "Select Directory", "", QtWidgets.QFileDialog.ShowDirsOnly) + if directory: + path = Path(directory) # Create a Path object from the string + + if path.is_dir(): # Check if it's a valid directory + directory_textbox.setText(directory) + setattr(self, name, path) + else: + # Handle the case where the selected path is not a directory + print("Invalid directory selected. Please try again.") + + def _get_file_path(self, name, directory_textbox): + file_path, _ = QtWidgets.QFileDialog.getOpenFileName( + self, "Select File", "", "All Files (*)" + ) + if file_path: + directory_textbox.setText(file_path) + setattr(self, name, file_path) + # Custom signals for managing progress updates. class PBarSignals(QObject): @@ -587,10 +625,18 @@ def _create_settings_widget(self): # TODO # save_path: Optional[Path] = None, # where embeddings for this image are cached (optional, zarr file = folder) # custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights. + # Create UI for the save path. - self.save_path = None + self.embeddings_save_path = None + layout = self._add_path_param( + "embeddings_save_path", self.embeddings_save_path, title="embeddings save path:") + setting_values.layout().addLayout(layout) + # Create UI for the custom weights. - self.custom_weights = None + self.custom_weights = None # select_file + layout = self._add_path_param( + "custom_weights", self.custom_weights, title="custom weights path:", select_file=True) + setting_values.layout().addLayout(layout) # Create UI for the tile shape. self.tile_x, self.tile_y = 0, 0 @@ -632,7 +678,7 @@ def __call__(self): # Process tile_shape and halo, set other data. tile_shape, halo = _process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) - save_path = self.save_path + save_path = self.embeddings_save_path image_data = image.data # Set up progress bar and signals for using it within a threadworker. diff --git a/test/test_sam_annotator/test_widgets.py b/test/test_sam_annotator/test_widgets.py index 03ecc15a..04bede16 100644 --- a/test/test_sam_annotator/test_widgets.py +++ b/test/test_sam_annotator/test_widgets.py @@ -30,7 +30,7 @@ def test_embedding_widget(make_napari_viewer, tmp_path): my_widget.image = layer my_widget.model_type = "vit_t" my_widget.device = "cpu" - my_widget.save_path = tmp_path + my_widget.embeddings_save_path = tmp_path # Run image embedding widget. worker = my_widget()