Skip to content

Commit

Permalink
490 finish settings for embedding widget (#493)
Browse files Browse the repository at this point in the history
Finish settings for the embedding widget

---------

Co-authored-by: Constantin Pape <[email protected]>
  • Loading branch information
lufre1 and constantinpape authored Apr 3, 2024
1 parent 3c503f2 commit 7e88291
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 5 deletions.
54 changes: 50 additions & 4 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import z5py

from qtpy import QtWidgets
from qtpy.QtCore import QObject, Signal
from qtpy.QtCore import QObject, Signal, QFileInfo
from superqt import QCollapsible
from magicgui import magic_factory
from magicgui.widgets import ComboBox, Container, create_widget
Expand Down Expand Up @@ -123,6 +123,44 @@ def _add_shape_param(self, names, values, min_val, max_val, step=1):

return x_param, y_param, layout

def _add_path_param(self, name, value, title=None, select_file=False):
layout = QtWidgets.QHBoxLayout()
layout.addWidget(QtWidgets.QLabel(name if title is None else title))

directory_textbox = QtWidgets.QLineEdit()
directory_textbox.setText(value)
layout.addWidget(directory_textbox)

button_text = "Browse File" if select_file else "Browse Directory" # Adjust button text
directory_button = QtWidgets.QPushButton(button_text)
# Call appropriate function based on select_file
directory_button.clicked.connect(lambda: getattr(self, "_get_{}_path".format(
"directory" if not select_file else "file"))(name, directory_textbox))
layout.addWidget(directory_button)

return layout

def _get_directory_path(self, name, directory_textbox):
directory = QtWidgets.QFileDialog.getExistingDirectory(
self, "Select Directory", "", QtWidgets.QFileDialog.ShowDirsOnly)
if directory:
path = Path(directory) # Create a Path object from the string

if path.is_dir(): # Check if it's a valid directory
directory_textbox.setText(directory)
setattr(self, name, path)
else:
# Handle the case where the selected path is not a directory
print("Invalid directory selected. Please try again.")

def _get_file_path(self, name, directory_textbox):
file_path, _ = QtWidgets.QFileDialog.getOpenFileName(
self, "Select File", "", "All Files (*)"
)
if file_path:
directory_textbox.setText(file_path)
setattr(self, name, file_path)


# Custom signals for managing progress updates.
class PBarSignals(QObject):
Expand Down Expand Up @@ -587,10 +625,18 @@ def _create_settings_widget(self):
# TODO
# save_path: Optional[Path] = None, # where embeddings for this image are cached (optional, zarr file = folder)
# custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights.

# Create UI for the save path.
self.save_path = None
self.embeddings_save_path = None
layout = self._add_path_param(
"embeddings_save_path", self.embeddings_save_path, title="embeddings save path:")
setting_values.layout().addLayout(layout)

# Create UI for the custom weights.
self.custom_weights = None
self.custom_weights = None # select_file
layout = self._add_path_param(
"custom_weights", self.custom_weights, title="custom weights path:", select_file=True)
setting_values.layout().addLayout(layout)

# Create UI for the tile shape.
self.tile_x, self.tile_y = 0, 0
Expand Down Expand Up @@ -632,7 +678,7 @@ def __call__(self):

# Process tile_shape and halo, set other data.
tile_shape, halo = _process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
save_path = self.save_path
save_path = self.embeddings_save_path
image_data = image.data

# Set up progress bar and signals for using it within a threadworker.
Expand Down
2 changes: 1 addition & 1 deletion test/test_sam_annotator/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_embedding_widget(make_napari_viewer, tmp_path):
my_widget.image = layer
my_widget.model_type = "vit_t"
my_widget.device = "cpu"
my_widget.save_path = tmp_path
my_widget.embeddings_save_path = tmp_path

# Run image embedding widget.
worker = my_widget()
Expand Down

0 comments on commit 7e88291

Please sign in to comment.