Skip to content

Commit

Permalink
Add initial support for vit tiny model form mobileSAM
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Sep 4, 2023
1 parent 9bad755 commit 645dfbb
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
6 changes: 4 additions & 2 deletions environment_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@ dependencies:
- torchvision
- torch_em >=0.5.1
- tqdm
# - pip:
# - git+https://github.com/facebookresearch/segment-anything.git
- timm
- pip:
- git+https://github.com/ChaoningZhang/MobileSAM.git
# - git+https://github.com/facebookresearch/segment-anything.git
10 changes: 8 additions & 2 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from nifty.tools import blocking
from skimage.measure import regionprops

from segment_anything import sam_model_registry, SamPredictor
# from segment_anything import sam_model_registry, SamPredictor
from mobile_sam import sam_model_registry, SamPredictor

try:
from napari.utils import progress as tqdm
Expand All @@ -33,6 +34,8 @@
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt",
# first version of finetuned models on zenodo
"vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1",
"vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1",
Expand All @@ -45,6 +48,8 @@
"vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e",
"vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622",
"vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912",
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f",
# first version of finetuned models on zenodo
"vit_h_lm": "9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26",
"vit_b_lm": "5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd",
Expand All @@ -53,6 +58,7 @@
}
# this is required so that the downloaded file is not called 'download'
_DOWNLOAD_NAMES = {
"vit_t": "vit_t_mobile_sam.pth",
"vit_h_lm": "vit_h_lm.pth",
"vit_b_lm": "vit_b_lm.pth",
"vit_h_em": "vit_h_em.pth",
Expand Down Expand Up @@ -144,7 +150,7 @@ def get_sam_model(
# Our custom model types have a suffix "_...". This suffix needs to be stripped
# before calling sam_model_registry.
model_type_ = model_type[:5]
assert model_type_ in ("vit_h", "vit_b", "vit_l")
assert model_type_ in ("vit_h", "vit_b", "vit_l", "vit_t")

sam = sam_model_registry[model_type_](checkpoint=checkpoint)
sam.to(device=device)
Expand Down
2 changes: 1 addition & 1 deletion test/test_prompt_based_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _get_input(shape=(256, 256)):

@staticmethod
def _get_model(image):
predictor = util.get_sam_model(model_type="vit_b")
predictor = util.get_sam_model(model_type="vit_t")
image_embeddings = util.precompute_image_embeddings(predictor, image)
util.set_precomputed(predictor, image_embeddings)
return predictor
Expand Down

0 comments on commit 645dfbb

Please sign in to comment.