From 001878b7627fa8fefb57068af24f86a5563d92a0 Mon Sep 17 00:00:00 2001 From: Peter Sobolewski <76622105+psobolewskiPhD@users.noreply.github.com> Date: Tue, 3 Dec 2024 15:53:29 -0500 Subject: [PATCH] [Enh] Handle segmenting image layers that have non-1 layer.scale (#804) Handle segmenting image layers that have non-1 layer.scale --- micro_sam/sam_annotator/_annotator.py | 6 +++++- micro_sam/sam_annotator/_state.py | 2 ++ micro_sam/sam_annotator/_widgets.py | 17 ++++++++++++----- micro_sam/sam_annotator/util.py | 4 ++-- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/micro_sam/sam_annotator/_annotator.py b/micro_sam/sam_annotator/_annotator.py index fcd58cc44..2dc7efee4 100644 --- a/micro_sam/sam_annotator/_annotator.py +++ b/micro_sam/sam_annotator/_annotator.py @@ -163,11 +163,15 @@ def _update_image(self, segmentation_result=None): # Reset all layers. self._viewer.layers["current_object"].data = np.zeros(self._shape, dtype="uint32") + self._viewer.layers["current_object"].scale = state.image_scale self._viewer.layers["auto_segmentation"].data = np.zeros(self._shape, dtype="uint32") + self._viewer.layers["auto_segmentation"].scale = state.image_scale if segmentation_result is None or segmentation_result is False: self._viewer.layers["committed_objects"].data = np.zeros(self._shape, dtype="uint32") else: assert segmentation_result.shape == self._shape self._viewer.layers["committed_objects"].data = segmentation_result - + self._viewer.layers["committed_objects"].scale = state.image_scale + self._viewer.layers["point_prompts"].scale = state.image_scale + self._viewer.layers["prompts"].scale = state.image_scale vutil.clear_annotations(self._viewer, clear_segmentations=False) diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index ee42e4ebb..3de8affb5 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -42,6 +42,7 @@ class AnnotatorState(metaclass=Singleton): image_embeddings: Optional[util.ImageEmbeddings] = None predictor: Optional[SamPredictor] = None image_shape: Optional[Tuple[int, int]] = None + image_scale: Optional[Tuple[float, ...]] = None embedding_path: Optional[str] = None data_signature: Optional[str] = None @@ -198,6 +199,7 @@ def reset_state(self): self.image_embeddings = None self.predictor = None self.image_shape = None + self.image_scale = None self.embedding_path = None self.amg = None self.amg_state = None diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 4a58a42a7..63ba5347f 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -325,7 +325,7 @@ def clear_volume(viewer: "napari.viewer.Viewer", all_slices: bool = True) -> Non if all_slices: vutil.clear_annotations(viewer) else: - i = int(viewer.cursor.position[0]) + i = int(viewer.dims.point[0]) vutil.clear_annotations_slice(viewer, i=i) @@ -341,7 +341,7 @@ def clear_track(viewer: "napari.viewer.Viewer", all_frames: bool = True) -> None _reset_tracking_state(viewer) vutil.clear_annotations(viewer) else: - i = int(viewer.cursor.position[0]) + i = int(viewer.dims.point[0]) vutil.clear_annotations_slice(viewer, i=i) @@ -736,7 +736,9 @@ def segment_slice(viewer: "napari.viewer.Viewer") -> None: return None shape = viewer.layers["current_object"].data.shape[1:] - position = viewer.cursor.position + + position_world = viewer.dims.point + position = viewer.layers["point_prompts"].world_to_data(position_world) z = int(position[0]) point_prompts = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], z) @@ -775,7 +777,7 @@ def segment_frame(viewer: "napari.viewer.Viewer") -> None: return None state = AnnotatorState() shape = state.image_shape[1:] - position = viewer.cursor.position + position = viewer.dims.point t = int(position[0]) point_prompts = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], i=t, track_id=state.current_track_id) @@ -868,7 +870,9 @@ def __init__(self, parent=None): def _initialize_image(self): state = AnnotatorState() image_shape = self.image_selection.get_value().data.shape + image_scale = tuple(self.image_selection.get_value().scale) state.image_shape = image_shape + state.image_scale = image_scale def _create_image_section(self): image_section = QtWidgets.QVBoxLayout() @@ -1083,6 +1087,9 @@ def __call__(self, skip_validate=False): ndim = image.data.ndim state.image_shape = image.data.shape + # Set layer scale + state.image_scale = tuple(image.scale) + # Process tile_shape and halo, set other data. tile_shape, halo = _process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) save_path = None if self.embeddings_save_path == "" else self.embeddings_save_path @@ -1655,7 +1662,7 @@ def __call__(self): if self.volumetric and self.apply_to_volume: worker = self._run_segmentation_3d(kwargs) elif self.volumetric and not self.apply_to_volume: - i = int(self._viewer.cursor.position[0]) + i = int(self._viewer.dims.point[0]) worker = self._run_segmentation_2d(kwargs, i=i) else: worker = self._run_segmentation_2d(kwargs) diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index 1887f371e..7916ea142 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -178,7 +178,7 @@ def point_layer_to_prompts( this_points, this_labels = points, labels else: assert points.shape[1] == 3, f"{points.shape}" - mask = points[:, 0] == i + mask = np.round(points[:, 0]) == i this_points = points[mask][:, 1:] this_labels = labels[mask] assert len(this_points) == len(this_labels) @@ -355,7 +355,7 @@ def segment_slices_with_prompts( image_shape = shape[1:] seg = np.zeros(shape, dtype="uint32") - z_values = point_prompts.data[:, 0] + z_values = np.round(point_prompts.data[:, 0]) z_values_boxes = np.concatenate([box[:1, 0] for box in box_prompts.data]) if box_prompts.data else\ np.zeros(0, dtype="int")