Skip to content

Commit

Permalink
Set points_per_batch in AMG to None and choose default based on device
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Sep 8, 2023
1 parent db666af commit 211a33c
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def __init__(
self,
predictor: SamPredictor,
points_per_side: Optional[int] = 32,
points_per_batch: int = 64,
points_per_batch: Optional[int] = None,
crop_n_layers: int = 0,
crop_overlap_ratio: float = 512 / 1500,
crop_n_points_downscale_factor: int = 1,
Expand All @@ -358,7 +358,13 @@ def __init__(

self._predictor = predictor
self._points_per_side = points_per_side

# we set the points per batch to 16 for mps for performance reasons
# and otherwise keep them at the default of 64
if points_per_batch is None:
points_per_batch = 16 if str(predictor.device) == "mps" else 64
self._points_per_batch = points_per_batch

self._crop_n_layers = crop_n_layers
self._crop_overlap_ratio = crop_overlap_ratio
self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
Expand Down

0 comments on commit 211a33c

Please sign in to comment.