From 645dfbbd9beb3058e3474fedb7b230e3a58ac9d4 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 4 Sep 2023 10:11:57 -0400 Subject: [PATCH 1/4] Add initial support for vit tiny model form mobileSAM --- environment_cpu.yaml | 6 ++++-- micro_sam/util.py | 10 ++++++++-- test/test_prompt_based_segmentation.py | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/environment_cpu.yaml b/environment_cpu.yaml index ad91dab9..48277532 100644 --- a/environment_cpu.yaml +++ b/environment_cpu.yaml @@ -13,5 +13,7 @@ dependencies: - torchvision - torch_em >=0.5.1 - tqdm - # - pip: - # - git+https://github.com/facebookresearch/segment-anything.git + - timm + - pip: + - git+https://github.com/ChaoningZhang/MobileSAM.git + # - git+https://github.com/facebookresearch/segment-anything.git diff --git a/micro_sam/util.py b/micro_sam/util.py index 1ea21303..a3f2b7f7 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -21,7 +21,8 @@ from nifty.tools import blocking from skimage.measure import regionprops -from segment_anything import sam_model_registry, SamPredictor +# from segment_anything import sam_model_registry, SamPredictor +from mobile_sam import sam_model_registry, SamPredictor try: from napari.utils import progress as tqdm @@ -33,6 +34,8 @@ "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM + "vit_t": "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt", # first version of finetuned models on zenodo "vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1", "vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1", @@ -45,6 +48,8 @@ "vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e", "vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622", "vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912", + # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM + "vit_t": "6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f", # first version of finetuned models on zenodo "vit_h_lm": "9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26", "vit_b_lm": "5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd", @@ -53,6 +58,7 @@ } # this is required so that the downloaded file is not called 'download' _DOWNLOAD_NAMES = { + "vit_t": "vit_t_mobile_sam.pth", "vit_h_lm": "vit_h_lm.pth", "vit_b_lm": "vit_b_lm.pth", "vit_h_em": "vit_h_em.pth", @@ -144,7 +150,7 @@ def get_sam_model( # Our custom model types have a suffix "_...". This suffix needs to be stripped # before calling sam_model_registry. model_type_ = model_type[:5] - assert model_type_ in ("vit_h", "vit_b", "vit_l") + assert model_type_ in ("vit_h", "vit_b", "vit_l", "vit_t") sam = sam_model_registry[model_type_](checkpoint=checkpoint) sam.to(device=device) diff --git a/test/test_prompt_based_segmentation.py b/test/test_prompt_based_segmentation.py index 87f8ce87..f2a30071 100644 --- a/test/test_prompt_based_segmentation.py +++ b/test/test_prompt_based_segmentation.py @@ -18,7 +18,7 @@ def _get_input(shape=(256, 256)): @staticmethod def _get_model(image): - predictor = util.get_sam_model(model_type="vit_b") + predictor = util.get_sam_model(model_type="vit_t") image_embeddings = util.precompute_image_embeddings(predictor, image) util.set_precomputed(predictor, image_embeddings) return predictor From dbee590238078025c1b7dae20ed371dfaf70e865 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 4 Sep 2023 19:26:05 -0400 Subject: [PATCH 2/4] Improve support for vit-tiny --- micro_sam/util.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/micro_sam/util.py b/micro_sam/util.py index a3f2b7f7..b1e09bf2 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -21,8 +21,12 @@ from nifty.tools import blocking from skimage.measure import regionprops -# from segment_anything import sam_model_registry, SamPredictor -from mobile_sam import sam_model_registry, SamPredictor +try: + from mobile_sam import sam_model_registry, SamPredictor + VIT_T_SUPPORT = True +except ImportError: + from segment_anything import sam_model_registry, SamPredictor + VIT_T_SUPPORT = False try: from napari.utils import progress as tqdm @@ -151,6 +155,11 @@ def get_sam_model( # before calling sam_model_registry. model_type_ = model_type[:5] assert model_type_ in ("vit_h", "vit_b", "vit_l", "vit_t") + if model_type == "vit_t" and not VIT_T_SUPPORT: + raise RuntimeError( + "mobile_sam is required for the vit-tiny." + "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" + ) sam = sam_model_registry[model_type_](checkpoint=checkpoint) sam.to(device=device) @@ -213,7 +222,7 @@ def get_custom_sam_model( # copy the model weights from torch_em's training format sam_prefix = "sam." model_state = OrderedDict( - [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()] + [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()] ) sam.load_state_dict(model_state) sam.to(device) From f1dd0e8eb974e47ad8deb6ca58fb571d734b65e1 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 8 Sep 2023 23:48:44 +0200 Subject: [PATCH 3/4] Use different URL for vit-t and use it to speed up tests --- micro_sam/util.py | 2 +- test/test_instance_segmentation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/micro_sam/util.py b/micro_sam/util.py index b1e09bf2..54b904bd 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -39,7 +39,7 @@ "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM - "vit_t": "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt", + "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", # first version of finetuned models on zenodo "vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1", "vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1", diff --git a/test/test_instance_segmentation.py b/test/test_instance_segmentation.py index 8ee81614..ddf76656 100644 --- a/test/test_instance_segmentation.py +++ b/test/test_instance_segmentation.py @@ -38,7 +38,7 @@ def write_object(center, radius): @staticmethod def _get_model(image): - predictor = util.get_sam_model(model_type="vit_b") + predictor = util.get_sam_model(model_type="vit_t") image_embeddings = util.precompute_image_embeddings(predictor, image) return predictor, image_embeddings From 05eb6c3bb471eeb2f82c8e022143b355b4c181af Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 8 Sep 2023 23:52:31 +0200 Subject: [PATCH 4/4] Update envs --- environment_cpu.yaml | 1 + environment_gpu.yaml | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/environment_cpu.yaml b/environment_cpu.yaml index 48277532..59f0aaef 100644 --- a/environment_cpu.yaml +++ b/environment_cpu.yaml @@ -6,6 +6,7 @@ name: dependencies: - cpuonly - napari + - pip - pooch - python-elf >=0.4.8 - pytorch diff --git a/environment_gpu.yaml b/environment_gpu.yaml index 57700759..41edeb5d 100644 --- a/environment_gpu.yaml +++ b/environment_gpu.yaml @@ -6,6 +6,7 @@ name: sam dependencies: - napari + - pip - pooch - python-elf >=0.4.8 - pytorch @@ -14,5 +15,7 @@ dependencies: - torchvision - torch_em >=0.5.1 - tqdm - # - pip: - # - git+https://github.com/facebookresearch/segment-anything.git + - timm + - pip: + - git+https://github.com/ChaoningZhang/MobileSAM.git + # - git+https://github.com/facebookresearch/segment-anything.git