Skip to content

Commit

Permalink
[Enh] Handle segmenting image layers that have non-1 layer.scale (#804)
Browse files Browse the repository at this point in the history
Handle segmenting image layers that have non-1 layer.scale
  • Loading branch information
psobolewskiPhD authored Dec 3, 2024
1 parent a2c2a7d commit 001878b
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 8 deletions.
6 changes: 5 additions & 1 deletion micro_sam/sam_annotator/_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,15 @@ def _update_image(self, segmentation_result=None):

# Reset all layers.
self._viewer.layers["current_object"].data = np.zeros(self._shape, dtype="uint32")
self._viewer.layers["current_object"].scale = state.image_scale
self._viewer.layers["auto_segmentation"].data = np.zeros(self._shape, dtype="uint32")
self._viewer.layers["auto_segmentation"].scale = state.image_scale
if segmentation_result is None or segmentation_result is False:
self._viewer.layers["committed_objects"].data = np.zeros(self._shape, dtype="uint32")
else:
assert segmentation_result.shape == self._shape
self._viewer.layers["committed_objects"].data = segmentation_result

self._viewer.layers["committed_objects"].scale = state.image_scale
self._viewer.layers["point_prompts"].scale = state.image_scale
self._viewer.layers["prompts"].scale = state.image_scale
vutil.clear_annotations(self._viewer, clear_segmentations=False)
2 changes: 2 additions & 0 deletions micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class AnnotatorState(metaclass=Singleton):
image_embeddings: Optional[util.ImageEmbeddings] = None
predictor: Optional[SamPredictor] = None
image_shape: Optional[Tuple[int, int]] = None
image_scale: Optional[Tuple[float, ...]] = None
embedding_path: Optional[str] = None
data_signature: Optional[str] = None

Expand Down Expand Up @@ -198,6 +199,7 @@ def reset_state(self):
self.image_embeddings = None
self.predictor = None
self.image_shape = None
self.image_scale = None
self.embedding_path = None
self.amg = None
self.amg_state = None
Expand Down
17 changes: 12 additions & 5 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def clear_volume(viewer: "napari.viewer.Viewer", all_slices: bool = True) -> Non
if all_slices:
vutil.clear_annotations(viewer)
else:
i = int(viewer.cursor.position[0])
i = int(viewer.dims.point[0])
vutil.clear_annotations_slice(viewer, i=i)


Expand All @@ -341,7 +341,7 @@ def clear_track(viewer: "napari.viewer.Viewer", all_frames: bool = True) -> None
_reset_tracking_state(viewer)
vutil.clear_annotations(viewer)
else:
i = int(viewer.cursor.position[0])
i = int(viewer.dims.point[0])
vutil.clear_annotations_slice(viewer, i=i)


Expand Down Expand Up @@ -736,7 +736,9 @@ def segment_slice(viewer: "napari.viewer.Viewer") -> None:
return None

shape = viewer.layers["current_object"].data.shape[1:]
position = viewer.cursor.position

position_world = viewer.dims.point
position = viewer.layers["point_prompts"].world_to_data(position_world)
z = int(position[0])

point_prompts = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], z)
Expand Down Expand Up @@ -775,7 +777,7 @@ def segment_frame(viewer: "napari.viewer.Viewer") -> None:
return None
state = AnnotatorState()
shape = state.image_shape[1:]
position = viewer.cursor.position
position = viewer.dims.point
t = int(position[0])

point_prompts = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], i=t, track_id=state.current_track_id)
Expand Down Expand Up @@ -868,7 +870,9 @@ def __init__(self, parent=None):
def _initialize_image(self):
state = AnnotatorState()
image_shape = self.image_selection.get_value().data.shape
image_scale = tuple(self.image_selection.get_value().scale)
state.image_shape = image_shape
state.image_scale = image_scale

def _create_image_section(self):
image_section = QtWidgets.QVBoxLayout()
Expand Down Expand Up @@ -1083,6 +1087,9 @@ def __call__(self, skip_validate=False):
ndim = image.data.ndim
state.image_shape = image.data.shape

# Set layer scale
state.image_scale = tuple(image.scale)

# Process tile_shape and halo, set other data.
tile_shape, halo = _process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
save_path = None if self.embeddings_save_path == "" else self.embeddings_save_path
Expand Down Expand Up @@ -1655,7 +1662,7 @@ def __call__(self):
if self.volumetric and self.apply_to_volume:
worker = self._run_segmentation_3d(kwargs)
elif self.volumetric and not self.apply_to_volume:
i = int(self._viewer.cursor.position[0])
i = int(self._viewer.dims.point[0])
worker = self._run_segmentation_2d(kwargs, i=i)
else:
worker = self._run_segmentation_2d(kwargs)
Expand Down
4 changes: 2 additions & 2 deletions micro_sam/sam_annotator/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def point_layer_to_prompts(
this_points, this_labels = points, labels
else:
assert points.shape[1] == 3, f"{points.shape}"
mask = points[:, 0] == i
mask = np.round(points[:, 0]) == i
this_points = points[mask][:, 1:]
this_labels = labels[mask]
assert len(this_points) == len(this_labels)
Expand Down Expand Up @@ -355,7 +355,7 @@ def segment_slices_with_prompts(
image_shape = shape[1:]
seg = np.zeros(shape, dtype="uint32")

z_values = point_prompts.data[:, 0]
z_values = np.round(point_prompts.data[:, 0])
z_values_boxes = np.concatenate([box[:1, 0] for box in box_prompts.data]) if box_prompts.data else\
np.zeros(0, dtype="int")

Expand Down

0 comments on commit 001878b

Please sign in to comment.