Skip to content

Commit

Permalink
Merge pull request #174 from computational-cell-analytics/vit-tiny
Browse files Browse the repository at this point in the history
Add initial support for vit tiny model form mobileSAM
  • Loading branch information
constantinpape authored Sep 8, 2023
2 parents 3db6c0e + 05eb6c3 commit 0d8d5d4
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 9 deletions.
7 changes: 5 additions & 2 deletions environment_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ name:
dependencies:
- cpuonly
- napari
- pip
- pooch
- python-elf >=0.4.8
- pytorch
- segment-anything
- torchvision
- torch_em >=0.5.1
- tqdm
# - pip:
# - git+https://github.com/facebookresearch/segment-anything.git
- timm
- pip:
- git+https://github.com/ChaoningZhang/MobileSAM.git
# - git+https://github.com/facebookresearch/segment-anything.git
7 changes: 5 additions & 2 deletions environment_gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ name:
sam
dependencies:
- napari
- pip
- pooch
- python-elf >=0.4.8
- pytorch
Expand All @@ -14,5 +15,7 @@ dependencies:
- torchvision
- torch_em >=0.5.1
- tqdm
# - pip:
# - git+https://github.com/facebookresearch/segment-anything.git
- timm
- pip:
- git+https://github.com/ChaoningZhang/MobileSAM.git
# - git+https://github.com/facebookresearch/segment-anything.git
21 changes: 18 additions & 3 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from nifty.tools import blocking
from skimage.measure import regionprops

from segment_anything import sam_model_registry, SamPredictor
try:
from mobile_sam import sam_model_registry, SamPredictor
VIT_T_SUPPORT = True
except ImportError:
from segment_anything import sam_model_registry, SamPredictor
VIT_T_SUPPORT = False

try:
from napari.utils import progress as tqdm
Expand All @@ -33,6 +38,8 @@
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
# first version of finetuned models on zenodo
"vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1",
"vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1",
Expand All @@ -45,6 +52,8 @@
"vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e",
"vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622",
"vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912",
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f",
# first version of finetuned models on zenodo
"vit_h_lm": "9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26",
"vit_b_lm": "5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd",
Expand All @@ -53,6 +62,7 @@
}
# this is required so that the downloaded file is not called 'download'
_DOWNLOAD_NAMES = {
"vit_t": "vit_t_mobile_sam.pth",
"vit_h_lm": "vit_h_lm.pth",
"vit_b_lm": "vit_b_lm.pth",
"vit_h_em": "vit_h_em.pth",
Expand Down Expand Up @@ -161,7 +171,12 @@ def get_sam_model(
# Our custom model types have a suffix "_...". This suffix needs to be stripped
# before calling sam_model_registry.
model_type_ = model_type[:5]
assert model_type_ in ("vit_h", "vit_b", "vit_l")
assert model_type_ in ("vit_h", "vit_b", "vit_l", "vit_t")
if model_type == "vit_t" and not VIT_T_SUPPORT:
raise RuntimeError(
"mobile_sam is required for the vit-tiny."
"You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
)

sam = sam_model_registry[model_type_](checkpoint=checkpoint)
sam.to(device=device)
Expand Down Expand Up @@ -223,7 +238,7 @@ def get_custom_sam_model(
# copy the model weights from torch_em's training format
sam_prefix = "sam."
model_state = OrderedDict(
[(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()]
[(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()]
)
sam.load_state_dict(model_state)
sam.to(device)
Expand Down
2 changes: 1 addition & 1 deletion test/test_instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def write_object(center, radius):

@staticmethod
def _get_model(image):
predictor = util.get_sam_model(model_type="vit_b")
predictor = util.get_sam_model(model_type="vit_t")
image_embeddings = util.precompute_image_embeddings(predictor, image)
return predictor, image_embeddings

Expand Down
2 changes: 1 addition & 1 deletion test/test_prompt_based_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _get_input(shape=(256, 256)):

@staticmethod
def _get_model(image):
predictor = util.get_sam_model(model_type="vit_b")
predictor = util.get_sam_model(model_type="vit_t")
image_embeddings = util.precompute_image_embeddings(predictor, image)
util.set_precomputed(predictor, image_embeddings)
return predictor
Expand Down

0 comments on commit 0d8d5d4

Please sign in to comment.