From 00006fc6a86d4f6888f7ef333cf070987ffabb51 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Tue, 14 Nov 2023 16:25:04 +1100 Subject: [PATCH 01/16] Use pooch to download model weights --- micro_sam/sam_annotator/_widgets.py | 4 +- micro_sam/util.py | 113 ++++++++-------------------- 2 files changed, 35 insertions(+), 82 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 887aae8a..dc2bf2fe 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -13,7 +13,7 @@ ImageEmbeddings, get_sam_model, precompute_image_embeddings, - _MODEL_URLS, + MODELS, _DEFAULT_MODEL, _available_devices, ) @@ -21,7 +21,7 @@ if TYPE_CHECKING: import napari -Model = Enum("Model", _MODEL_URLS) +Model = Enum("Model", MODELS.urls) available_devices_list = ["auto"] + _available_devices() diff --git a/micro_sam/util.py b/micro_sam/util.py index 28460479..60de9f89 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -35,42 +35,7 @@ except ImportError: from tqdm import tqdm -_MODEL_URLS = { - # the default segment anything models - "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", - "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", - "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", - # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM - "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", - # first version of finetuned models on zenodo - "vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1", - "vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1", - "vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1", - "vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1", -} -_CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam') -_CHECKPOINT_FOLDER = os.path.join(_CACHE_DIR, 'models') -_CHECKSUMS = { - # the default segment anything models - "vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e", - "vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622", - "vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912", - # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM - "vit_t": "6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f", - # first version of finetuned models on zenodo - "vit_h_lm": "9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26", - "vit_b_lm": "5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd", - "vit_h_em": "ae3798a0646c8df1d4db147998a2d37e402ff57d3aa4e571792fbb911d8a979c", - "vit_b_em": "c04a714a4e14a110f0eec055a65f7409d54e6bf733164d2933a0ce556f7d6f81", -} -# this is required so that the downloaded file is not called 'download' -_DOWNLOAD_NAMES = { - "vit_t": "vit_t_mobile_sam.pth", - "vit_h_lm": "vit_h_lm.pth", - "vit_b_lm": "vit_b_lm.pth", - "vit_h_em": "vit_h_em.pth", - "vit_b_em": "vit_b_em.pth", -} + # this is the default model used in micro_sam # currently set to the default vit_h _DEFAULT_MODEL = "vit_h" @@ -84,49 +49,37 @@ # # Functionality for model download and export # - - -def _download(url, path, model_type): - with requests.get(url, stream=True, verify=True) as r: - if r.status_code != 200: - r.raise_for_status() - raise RuntimeError(f"Request to {url} returned status code {r.status_code}") - file_size = int(r.headers.get("Content-Length", 0)) - desc = f"Download {url} to {path}" - if file_size == 0: - desc += " (unknown file size)" - with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f: - copyfileobj(r_raw, f) - - # validate the checksum - expected_checksum = _CHECKSUMS[model_type] - if expected_checksum is None: - return - with open(path, "rb") as f: - file_ = f.read() - checksum = hashlib.sha256(file_).hexdigest() - if checksum != expected_checksum: - raise RuntimeError( - "The checksum of the download does not match the expected checksum." - f"Expected: {expected_checksum}, got: {checksum}" - ) - print("Download successful and checksums agree.") - - -def _get_checkpoint(model_type, checkpoint_path=None): - if checkpoint_path is None: - checkpoint_url = _MODEL_URLS[model_type] - checkpoint_name = _DOWNLOAD_NAMES.get(model_type, checkpoint_url.split("/")[-1]) - checkpoint_path = os.path.join(_CHECKPOINT_FOLDER, checkpoint_name) - - # download the checkpoint if necessary - if not os.path.exists(checkpoint_path): - os.makedirs(_CHECKPOINT_FOLDER, exist_ok=True) - _download(checkpoint_url, checkpoint_path, model_type) - elif not os.path.exists(checkpoint_path): - raise ValueError(f"The checkpoint path {checkpoint_path} that was passed does not exist.") - - return checkpoint_path +MODELS = pooch.create( + path=pooch.os_cache(os.path.join("micro-sam", "models")), + base_url="", + registry={ + # the default segment anything models + "vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e", + "vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622", + "vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912", + # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM + "vit_t": "6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f", + # first version of finetuned models on zenodo + "vit_h_lm": "9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26", + "vit_b_lm": "5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd", + "vit_h_em": "ae3798a0646c8df1d4db147998a2d37e402ff57d3aa4e571792fbb911d8a979c", + "vit_b_em": "c04a714a4e14a110f0eec055a65f7409d54e6bf733164d2933a0ce556f7d6f81", + }, + # Now specify custom URLs for some of the files in the registry. + urls={ + # the default segment anything models + "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM + "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", + # first version of finetuned models on zenodo + "vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1", + "vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1", + "vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1", + "vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1", + }, +) def _get_default_device(): @@ -203,7 +156,7 @@ def get_sam_model( Returns: The segment anything predictor. """ - checkpoint = _get_checkpoint(model_type, checkpoint_path) + checkpoint = MODELS.fetch(model_type) device = _get_device(device) # Our custom model types have a suffix "_...". This suffix needs to be stripped From 926b4ba6d200c6cfafb3fb0cb02909eb42992f01 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Tue, 14 Nov 2023 16:52:26 +1100 Subject: [PATCH 02/16] Remove checkpoint_path from default sam model download (custom weights should belong with the custom model download function) --- examples/annotator_with_custom_model.py | 4 ++-- examples/finetuning/use_finetuned_model.py | 2 +- examples/use_as_library/instance_segmentation.py | 3 +-- micro_sam/evaluation/inference.py | 4 +--- micro_sam/precompute_state.py | 4 +--- micro_sam/sam_annotator/_widgets.py | 8 +++----- 6 files changed, 9 insertions(+), 16 deletions(-) diff --git a/examples/annotator_with_custom_model.py b/examples/annotator_with_custom_model.py index ceb8b2cb..deba6638 100644 --- a/examples/annotator_with_custom_model.py +++ b/examples/annotator_with_custom_model.py @@ -1,6 +1,6 @@ import h5py import micro_sam.sam_annotator as annotator -from micro_sam.util import get_sam_model +from micro_sam.util import get_custom_sam_model # TODO add an example for the 2d annotator with a custom model @@ -11,7 +11,7 @@ def annotator_3d_with_custom_model(): custom_model = "/home/pape/Work/data/models/sam/user-study/vit_h_nuclei_em_finetuned.pt" embedding_path = "./embeddings/nuclei3d-custom-vit-h.zarr" - predictor = get_sam_model(checkpoint_path=custom_model, model_type="vit_h") + predictor = get_custom_sam_model(checkpoint_path=custom_model, model_type="vit_h") annotator.annotator_3d(raw, embedding_path, predictor=predictor) diff --git a/examples/finetuning/use_finetuned_model.py b/examples/finetuning/use_finetuned_model.py index 19600241..e07b0e36 100644 --- a/examples/finetuning/use_finetuned_model.py +++ b/examples/finetuning/use_finetuned_model.py @@ -21,7 +21,7 @@ def run_annotator_with_custom_model(): # Adapt this if you finetune a different model type, e.g. vit_h. # Load the custom model. - predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint) + predictor = util.get_custom_sam_model(model_type=model_type, checkpoint_path=checkpoint) # Run the 2d annotator with the custom model. annotator_2d( diff --git a/examples/use_as_library/instance_segmentation.py b/examples/use_as_library/instance_segmentation.py index a447ed0a..bb58e1d3 100644 --- a/examples/use_as_library/instance_segmentation.py +++ b/examples/use_as_library/instance_segmentation.py @@ -101,8 +101,7 @@ def segmentation_in_3d(): # Load the SAM model for prediction. model_type = "vit_b" # The model-type to use: vit_h, vit_l, vit_b etc. - checkpoint_path = None # You can specifiy the path to a custom (fine-tuned) model here. - predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) + predictor = util.get_sam_model(model_type=model_type) # Run 3d segmentation for a given slice. Will segment all objects found in that slice # throughout the volume. diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index bfa1b160..bf9ea068 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -163,9 +163,7 @@ def get_predictor( ) else: # Vanilla SAM model assert not return_state - predictor = util.get_sam_model( - model_type=model_type, device=device, checkpoint_path=checkpoint_path - ) # type: ignore + predictor = util.get_sam_model(model_type=model_type, device=device) # type: ignore return predictor diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index 66e6edf5..fbef37a4 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -124,8 +124,6 @@ def precompute_state( it can be given to provide a glob pattern to subselect files from the folder. output_path: The output path were the embeddings and other state will be saved. model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. - checkpoint_path: Path to a checkpoint for a custom model. - key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) and can be used to provide a glob pattern if the input is a folder with image files. ndim: The dimensionality of the data. tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. @@ -133,7 +131,7 @@ def precompute_state( precompute_amg_state: Whether to precompute the state for automatic instance segmentation in addition to the image embeddings. """ - predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) + predictor = util.get_sam_model(model_type=model_type) # check if we precompute the state for a single file or for a folder with image files if os.path.isdir(input_path) and Path(input_path).suffix not in (".n5", ".zarr"): pattern = "*" if key is None else key diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index dc2bf2fe..c286d8eb 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -37,7 +37,6 @@ def embedding_widget( model: Model = Model.__getitem__(_DEFAULT_MODEL), device = "auto", save_path: Optional[Path] = None, # where embeddings for this image are cached (optional) - optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights. ) -> ImageEmbeddings: """Image embedding widget.""" state = AnnotatorState() @@ -53,7 +52,7 @@ def embedding_widget( @thread_worker(connect={'started': pbar.show, 'finished': pbar.hide}) def _compute_image_embedding(state, image_data, save_path, ndim=None, device="auto", model=Model.__getitem__(_DEFAULT_MODEL), - optional_custom_weights=None): + ): # Make sure save directory exists and is an empty directory if save_path is not None: os.makedirs(save_path, exist_ok=True) @@ -70,8 +69,7 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None, f"or empty directory: {save_path}" ) # Initialize the model - state.predictor = get_sam_model(device=device, model_type=model.name, - checkpoint_path=optional_custom_weights) + state.predictor = get_sam_model(device=device, model_type=model.name) # Compute the image embeddings state.image_embeddings = precompute_image_embeddings( predictor = state.predictor, @@ -81,4 +79,4 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None, ) return state # returns napari._qt.qthreading.FunctionWorker - return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model, optional_custom_weights=optional_custom_weights) + return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model) From 09bc898415264e747356d3dc2ea11177ee32c5ea Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Tue, 14 Nov 2023 17:04:04 +1100 Subject: [PATCH 03/16] Allow MICROSAM_CACHEDIR os environment variable --- micro_sam/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/micro_sam/util.py b/micro_sam/util.py index 60de9f89..bd7fed16 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -49,8 +49,9 @@ # # Functionality for model download and export # +_CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam') MODELS = pooch.create( - path=pooch.os_cache(os.path.join("micro-sam", "models")), + path=pooch.os_cache(os.path.join(_CACHE_DIR, "models")), base_url="", registry={ # the default segment anything models @@ -130,7 +131,6 @@ def _available_devices(): def get_sam_model( model_type: str = _DEFAULT_MODEL, device: Optional[str] = None, - checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_sam: bool = False, ) -> SamPredictor: r"""Get the SegmentAnything Predictor. From e5b1433fa3823167423679d4a92e3d0bb2753c19 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Tue, 14 Nov 2023 17:06:30 +1100 Subject: [PATCH 04/16] Remove checkpoint kwarg --- micro_sam/training/util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 52096462..ab75d9de 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -11,7 +11,6 @@ def get_trainable_sam_model( model_type: str = "vit_h", device: Optional[str] = None, - checkpoint_path: Optional[Union[str, os.PathLike]] = None, freeze: Optional[List[str]] = None, ) -> TrainableSAM: """Get the trainable sam model. @@ -28,7 +27,7 @@ def get_trainable_sam_model( """ # set the device here so that the correct one is passed to TrainableSAM below device = _get_device(device) - _, sam = get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path, return_sam=True) + _, sam = get_sam_model(model_type=model_type, device=device, return_sam=True) # freeze components of the model if freeze was passed # ideally we would want to add components in such a way that: From 84609548f0d53c1074d99e030dab531f2a03431c Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Wed, 22 Nov 2023 17:46:40 +1100 Subject: [PATCH 05/16] Bugfixes for image embedding widget --- micro_sam/sam_annotator/_widgets.py | 16 ++--- micro_sam/util.py | 88 +++++++++++++++---------- test/test_sam_annotator/test_widgets.py | 7 +- 3 files changed, 66 insertions(+), 45 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index c286d8eb..29adfcba 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -1,7 +1,6 @@ -from enum import Enum import os from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Literal from magicgui import magic_factory, widgets from napari.qt.threading import thread_worker @@ -13,7 +12,7 @@ ImageEmbeddings, get_sam_model, precompute_image_embeddings, - MODELS, + models, _DEFAULT_MODEL, _available_devices, ) @@ -21,20 +20,20 @@ if TYPE_CHECKING: import napari -Model = Enum("Model", MODELS.urls) available_devices_list = ["auto"] + _available_devices() @magic_factory( pbar={'visible': False, 'max': 0, 'value': 0, 'label': 'working...'}, call_button="Compute image embeddings", + # model={"choices": list(models().urls.keys())}, device = {"choices": available_devices_list}, save_path={"mode": "d"}, # choose a directory ) def embedding_widget( pbar: widgets.ProgressBar, image: "napari.layers.Image", - model: Model = Model.__getitem__(_DEFAULT_MODEL), + model: Literal[tuple(models().urls.keys())] = _DEFAULT_MODEL, device = "auto", save_path: Optional[Path] = None, # where embeddings for this image are cached (optional) ) -> ImageEmbeddings: @@ -51,7 +50,7 @@ def embedding_widget( @thread_worker(connect={'started': pbar.show, 'finished': pbar.hide}) def _compute_image_embedding(state, image_data, save_path, ndim=None, - device="auto", model=Model.__getitem__(_DEFAULT_MODEL), + device="auto", model=_DEFAULT_MODEL, ): # Make sure save directory exists and is an empty directory if save_path is not None: @@ -68,13 +67,14 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None, "The user selected 'save_path' is not a zarr array " f"or empty directory: {save_path}" ) + # Initialize the model - state.predictor = get_sam_model(device=device, model_type=model.name) + state.predictor = get_sam_model(device=device, model_type=model) # Compute the image embeddings state.image_embeddings = precompute_image_embeddings( predictor = state.predictor, input_ = image_data, - save_path = str(save_path), + save_path = save_path, ndim=ndim, ) return state # returns napari._qt.qthreading.FunctionWorker diff --git a/micro_sam/util.py b/micro_sam/util.py index bd7fed16..676ec3bf 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -49,38 +49,57 @@ # # Functionality for model download and export # -_CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam') -MODELS = pooch.create( - path=pooch.os_cache(os.path.join(_CACHE_DIR, "models")), - base_url="", - registry={ - # the default segment anything models - "vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e", - "vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622", - "vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912", - # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM - "vit_t": "6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f", - # first version of finetuned models on zenodo - "vit_h_lm": "9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26", - "vit_b_lm": "5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd", - "vit_h_em": "ae3798a0646c8df1d4db147998a2d37e402ff57d3aa4e571792fbb911d8a979c", - "vit_b_em": "c04a714a4e14a110f0eec055a65f7409d54e6bf733164d2933a0ce556f7d6f81", - }, - # Now specify custom URLs for some of the files in the registry. - urls={ - # the default segment anything models - "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", - "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", - "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", - # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM - "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", - # first version of finetuned models on zenodo - "vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1", - "vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1", - "vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1", - "vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1", - }, -) +def microsam_cachedir(): + """Return the micro-sam cache directory. + + Returns the top level cache directory for micro-sam models and sample data. + + Every time this function is called, we check for any user updates made to + the MICROSAM_CACHEDIR os environment variable since the last time. + """ + cache_directory = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam') + return cache_directory + + +def models(): + """Return the segmentation models registry. + + We recreate the model registry every time this function is called, + so any user changes to the default micro-sam cache directory location + are respected. + """ + models = pooch.create( + path=pooch.os_cache(os.path.join(microsam_cachedir(), "models")), + base_url="", + registry={ + # the default segment anything models + "vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e", + "vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622", + "vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912", + # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM + "vit_t": "6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f", + # first version of finetuned models on zenodo + "vit_h_lm": "9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26", + "vit_b_lm": "5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd", + "vit_h_em": "ae3798a0646c8df1d4db147998a2d37e402ff57d3aa4e571792fbb911d8a979c", + "vit_b_em": "c04a714a4e14a110f0eec055a65f7409d54e6bf733164d2933a0ce556f7d6f81", + }, + # Now specify custom URLs for some of the files in the registry. + urls={ + # the default segment anything models + "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM + "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", + # first version of finetuned models on zenodo + "vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1", + "vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1", + "vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1", + "vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1", + }, + ) + return models def _get_default_device(): @@ -156,7 +175,7 @@ def get_sam_model( Returns: The segment anything predictor. """ - checkpoint = MODELS.fetch(model_type) + checkpoint = models().fetch(model_type) device = _get_device(device) # Our custom model types have a suffix "_...". This suffix needs to be stripped @@ -172,7 +191,7 @@ def get_sam_model( sam = sam_model_registry[model_type_](checkpoint=checkpoint) sam.to(device=device) predictor = SamPredictor(sam) - predictor.model_type = model_type + predictor.model_type = model_type_ if return_sam: return predictor, sam return predictor @@ -527,6 +546,7 @@ def precompute_image_embeddings( assert save_path is not None, "Tiled prediction is only supported when the embeddings are saved to file." if save_path is not None: + save_path = str(save_path) data_signature = _compute_data_signature(input_) f = zarr.open(save_path, "a") diff --git a/test/test_sam_annotator/test_widgets.py b/test/test_sam_annotator/test_widgets.py index dd4adb6f..6846200a 100644 --- a/test/test_sam_annotator/test_widgets.py +++ b/test/test_sam_annotator/test_widgets.py @@ -1,3 +1,4 @@ +from enum import Enum import json import os @@ -7,8 +8,8 @@ import zarr from micro_sam.sam_annotator._state import AnnotatorState -from micro_sam.sam_annotator._widgets import embedding_widget, Model -from micro_sam.util import _compute_data_signature +from micro_sam.sam_annotator._widgets import embedding_widget +from micro_sam.util import _compute_data_signature, models # make_napari_viewer is a pytest fixture that returns a napari viewer object @@ -22,7 +23,7 @@ def test_embedding_widget(make_napari_viewer, tmp_path): layer = viewer.open_sample('napari', 'camera')[0] my_widget = embedding_widget() # run image embedding widget - worker = my_widget(image=layer, model=Model.vit_t, device="cpu", save_path=tmp_path) + worker = my_widget(image=layer, model="vit_t", device="cpu", save_path=tmp_path) worker.await_workers() # blocks until thread worker is finished the embedding # Check in-memory state - predictor assert isinstance(AnnotatorState().predictor, (SamPredictor, MobileSamPredictor)) From c00c6607bc4ab69d99e55579b986b68c43858353 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Wed, 22 Nov 2023 18:31:37 +1100 Subject: [PATCH 06/16] Restore get_sam_model checkpoint_path kwarg to select specific model checkpoint weights --- examples/annotator_with_custom_model.py | 4 +-- examples/finetuning/use_finetuned_model.py | 2 +- .../use_as_library/instance_segmentation.py | 3 +- micro_sam/evaluation/inference.py | 4 ++- micro_sam/precompute_state.py | 4 ++- micro_sam/sam_annotator/_widgets.py | 6 +--- micro_sam/util.py | 35 +++++++++++++++++-- 7 files changed, 45 insertions(+), 13 deletions(-) diff --git a/examples/annotator_with_custom_model.py b/examples/annotator_with_custom_model.py index deba6638..ceb8b2cb 100644 --- a/examples/annotator_with_custom_model.py +++ b/examples/annotator_with_custom_model.py @@ -1,6 +1,6 @@ import h5py import micro_sam.sam_annotator as annotator -from micro_sam.util import get_custom_sam_model +from micro_sam.util import get_sam_model # TODO add an example for the 2d annotator with a custom model @@ -11,7 +11,7 @@ def annotator_3d_with_custom_model(): custom_model = "/home/pape/Work/data/models/sam/user-study/vit_h_nuclei_em_finetuned.pt" embedding_path = "./embeddings/nuclei3d-custom-vit-h.zarr" - predictor = get_custom_sam_model(checkpoint_path=custom_model, model_type="vit_h") + predictor = get_sam_model(checkpoint_path=custom_model, model_type="vit_h") annotator.annotator_3d(raw, embedding_path, predictor=predictor) diff --git a/examples/finetuning/use_finetuned_model.py b/examples/finetuning/use_finetuned_model.py index e07b0e36..19600241 100644 --- a/examples/finetuning/use_finetuned_model.py +++ b/examples/finetuning/use_finetuned_model.py @@ -21,7 +21,7 @@ def run_annotator_with_custom_model(): # Adapt this if you finetune a different model type, e.g. vit_h. # Load the custom model. - predictor = util.get_custom_sam_model(model_type=model_type, checkpoint_path=checkpoint) + predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint) # Run the 2d annotator with the custom model. annotator_2d( diff --git a/examples/use_as_library/instance_segmentation.py b/examples/use_as_library/instance_segmentation.py index bb58e1d3..a447ed0a 100644 --- a/examples/use_as_library/instance_segmentation.py +++ b/examples/use_as_library/instance_segmentation.py @@ -101,7 +101,8 @@ def segmentation_in_3d(): # Load the SAM model for prediction. model_type = "vit_b" # The model-type to use: vit_h, vit_l, vit_b etc. - predictor = util.get_sam_model(model_type=model_type) + checkpoint_path = None # You can specifiy the path to a custom (fine-tuned) model here. + predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) # Run 3d segmentation for a given slice. Will segment all objects found in that slice # throughout the volume. diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index bf9ea068..bfa1b160 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -163,7 +163,9 @@ def get_predictor( ) else: # Vanilla SAM model assert not return_state - predictor = util.get_sam_model(model_type=model_type, device=device) # type: ignore + predictor = util.get_sam_model( + model_type=model_type, device=device, checkpoint_path=checkpoint_path + ) # type: ignore return predictor diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index fbef37a4..52ee539d 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -125,13 +125,15 @@ def precompute_state( output_path: The output path were the embeddings and other state will be saved. model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. and can be used to provide a glob pattern if the input is a folder with image files. + checkpoint_path: Path to a checkpoint for a custom model. + key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) ndim: The dimensionality of the data. tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. halo: Overlap of the tiles for tiled prediction. precompute_amg_state: Whether to precompute the state for automatic instance segmentation in addition to the image embeddings. """ - predictor = util.get_sam_model(model_type=model_type) + predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) # check if we precompute the state for a single file or for a folder with image files if os.path.isdir(input_path) and Path(input_path).suffix not in (".n5", ".zarr"): pattern = "*" if key is None else key diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 29adfcba..a231eea7 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -20,21 +20,17 @@ if TYPE_CHECKING: import napari -available_devices_list = ["auto"] + _available_devices() - @magic_factory( pbar={'visible': False, 'max': 0, 'value': 0, 'label': 'working...'}, call_button="Compute image embeddings", - # model={"choices": list(models().urls.keys())}, - device = {"choices": available_devices_list}, save_path={"mode": "d"}, # choose a directory ) def embedding_widget( pbar: widgets.ProgressBar, image: "napari.layers.Image", model: Literal[tuple(models().urls.keys())] = _DEFAULT_MODEL, - device = "auto", + device: Literal[tuple(["auto"] + _available_devices())]= "auto", save_path: Optional[Path] = None, # where embeddings for this image are cached (optional) ) -> ImageEmbeddings: """Image embedding widget.""" diff --git a/micro_sam/util.py b/micro_sam/util.py index 676ec3bf..7219460d 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -99,9 +99,35 @@ def models(): "vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1", }, ) + # This extra dictionary is needed for _get_checkpoint to download weights from a specific checkpoint + models.download_names={ + "vit_h": "vit_h.pth", + "vit_l": "vit_l.pth", + "vit_b": "vit_b.pth", + "vit_t": "vit_t_mobile_sam.pth", + "vit_h_lm": "vit_h_lm.pth", + "vit_b_lm": "vit_b_lm.pth", + "vit_h_em": "vit_h_em.pth", + "vit_b_em": "vit_b_em.pth", + } return models +def _get_checkpoint(model_name, checkpoint_path): + if checkpoint_path is None: + checkpoint_url = models().urls[model_name] + checkpoint_name = models().download_names.get(model_name, checkpoint_url.split("/")[-1]) + checkpoint_path = os.path.join(microsam_cachedir()/"models", checkpoint_name) + + # download the checkpoint if necessary + if not os.path.exists(checkpoint_path): + os.makedirs(microsam_cachedir()/"models", exist_ok=True) + pooch.retrieve(url=checkpoint_url, known_hash=models().registry.get(model_name)) + elif not os.path.exists(checkpoint_path): + raise ValueError(f"The checkpoint path {checkpoint_path} that was passed does not exist.") + return checkpoint_path + + def _get_default_device(): # Use cuda enabled gpu if it's available. if torch.cuda.is_available(): @@ -150,6 +176,7 @@ def _available_devices(): def get_sam_model( model_type: str = _DEFAULT_MODEL, device: Optional[str] = None, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_sam: bool = False, ) -> SamPredictor: r"""Get the SegmentAnything Predictor. @@ -175,14 +202,18 @@ def get_sam_model( Returns: The segment anything predictor. """ - checkpoint = models().fetch(model_type) device = _get_device(device) + if checkpoint_path is None: + checkpoint = models().fetch(model_type) + else: + checkpoint = _get_checkpoint(model_type, checkpoint_path) + # Our custom model types have a suffix "_...". This suffix needs to be stripped # before calling sam_model_registry. model_type_ = model_type[:5] assert model_type_ in ("vit_h", "vit_b", "vit_l", "vit_t") - if model_type == "vit_t" and not VIT_T_SUPPORT: + if model_type_ == "vit_t" and not VIT_T_SUPPORT: raise RuntimeError( "mobile_sam is required for the vit-tiny." "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" From 0387b95638ea085626434b6d99391887f83dfe7f Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Wed, 22 Nov 2023 18:47:53 +1100 Subject: [PATCH 07/16] Fix typo in docstring, mixed up order of text lines --- micro_sam/precompute_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index 52ee539d..66e6edf5 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -124,9 +124,9 @@ def precompute_state( it can be given to provide a glob pattern to subselect files from the folder. output_path: The output path were the embeddings and other state will be saved. model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. - and can be used to provide a glob pattern if the input is a folder with image files. checkpoint_path: Path to a checkpoint for a custom model. key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) + and can be used to provide a glob pattern if the input is a folder with image files. ndim: The dimensionality of the data. tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. halo: Overlap of the tiles for tiled prediction. From 32c8bb221bc2370002cb35a40c228b411ae5b744 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Wed, 22 Nov 2023 18:59:52 +1100 Subject: [PATCH 08/16] Remove unnecessary import --- test/test_sam_annotator/test_widgets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_sam_annotator/test_widgets.py b/test/test_sam_annotator/test_widgets.py index 6846200a..d0347c66 100644 --- a/test/test_sam_annotator/test_widgets.py +++ b/test/test_sam_annotator/test_widgets.py @@ -1,4 +1,3 @@ -from enum import Enum import json import os From 1c296857e9c5124af210f214ac29c7089a22b61a Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Wed, 22 Nov 2023 19:00:26 +1100 Subject: [PATCH 09/16] Finish restoring optional_custom_weights / checkpoint_path kwarg --- micro_sam/sam_annotator/_widgets.py | 6 ++++-- micro_sam/training/util.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index a231eea7..b0098b4a 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -32,6 +32,7 @@ def embedding_widget( model: Literal[tuple(models().urls.keys())] = _DEFAULT_MODEL, device: Literal[tuple(["auto"] + _available_devices())]= "auto", save_path: Optional[Path] = None, # where embeddings for this image are cached (optional) + optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights. ) -> ImageEmbeddings: """Image embedding widget.""" state = AnnotatorState() @@ -47,6 +48,7 @@ def embedding_widget( @thread_worker(connect={'started': pbar.show, 'finished': pbar.hide}) def _compute_image_embedding(state, image_data, save_path, ndim=None, device="auto", model=_DEFAULT_MODEL, + optional_custom_weights=None, ): # Make sure save directory exists and is an empty directory if save_path is not None: @@ -65,7 +67,7 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None, ) # Initialize the model - state.predictor = get_sam_model(device=device, model_type=model) + state.predictor = get_sam_model(device=device, model_type=model, checkpoint_path=optional_custom_weights) # Compute the image embeddings state.image_embeddings = precompute_image_embeddings( predictor = state.predictor, @@ -75,4 +77,4 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None, ) return state # returns napari._qt.qthreading.FunctionWorker - return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model) + return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model, optional_custom_weights=optional_custom_weights) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index ab75d9de..2f7964a7 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -12,6 +12,7 @@ def get_trainable_sam_model( model_type: str = "vit_h", device: Optional[str] = None, freeze: Optional[List[str]] = None, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, ) -> TrainableSAM: """Get the trainable sam model. @@ -27,7 +28,7 @@ def get_trainable_sam_model( """ # set the device here so that the correct one is passed to TrainableSAM below device = _get_device(device) - _, sam = get_sam_model(model_type=model_type, device=device, return_sam=True) + _, sam = get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path, return_sam=True) # freeze components of the model if freeze was passed # ideally we would want to add components in such a way that: From 865d6136ef90a3172721d57826f62e8d4d68a014 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Wed, 22 Nov 2023 19:02:09 +1100 Subject: [PATCH 10/16] Remove another unnecessary import --- test/test_sam_annotator/test_widgets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_sam_annotator/test_widgets.py b/test/test_sam_annotator/test_widgets.py index d0347c66..dc5e26f1 100644 --- a/test/test_sam_annotator/test_widgets.py +++ b/test/test_sam_annotator/test_widgets.py @@ -8,7 +8,7 @@ from micro_sam.sam_annotator._state import AnnotatorState from micro_sam.sam_annotator._widgets import embedding_widget -from micro_sam.util import _compute_data_signature, models +from micro_sam.util import _compute_data_signature # make_napari_viewer is a pytest fixture that returns a napari viewer object From 3b40099db4f384ac1e04fcfefb520f4210b988a2 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 22 Nov 2023 12:48:42 +0100 Subject: [PATCH 11/16] Update pooch download, rename model_type->model_name in get_sam_model --- examples/annotator_2d.py | 18 ++-- examples/annotator_3d.py | 6 +- examples/annotator_tracking.py | 6 +- examples/annotator_with_custom_model.py | 2 +- examples/image_series_annotator.py | 6 +- .../use_as_library/instance_segmentation.py | 4 +- micro_sam/evaluation/inference.py | 2 +- micro_sam/evaluation/model_comparison.py | 16 +-- micro_sam/precompute_state.py | 10 +- micro_sam/sam_annotator/_widgets.py | 16 +-- micro_sam/sam_annotator/annotator_2d.py | 6 +- micro_sam/sam_annotator/annotator_3d.py | 6 +- micro_sam/sam_annotator/annotator_tracking.py | 10 +- .../sam_annotator/image_series_annotator.py | 2 +- micro_sam/sam_annotator/util.py | 2 +- micro_sam/training/util.py | 9 +- micro_sam/util.py | 97 ++++++++++--------- test/test_util.py | 2 +- 18 files changed, 114 insertions(+), 106 deletions(-) diff --git a/examples/annotator_2d.py b/examples/annotator_2d.py index 86eff5c7..7e3ff33e 100644 --- a/examples/annotator_2d.py +++ b/examples/annotator_2d.py @@ -20,12 +20,12 @@ def livecell_annotator(use_finetuned_model): if use_finetuned_model: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell-vit_h_lm.zarr") - model_type = "vit_h_lm" + model_name = "vit_h_lm" else: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell.zarr") - model_type = "vit_h" + model_name = "vit_h" - annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type) + annotator_2d(image, embedding_path, show_embeddings=False, model_name=model_name) def hela_2d_annotator(use_finetuned_model): @@ -36,12 +36,12 @@ def hela_2d_annotator(use_finetuned_model): if use_finetuned_model: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-hela2d-vit_h_lm.zarr") - model_type = "vit_h_lm" + model_name = "vit_h_lm" else: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-hela2d.zarr") - model_type = "vit_h" + model_name = "vit_h" - annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type, precompute_amg_state=True) + annotator_2d(image, embedding_path, show_embeddings=False, model_name=model_name, precompute_amg_state=True) def wholeslide_annotator(use_finetuned_model): @@ -55,12 +55,12 @@ def wholeslide_annotator(use_finetuned_model): if use_finetuned_model: embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings-vit_h_lm.zarr") - model_type = "vit_h_lm" + model_name = "vit_h_lm" else: embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings.zarr") - model_type = "vit_h" + model_name = "vit_h" - annotator_2d(image, embedding_path, tile_shape=(1024, 1024), halo=(256, 256), model_type=model_type) + annotator_2d(image, embedding_path, tile_shape=(1024, 1024), halo=(256, 256), model_name=model_name) def main(): diff --git a/examples/annotator_3d.py b/examples/annotator_3d.py index 71d42d4a..a07e24dd 100644 --- a/examples/annotator_3d.py +++ b/examples/annotator_3d.py @@ -20,13 +20,13 @@ def em_3d_annotator(use_finetuned_model): if use_finetuned_model: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-lucchi-vit_h_em.zarr") - model_type = "vit_h_em" + model_name = "vit_h_em" else: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-lucchi.zarr") - model_type = "vit_h" + model_name = "vit_h" # start the annotator, cache the embeddings - annotator_3d(raw, embedding_path, model_type=model_type, show_embeddings=False) + annotator_3d(raw, embedding_path, model_name=model_name, show_embeddings=False) def main(): diff --git a/examples/annotator_tracking.py b/examples/annotator_tracking.py index 369350e9..6950703e 100644 --- a/examples/annotator_tracking.py +++ b/examples/annotator_tracking.py @@ -21,13 +21,13 @@ def track_ctc_data(use_finetuned_model): if use_finetuned_model: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc-vit_h_lm.zarr") - model_type = "vit_h_lm" + model_name = "vit_h_lm" else: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc.zarr") - model_type = "vit_h" + model_name = "vit_h" # start the annotator with cached embeddings - annotator_tracking(timeseries, embedding_path=embedding_path, show_embeddings=False, model_type=model_type) + annotator_tracking(timeseries, embedding_path=embedding_path, show_embeddings=False, model_name=model_name) def main(): diff --git a/examples/annotator_with_custom_model.py b/examples/annotator_with_custom_model.py index ceb8b2cb..59efd483 100644 --- a/examples/annotator_with_custom_model.py +++ b/examples/annotator_with_custom_model.py @@ -11,7 +11,7 @@ def annotator_3d_with_custom_model(): custom_model = "/home/pape/Work/data/models/sam/user-study/vit_h_nuclei_em_finetuned.pt" embedding_path = "./embeddings/nuclei3d-custom-vit-h.zarr" - predictor = get_sam_model(checkpoint_path=custom_model, model_type="vit_h") + predictor = get_sam_model(checkpoint_path=custom_model, model_name="vit_h") annotator.annotator_3d(raw, embedding_path, predictor=predictor) diff --git a/examples/image_series_annotator.py b/examples/image_series_annotator.py index 7fcfa13b..4fcc83b9 100644 --- a/examples/image_series_annotator.py +++ b/examples/image_series_annotator.py @@ -15,15 +15,15 @@ def series_annotation(use_finetuned_model): if use_finetuned_model: embedding_path = os.path.join(EMBEDDING_CACHE, "series-embeddings-vit_h_lm") - model_type = "vit_h_lm" + model_name = "vit_h_lm" else: embedding_path = os.path.join(EMBEDDING_CACHE, "series-embeddings") - model_type = "vit_h" + model_name = "vit_h" example_data = fetch_image_series_example_data(DATA_CACHE) image_folder_annotator( example_data, "./series-segmentation-result", embedding_path=embedding_path, - pattern="*.tif", model_type=model_type, + pattern="*.tif", model_name=model_name, precompute_amg_state=True, ) diff --git a/examples/use_as_library/instance_segmentation.py b/examples/use_as_library/instance_segmentation.py index a447ed0a..ab95a6be 100644 --- a/examples/use_as_library/instance_segmentation.py +++ b/examples/use_as_library/instance_segmentation.py @@ -100,9 +100,9 @@ def segmentation_in_3d(): data = imageio.imread(path) # Load the SAM model for prediction. - model_type = "vit_b" # The model-type to use: vit_h, vit_l, vit_b etc. + model_name = "vit_b" # The model-type to use: vit_h, vit_l, vit_b etc. checkpoint_path = None # You can specifiy the path to a custom (fine-tuned) model here. - predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) + predictor = util.get_sam_model(model_name=model_name, checkpoint_path=checkpoint_path) # Run 3d segmentation for a given slice. Will segment all objects found in that slice # throughout the volume. diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index 5bbf9ca0..a7d4a7bc 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -165,7 +165,7 @@ def get_predictor( else: # Vanilla SAM model assert not return_state predictor = util.get_sam_model( - model_type=model_type, device=device, checkpoint_path=checkpoint_path + model_name=model_type, device=device, checkpoint_path=checkpoint_path ) # type: ignore return predictor diff --git a/micro_sam/evaluation/model_comparison.py b/micro_sam/evaluation/model_comparison.py index 702f944f..2459e0dd 100644 --- a/micro_sam/evaluation/model_comparison.py +++ b/micro_sam/evaluation/model_comparison.py @@ -96,8 +96,8 @@ def _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, def generate_data_for_model_comparison( loader: torch.utils.data.DataLoader, output_folder: Union[str, os.PathLike], - model_type1: str, - model_type2: str, + model_name1: str, + model_name2: str, n_samples: int, ) -> None: """Generate samples for qualitative model comparison. @@ -107,10 +107,10 @@ def generate_data_for_model_comparison( Args: loader: The torch dataloader from which samples are drawn. output_folder: The folder where the samples will be saved. - model_type1: The first model to use for comparison. - The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. - model_type1: The second model to use for comparison. - The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. + model_name1: The first model to use for comparison. + The value needs to be a valid model_name for `micro_sam.util.get_sam_model`. + model_name1: The second model to use for comparison. + The value needs to be a valid model_name for `micro_sam.util.get_sam_model`. n_samples: The number of samples to draw from the dataloader. """ prompt_generator = PointAndBoxPromptGenerator( @@ -120,8 +120,8 @@ def generate_data_for_model_comparison( get_point_prompts=True, get_box_prompts=True, ) - predictor1 = util.get_sam_model(model_type=model_type1) - predictor2 = util.get_sam_model(model_type=model_type2) + predictor1 = util.get_sam_model(model_name=model_name1) + predictor2 = util.get_sam_model(model_name=model_name2) _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, output_folder) diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index 66e6edf5..d5937f01 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -107,7 +107,7 @@ def _precompute_state_for_files( def precompute_state( input_path: Union[os.PathLike, str], output_path: Union[os.PathLike, str], - model_type: str = util._DEFAULT_MODEL, + model_name: str = util._DEFAULT_MODEL, checkpoint_path: Optional[Union[os.PathLike, str]] = None, key: Optional[str] = None, ndim: Optional[int] = None, @@ -123,7 +123,7 @@ def precompute_state( In case of a container file the argument `key` must be given. In case of a folder it can be given to provide a glob pattern to subselect files from the folder. output_path: The output path were the embeddings and other state will be saved. - model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. + model_name: The SegmentAnything model to use. Will use the standard vit_h model by default. checkpoint_path: Path to a checkpoint for a custom model. key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) and can be used to provide a glob pattern if the input is a folder with image files. @@ -133,7 +133,7 @@ def precompute_state( precompute_amg_state: Whether to precompute the state for automatic instance segmentation in addition to the image embeddings. """ - predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) + predictor = util.get_sam_model(model_name=model_name, checkpoint_path=checkpoint_path) # check if we precompute the state for a single file or for a folder with image files if os.path.isdir(input_path) and Path(input_path).suffix not in (".n5", ".zarr"): pattern = "*" if key is None else key @@ -158,7 +158,7 @@ def main(): parser = argparse.ArgumentParser(description="Compute the embeddings for an image.") parser.add_argument("-i", "--input_path", required=True) parser.add_argument("-o", "--output_path", required=True) - parser.add_argument("-m", "--model_type", default="vit_h") + parser.add_argument("-m", "--model_name", default=util._DEFAULT_MODEL) parser.add_argument("-c", "--checkpoint_path", default=None) parser.add_argument("-k", "--key") parser.add_argument( @@ -172,7 +172,7 @@ def main(): args = parser.parse_args() precompute_state( - args.input_path, args.output_path, args.model_type, args.checkpoint_path, + args.input_path, args.output_path, args.model_name, args.checkpoint_path, key=args.key, tile_shape=args.tile_shape, halo=args.halo, ndim=args.ndim, precompute_amg_state=args.precompute_amg_state, ) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 9f94664e..b5d264ca 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -4,7 +4,6 @@ from magicgui import magic_factory, widgets from napari.qt.threading import thread_worker -import pooch import zarr from zarr.errors import PathNotFoundError @@ -32,7 +31,7 @@ def embedding_widget( pbar: widgets.ProgressBar, image: "napari.layers.Image", model: Literal[tuple(models().urls.keys())] = _DEFAULT_MODEL, - device: Literal[tuple(["auto"] + _available_devices())]= "auto", + device: Literal[tuple(["auto"] + _available_devices())] = "auto", save_path: Optional[Path] = None, # where embeddings for this image are cached (optional) optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights. ) -> ImageEmbeddings: @@ -69,17 +68,20 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None, ) # Initialize the model - state.predictor = get_sam_model(device=device, model_type=model, checkpoint_path=optional_custom_weights) + state.predictor = get_sam_model(device=device, model_name=model, checkpoint_path=optional_custom_weights) # Compute the image embeddings state.image_embeddings = precompute_image_embeddings( - predictor = state.predictor, - input_ = image_data, - save_path = save_path, + predictor=state.predictor, + input_=image_data, + save_path=save_path, ndim=ndim, ) return state # returns napari._qt.qthreading.FunctionWorker - return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model, optional_custom_weights=optional_custom_weights) + return _compute_image_embedding( + state, image.data, save_path, ndim=ndim, device=device, model=model, + optional_custom_weights=optional_custom_weights + ) @magic_factory( diff --git a/micro_sam/sam_annotator/annotator_2d.py b/micro_sam/sam_annotator/annotator_2d.py index 9c8e002e..51d80ee1 100644 --- a/micro_sam/sam_annotator/annotator_2d.py +++ b/micro_sam/sam_annotator/annotator_2d.py @@ -195,7 +195,7 @@ def annotator_2d( embedding_path: Optional[str] = None, show_embeddings: bool = False, segmentation_result: Optional[np.ndarray] = None, - model_type: str = util._DEFAULT_MODEL, + model_name: str = util._DEFAULT_MODEL, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, @@ -214,7 +214,7 @@ def annotator_2d( segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer. - model_type: The Segment Anything model to use. For details on the available models check out + model_name: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. tile_shape: Shape of tiles for tiled embedding prediction. If `None` then the whole image is passed to Segment Anything. @@ -234,7 +234,7 @@ def annotator_2d( state = AnnotatorState() if predictor is None: - state.predictor = util.get_sam_model(model_type=model_type) + state.predictor = util.get_sam_model(model_name=model_name) else: state.predictor = predictor state.image_shape = _get_shape(raw) diff --git a/micro_sam/sam_annotator/annotator_3d.py b/micro_sam/sam_annotator/annotator_3d.py index 0bf8cb3c..9ca3bd78 100644 --- a/micro_sam/sam_annotator/annotator_3d.py +++ b/micro_sam/sam_annotator/annotator_3d.py @@ -211,7 +211,7 @@ def annotator_3d( embedding_path: Optional[str] = None, show_embeddings: bool = False, segmentation_result: Optional[np.ndarray] = None, - model_type: str = util._DEFAULT_MODEL, + model_name: str = util._DEFAULT_MODEL, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, @@ -228,7 +228,7 @@ def annotator_3d( segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer. - model_type: The Segment Anything model to use. For details on the available models check out + model_name: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. tile_shape: Shape of tiles for tiled embedding prediction. If `None` then the whole image is passed to Segment Anything. @@ -243,7 +243,7 @@ def annotator_3d( state = AnnotatorState() if predictor is None: - state.predictor = util.get_sam_model(model_type=model_type) + state.predictor = util.get_sam_model(model_name=model_name) else: state.predictor = predictor state.image_embeddings = util.precompute_image_embeddings( diff --git a/micro_sam/sam_annotator/annotator_tracking.py b/micro_sam/sam_annotator/annotator_tracking.py index 6edbc79f..2743b313 100644 --- a/micro_sam/sam_annotator/annotator_tracking.py +++ b/micro_sam/sam_annotator/annotator_tracking.py @@ -387,7 +387,7 @@ def annotator_tracking( embedding_path: Optional[str] = None, show_embeddings: bool = False, tracking_result: Optional[str] = None, - model_type: str = util._DEFAULT_MODEL, + model_name: str = util._DEFAULT_MODEL, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, @@ -404,14 +404,14 @@ def annotator_tracking( tracking_result: An initial tracking result to load. This can be used to correct tracking with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_tracks' layer. - model_type: The Segment Anything model to use. For details on the available models check out + model_name: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. tile_shape: Shape of tiles for tiled embedding prediction. If `None` then the whole image is passed to Segment Anything. halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. return_viewer: Whether to return the napari viewer to further modify it before starting the tool. predictor: The Segment Anything model. Passing this enables using fully custom models. - If you pass `predictor` then `model_type` will be ignored. + If you pass `predictor` then `model_name` will be ignored. Returns: The napari viewer, only returned if `return_viewer=True`. @@ -424,7 +424,7 @@ def annotator_tracking( state = AnnotatorState() if predictor is None: - state.predictor = util.get_sam_model(model_type=model_type) + state.predictor = util.get_sam_model(model_name=model_name) else: state.predictor = predictor state.image_embeddings = util.precompute_image_embeddings( @@ -587,6 +587,6 @@ def main(): annotator_tracking( raw, embedding_path=args.embedding_path, show_embeddings=args.show_embeddings, - tracking_result=tracking_result, model_type=args.model_type, + tracking_result=tracking_result, model_name=args.model_name, tile_shape=args.tile_shape, halo=args.halo, ) diff --git a/micro_sam/sam_annotator/image_series_annotator.py b/micro_sam/sam_annotator/image_series_annotator.py index d561b7d5..f556c11f 100644 --- a/micro_sam/sam_annotator/image_series_annotator.py +++ b/micro_sam/sam_annotator/image_series_annotator.py @@ -45,7 +45,7 @@ def image_series_annotator( next_image_id = 0 if predictor is None: - predictor = util.get_sam_model(model_type=kwargs.get("model_type", util._DEFAULT_MODEL)) + predictor = util.get_sam_model(model_name=kwargs.get("model_name", util._DEFAULT_MODEL)) if embedding_path is None: embedding_paths = None else: diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index c2f39e7a..a67d8058 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -450,7 +450,7 @@ def _initialize_parser(description, with_segmentation_result=True, with_show_emb help="Visualize the embeddings computed by SegmentAnything. This can be helpful for debugging." ) parser.add_argument( - "--model_type", default=util._DEFAULT_MODEL, + "--model_name", default=util._DEFAULT_MODEL, help=f"The segment anything model that will be used, one of {available_models}." ) parser.add_argument( diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index eb8e4dd9..c424bddc 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -10,16 +10,17 @@ def get_trainable_sam_model( - model_type: str = "vit_h", + model_name: str = "vit_h", device: Optional[Union[str, torch.device]] = None, checkpoint_path: Optional[Union[str, os.PathLike]] = None, freeze: Optional[List[str]] = None, - checkpoint_path: Optional[Union[str, os.PathLike]] = None, ) -> TrainableSAM: """Get the trainable sam model. Args: - model_type: The type of the segment anything model. + model_name: The segment anything model that should be finetuned. + The weights of this model will be used for initialization, unless a + custom weight file is passed via `checkpoint_path`. device: The device to use for training. checkpoint_path: Path to a custom checkpoint from which to load the model weights. freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder @@ -30,7 +31,7 @@ def get_trainable_sam_model( """ # set the device here so that the correct one is passed to TrainableSAM below device = get_device(device) - _, sam = get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path, return_sam=True) + _, sam = get_sam_model(model_name=model_name, device=device, checkpoint_path=checkpoint_path, return_sam=True) # freeze components of the model if freeze was passed # ideally we would want to add components in such a way that: diff --git a/micro_sam/util.py b/micro_sam/util.py index b8c41549..99c30494 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -8,13 +8,11 @@ import pickle import warnings from collections import OrderedDict -from shutil import copyfileobj from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import imageio.v3 as imageio import numpy as np import pooch -import requests import torch import vigra import zarr @@ -41,6 +39,10 @@ # currently set to the default vit_h _DEFAULT_MODEL = "vit_h" +# The valid model types. Each type corresponds to the architecture of the +# vision transformer used within SAM. +_MODEL_TYPES = ("vit_h", "vit_b", "vit_l", "vit_t") + # TODO define the proper type for image embeddings ImageEmbeddings = Dict[str, Any] @@ -56,9 +58,11 @@ def get_cache_directory() -> None: cache_directory = Path(os.environ.get('MICROSAM_CACHEDIR', default_cache_directory)) return cache_directory + # # Functionality for model download and export # + def microsam_cachedir(): """Return the micro-sam cache directory. @@ -109,35 +113,9 @@ def models(): "vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1", }, ) - # This extra dictionary is needed for _get_checkpoint to download weights from a specific checkpoint - models.download_names={ - "vit_h": "vit_h.pth", - "vit_l": "vit_l.pth", - "vit_b": "vit_b.pth", - "vit_t": "vit_t_mobile_sam.pth", - "vit_h_lm": "vit_h_lm.pth", - "vit_b_lm": "vit_b_lm.pth", - "vit_h_em": "vit_h_em.pth", - "vit_b_em": "vit_b_em.pth", - } return models -def _get_checkpoint(model_name, checkpoint_path): - if checkpoint_path is None: - checkpoint_url = models().urls[model_name] - checkpoint_name = models().download_names.get(model_name, checkpoint_url.split("/")[-1]) - checkpoint_path = os.path.join(microsam_cachedir()/"models", checkpoint_name) - - # download the checkpoint if necessary - if not os.path.exists(checkpoint_path): - os.makedirs(microsam_cachedir()/"models", exist_ok=True) - pooch.retrieve(url=checkpoint_url, known_hash=models().registry.get(model_name)) - elif not os.path.exists(checkpoint_path): - raise ValueError(f"The checkpoint path {checkpoint_path} that was passed does not exist.") - return checkpoint_path - - def _get_default_device(): # check that we're in CI and use the CPU if we are # otherwise the tests may run out of memory on MAC if MPS is used. @@ -200,16 +178,23 @@ def _available_devices(): def get_sam_model( - model_type: str = _DEFAULT_MODEL, + model_name: str = _DEFAULT_MODEL, device: Optional[Union[str, torch.device]] = None, checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_sam: bool = False, ) -> SamPredictor: r"""Get the SegmentAnything Predictor. - This function will download the required model checkpoint or load it from file if it - was already downloaded. - This location can be changed by setting the environment variable: MICROSAM_CACHEDIR. + This function will download the required model or load it from the cached weight file. + This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. + The name of the requested model can be set via `model_name`. + See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models + for an overview of the available models + + Alternatively this function can also load a model from weights stored in a local filepath. + The corresponding file path is given via `checkpoint_path`. In this case `model_name` + must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for + a SAM model with vit_b encoder. By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg: @@ -220,35 +205,53 @@ def get_sam_model( https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html Args: - model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. + model_name: The SegmentAnything model to use. Will use the standard vit_h model by default. + To get a list of all available model names you can call `get_model_names`. device: The device for the model. If none is given will use GPU if available. - checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. + checkpoint_path: The path to a file with weights that should be used instead of using the + weights corresponding to `model_name`. If given, `model_name` must match the architecture + corresponding to the weight file. E.g. if you use weights for SAM with vit_b encoder + then `model_name` must be given as "vit_b". return_sam: Return the sam model object as well as the predictor. Returns: The segment anything predictor. """ - device = _get_device(device) + device = get_device(device) + # We support passing a local filepath to a checkpoint. + # In this case we do not download any weights but just use the local weight file, + # as it is, without copying it over anywhere or checking it's hashes. + + # checkpoint_path has not been passed, we download a known model and derive the correct + # URL from the model_name. If the model_name is invalid pooch will raise an error. if checkpoint_path is None: - checkpoint = models().fetch(model_type) + model_registry = models() + checkpoint = model_registry.fetch(model_name) + # checkpoint_path has been passed, we use it instead of downloading a model. else: - checkpoint = _get_checkpoint(model_type, checkpoint_path) + # Check if the file exists and raise an error otherwise. + # We can't check any hashes here, and we don't check if the file is actually a valid weight file. + # (If it isn't the model creation will fail below.) + if not os.path.exists(checkpoint_path): + raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.") + checkpoint = checkpoint_path - # Our custom model types have a suffix "_...". This suffix needs to be stripped + # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped # before calling sam_model_registry. - model_type_ = model_type[:5] - assert model_type_ in ("vit_h", "vit_b", "vit_l", "vit_t") - if model_type_ == "vit_t" and not VIT_T_SUPPORT: + model_type = model_name[:5] + if model_type not in _MODEL_TYPES: + raise ValueError(f"Invalid model_type: {model_type}. Expect one of {_MODEL_TYPES}") + if model_type == "vit_t" and not VIT_T_SUPPORT: raise RuntimeError( "mobile_sam is required for the vit-tiny." "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" ) - sam = sam_model_registry[model_type_](checkpoint=checkpoint) + sam = sam_model_registry[model_type](checkpoint=checkpoint) sam.to(device=device) predictor = SamPredictor(sam) - predictor.model_type = model_type_ + predictor.model_type = model_type if return_sam: return predictor, sam return predictor @@ -281,7 +284,7 @@ def get_custom_sam_model( Args: checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. - model_type: The SegmentAnything model to use. + model_type: The SegmentAnything model_type for the given checkpoint. device: The device for the model. If none is given will use GPU if available. return_sam: Return the sam model object as well as the predictor. return_state: Return the full state of the checkpoint in addition to the predictor. @@ -331,7 +334,7 @@ def export_custom_sam_model( Args: checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. - model_type: The SegmentAnything model type to use (vit_h, vit_b or vit_l). + model_type: The SegmentAnything model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). save_path: Where to save the exported model. """ _, state = get_custom_sam_model( @@ -346,7 +349,9 @@ def export_custom_sam_model( def get_model_names() -> Iterable: - return _MODEL_URLS.keys() + model_registry = models() + model_names = model_registry.registry.keys() + return model_names # diff --git a/test/test_util.py b/test/test_util.py index 505d0208..c0d4e5ca 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -34,7 +34,7 @@ def check_predictor(predictor): # check predictor with checkpoint path (using the cached model) checkpoint_path = os.path.join( - get_cache_directory(), "models", "vit_t_mobile_sam.pth" if VIT_T_SUPPORT else "sam_vit_b_01ec64.pth" + get_cache_directory(), "models", "vit_t" if VIT_T_SUPPORT else "vit_b" ) predictor = get_sam_model(model_type=self.model_type, checkpoint_path=checkpoint_path) check_predictor(predictor) From 355600ffe8f958ef4976fc8cf40f2f6eb581bfe4 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 23 Nov 2023 10:47:04 +1100 Subject: [PATCH 12/16] Fix remaining merge problems --- development/benchmark.py | 2 +- micro_sam/training/util.py | 1 - micro_sam/util.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/development/benchmark.py b/development/benchmark.py index b3d3ab15..ff5a50e7 100644 --- a/development/benchmark.py +++ b/development/benchmark.py @@ -180,7 +180,7 @@ def main(): args = parser.parse_args() model_type = args.model_type - device = util._get_device(args.device) + device = util.get_device(args.device) print("Running benchmarks for", model_type) print("with device:", device) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index eb8e4dd9..51c08759 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -14,7 +14,6 @@ def get_trainable_sam_model( device: Optional[Union[str, torch.device]] = None, checkpoint_path: Optional[Union[str, os.PathLike]] = None, freeze: Optional[List[str]] = None, - checkpoint_path: Optional[Union[str, os.PathLike]] = None, ) -> TrainableSAM: """Get the trainable sam model. diff --git a/micro_sam/util.py b/micro_sam/util.py index b8c41549..137e497c 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -228,7 +228,7 @@ def get_sam_model( Returns: The segment anything predictor. """ - device = _get_device(device) + device = get_device(device) if checkpoint_path is None: checkpoint = models().fetch(model_type) From ce52b292160a3188f95ebb32db72ddacc5af06c8 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 23 Nov 2023 17:13:24 +1100 Subject: [PATCH 13/16] kkeep model_type kwarg naming, rename model_type_ to abbreviated_model_type in get_sam_model function --- examples/annotator_2d.py | 18 ++++++------ examples/annotator_3d.py | 6 ++-- examples/annotator_tracking.py | 6 ++-- examples/annotator_with_custom_model.py | 2 +- examples/image_series_annotator.py | 6 ++-- .../use_as_library/instance_segmentation.py | 4 +-- micro_sam/evaluation/inference.py | 2 +- micro_sam/evaluation/model_comparison.py | 4 +-- micro_sam/precompute_state.py | 10 +++---- micro_sam/sam_annotator/_widgets.py | 2 +- micro_sam/sam_annotator/annotator_2d.py | 6 ++-- micro_sam/sam_annotator/annotator_3d.py | 6 ++-- micro_sam/sam_annotator/annotator_tracking.py | 10 +++---- .../sam_annotator/image_series_annotator.py | 2 +- micro_sam/sam_annotator/util.py | 2 +- micro_sam/training/util.py | 6 ++-- micro_sam/util.py | 28 +++++++++---------- 17 files changed, 60 insertions(+), 60 deletions(-) diff --git a/examples/annotator_2d.py b/examples/annotator_2d.py index 7e3ff33e..86eff5c7 100644 --- a/examples/annotator_2d.py +++ b/examples/annotator_2d.py @@ -20,12 +20,12 @@ def livecell_annotator(use_finetuned_model): if use_finetuned_model: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell-vit_h_lm.zarr") - model_name = "vit_h_lm" + model_type = "vit_h_lm" else: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-livecell.zarr") - model_name = "vit_h" + model_type = "vit_h" - annotator_2d(image, embedding_path, show_embeddings=False, model_name=model_name) + annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type) def hela_2d_annotator(use_finetuned_model): @@ -36,12 +36,12 @@ def hela_2d_annotator(use_finetuned_model): if use_finetuned_model: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-hela2d-vit_h_lm.zarr") - model_name = "vit_h_lm" + model_type = "vit_h_lm" else: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-hela2d.zarr") - model_name = "vit_h" + model_type = "vit_h" - annotator_2d(image, embedding_path, show_embeddings=False, model_name=model_name, precompute_amg_state=True) + annotator_2d(image, embedding_path, show_embeddings=False, model_type=model_type, precompute_amg_state=True) def wholeslide_annotator(use_finetuned_model): @@ -55,12 +55,12 @@ def wholeslide_annotator(use_finetuned_model): if use_finetuned_model: embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings-vit_h_lm.zarr") - model_name = "vit_h_lm" + model_type = "vit_h_lm" else: embedding_path = os.path.join(EMBEDDING_CACHE, "whole-slide-embeddings.zarr") - model_name = "vit_h" + model_type = "vit_h" - annotator_2d(image, embedding_path, tile_shape=(1024, 1024), halo=(256, 256), model_name=model_name) + annotator_2d(image, embedding_path, tile_shape=(1024, 1024), halo=(256, 256), model_type=model_type) def main(): diff --git a/examples/annotator_3d.py b/examples/annotator_3d.py index a07e24dd..71d42d4a 100644 --- a/examples/annotator_3d.py +++ b/examples/annotator_3d.py @@ -20,13 +20,13 @@ def em_3d_annotator(use_finetuned_model): if use_finetuned_model: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-lucchi-vit_h_em.zarr") - model_name = "vit_h_em" + model_type = "vit_h_em" else: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-lucchi.zarr") - model_name = "vit_h" + model_type = "vit_h" # start the annotator, cache the embeddings - annotator_3d(raw, embedding_path, model_name=model_name, show_embeddings=False) + annotator_3d(raw, embedding_path, model_type=model_type, show_embeddings=False) def main(): diff --git a/examples/annotator_tracking.py b/examples/annotator_tracking.py index 6950703e..369350e9 100644 --- a/examples/annotator_tracking.py +++ b/examples/annotator_tracking.py @@ -21,13 +21,13 @@ def track_ctc_data(use_finetuned_model): if use_finetuned_model: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc-vit_h_lm.zarr") - model_name = "vit_h_lm" + model_type = "vit_h_lm" else: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc.zarr") - model_name = "vit_h" + model_type = "vit_h" # start the annotator with cached embeddings - annotator_tracking(timeseries, embedding_path=embedding_path, show_embeddings=False, model_name=model_name) + annotator_tracking(timeseries, embedding_path=embedding_path, show_embeddings=False, model_type=model_type) def main(): diff --git a/examples/annotator_with_custom_model.py b/examples/annotator_with_custom_model.py index 59efd483..ceb8b2cb 100644 --- a/examples/annotator_with_custom_model.py +++ b/examples/annotator_with_custom_model.py @@ -11,7 +11,7 @@ def annotator_3d_with_custom_model(): custom_model = "/home/pape/Work/data/models/sam/user-study/vit_h_nuclei_em_finetuned.pt" embedding_path = "./embeddings/nuclei3d-custom-vit-h.zarr" - predictor = get_sam_model(checkpoint_path=custom_model, model_name="vit_h") + predictor = get_sam_model(checkpoint_path=custom_model, model_type="vit_h") annotator.annotator_3d(raw, embedding_path, predictor=predictor) diff --git a/examples/image_series_annotator.py b/examples/image_series_annotator.py index 4fcc83b9..7fcfa13b 100644 --- a/examples/image_series_annotator.py +++ b/examples/image_series_annotator.py @@ -15,15 +15,15 @@ def series_annotation(use_finetuned_model): if use_finetuned_model: embedding_path = os.path.join(EMBEDDING_CACHE, "series-embeddings-vit_h_lm") - model_name = "vit_h_lm" + model_type = "vit_h_lm" else: embedding_path = os.path.join(EMBEDDING_CACHE, "series-embeddings") - model_name = "vit_h" + model_type = "vit_h" example_data = fetch_image_series_example_data(DATA_CACHE) image_folder_annotator( example_data, "./series-segmentation-result", embedding_path=embedding_path, - pattern="*.tif", model_name=model_name, + pattern="*.tif", model_type=model_type, precompute_amg_state=True, ) diff --git a/examples/use_as_library/instance_segmentation.py b/examples/use_as_library/instance_segmentation.py index ab95a6be..a447ed0a 100644 --- a/examples/use_as_library/instance_segmentation.py +++ b/examples/use_as_library/instance_segmentation.py @@ -100,9 +100,9 @@ def segmentation_in_3d(): data = imageio.imread(path) # Load the SAM model for prediction. - model_name = "vit_b" # The model-type to use: vit_h, vit_l, vit_b etc. + model_type = "vit_b" # The model-type to use: vit_h, vit_l, vit_b etc. checkpoint_path = None # You can specifiy the path to a custom (fine-tuned) model here. - predictor = util.get_sam_model(model_name=model_name, checkpoint_path=checkpoint_path) + predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) # Run 3d segmentation for a given slice. Will segment all objects found in that slice # throughout the volume. diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index a7d4a7bc..5bbf9ca0 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -165,7 +165,7 @@ def get_predictor( else: # Vanilla SAM model assert not return_state predictor = util.get_sam_model( - model_name=model_type, device=device, checkpoint_path=checkpoint_path + model_type=model_type, device=device, checkpoint_path=checkpoint_path ) # type: ignore return predictor diff --git a/micro_sam/evaluation/model_comparison.py b/micro_sam/evaluation/model_comparison.py index 2459e0dd..c4d0ff8e 100644 --- a/micro_sam/evaluation/model_comparison.py +++ b/micro_sam/evaluation/model_comparison.py @@ -120,8 +120,8 @@ def generate_data_for_model_comparison( get_point_prompts=True, get_box_prompts=True, ) - predictor1 = util.get_sam_model(model_name=model_name1) - predictor2 = util.get_sam_model(model_name=model_name2) + predictor1 = util.get_sam_model(model_type=model_name1) + predictor2 = util.get_sam_model(model_type=model_name2) _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, output_folder) diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index d5937f01..6f20dd7a 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -107,7 +107,7 @@ def _precompute_state_for_files( def precompute_state( input_path: Union[os.PathLike, str], output_path: Union[os.PathLike, str], - model_name: str = util._DEFAULT_MODEL, + model_type: str = util._DEFAULT_MODEL, checkpoint_path: Optional[Union[os.PathLike, str]] = None, key: Optional[str] = None, ndim: Optional[int] = None, @@ -123,7 +123,7 @@ def precompute_state( In case of a container file the argument `key` must be given. In case of a folder it can be given to provide a glob pattern to subselect files from the folder. output_path: The output path were the embeddings and other state will be saved. - model_name: The SegmentAnything model to use. Will use the standard vit_h model by default. + model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. checkpoint_path: Path to a checkpoint for a custom model. key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) and can be used to provide a glob pattern if the input is a folder with image files. @@ -133,7 +133,7 @@ def precompute_state( precompute_amg_state: Whether to precompute the state for automatic instance segmentation in addition to the image embeddings. """ - predictor = util.get_sam_model(model_name=model_name, checkpoint_path=checkpoint_path) + predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) # check if we precompute the state for a single file or for a folder with image files if os.path.isdir(input_path) and Path(input_path).suffix not in (".n5", ".zarr"): pattern = "*" if key is None else key @@ -158,7 +158,7 @@ def main(): parser = argparse.ArgumentParser(description="Compute the embeddings for an image.") parser.add_argument("-i", "--input_path", required=True) parser.add_argument("-o", "--output_path", required=True) - parser.add_argument("-m", "--model_name", default=util._DEFAULT_MODEL) + parser.add_argument("-m", "--model_type", default=util._DEFAULT_MODEL) parser.add_argument("-c", "--checkpoint_path", default=None) parser.add_argument("-k", "--key") parser.add_argument( @@ -172,7 +172,7 @@ def main(): args = parser.parse_args() precompute_state( - args.input_path, args.output_path, args.model_name, args.checkpoint_path, + args.input_path, args.output_path, args.model_type, args.checkpoint_path, key=args.key, tile_shape=args.tile_shape, halo=args.halo, ndim=args.ndim, precompute_amg_state=args.precompute_amg_state, ) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index b5d264ca..b048f788 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -68,7 +68,7 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None, ) # Initialize the model - state.predictor = get_sam_model(device=device, model_name=model, checkpoint_path=optional_custom_weights) + state.predictor = get_sam_model(device=device, model_type=model, checkpoint_path=optional_custom_weights) # Compute the image embeddings state.image_embeddings = precompute_image_embeddings( predictor=state.predictor, diff --git a/micro_sam/sam_annotator/annotator_2d.py b/micro_sam/sam_annotator/annotator_2d.py index 51d80ee1..9c8e002e 100644 --- a/micro_sam/sam_annotator/annotator_2d.py +++ b/micro_sam/sam_annotator/annotator_2d.py @@ -195,7 +195,7 @@ def annotator_2d( embedding_path: Optional[str] = None, show_embeddings: bool = False, segmentation_result: Optional[np.ndarray] = None, - model_name: str = util._DEFAULT_MODEL, + model_type: str = util._DEFAULT_MODEL, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, @@ -214,7 +214,7 @@ def annotator_2d( segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer. - model_name: The Segment Anything model to use. For details on the available models check out + model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. tile_shape: Shape of tiles for tiled embedding prediction. If `None` then the whole image is passed to Segment Anything. @@ -234,7 +234,7 @@ def annotator_2d( state = AnnotatorState() if predictor is None: - state.predictor = util.get_sam_model(model_name=model_name) + state.predictor = util.get_sam_model(model_type=model_type) else: state.predictor = predictor state.image_shape = _get_shape(raw) diff --git a/micro_sam/sam_annotator/annotator_3d.py b/micro_sam/sam_annotator/annotator_3d.py index 9ca3bd78..0bf8cb3c 100644 --- a/micro_sam/sam_annotator/annotator_3d.py +++ b/micro_sam/sam_annotator/annotator_3d.py @@ -211,7 +211,7 @@ def annotator_3d( embedding_path: Optional[str] = None, show_embeddings: bool = False, segmentation_result: Optional[np.ndarray] = None, - model_name: str = util._DEFAULT_MODEL, + model_type: str = util._DEFAULT_MODEL, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, @@ -228,7 +228,7 @@ def annotator_3d( segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer. - model_name: The Segment Anything model to use. For details on the available models check out + model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. tile_shape: Shape of tiles for tiled embedding prediction. If `None` then the whole image is passed to Segment Anything. @@ -243,7 +243,7 @@ def annotator_3d( state = AnnotatorState() if predictor is None: - state.predictor = util.get_sam_model(model_name=model_name) + state.predictor = util.get_sam_model(model_type=model_type) else: state.predictor = predictor state.image_embeddings = util.precompute_image_embeddings( diff --git a/micro_sam/sam_annotator/annotator_tracking.py b/micro_sam/sam_annotator/annotator_tracking.py index 2743b313..6edbc79f 100644 --- a/micro_sam/sam_annotator/annotator_tracking.py +++ b/micro_sam/sam_annotator/annotator_tracking.py @@ -387,7 +387,7 @@ def annotator_tracking( embedding_path: Optional[str] = None, show_embeddings: bool = False, tracking_result: Optional[str] = None, - model_name: str = util._DEFAULT_MODEL, + model_type: str = util._DEFAULT_MODEL, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, @@ -404,14 +404,14 @@ def annotator_tracking( tracking_result: An initial tracking result to load. This can be used to correct tracking with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_tracks' layer. - model_name: The Segment Anything model to use. For details on the available models check out + model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. tile_shape: Shape of tiles for tiled embedding prediction. If `None` then the whole image is passed to Segment Anything. halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. return_viewer: Whether to return the napari viewer to further modify it before starting the tool. predictor: The Segment Anything model. Passing this enables using fully custom models. - If you pass `predictor` then `model_name` will be ignored. + If you pass `predictor` then `model_type` will be ignored. Returns: The napari viewer, only returned if `return_viewer=True`. @@ -424,7 +424,7 @@ def annotator_tracking( state = AnnotatorState() if predictor is None: - state.predictor = util.get_sam_model(model_name=model_name) + state.predictor = util.get_sam_model(model_type=model_type) else: state.predictor = predictor state.image_embeddings = util.precompute_image_embeddings( @@ -587,6 +587,6 @@ def main(): annotator_tracking( raw, embedding_path=args.embedding_path, show_embeddings=args.show_embeddings, - tracking_result=tracking_result, model_name=args.model_name, + tracking_result=tracking_result, model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo, ) diff --git a/micro_sam/sam_annotator/image_series_annotator.py b/micro_sam/sam_annotator/image_series_annotator.py index f556c11f..d561b7d5 100644 --- a/micro_sam/sam_annotator/image_series_annotator.py +++ b/micro_sam/sam_annotator/image_series_annotator.py @@ -45,7 +45,7 @@ def image_series_annotator( next_image_id = 0 if predictor is None: - predictor = util.get_sam_model(model_name=kwargs.get("model_name", util._DEFAULT_MODEL)) + predictor = util.get_sam_model(model_type=kwargs.get("model_type", util._DEFAULT_MODEL)) if embedding_path is None: embedding_paths = None else: diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index a67d8058..c2f39e7a 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -450,7 +450,7 @@ def _initialize_parser(description, with_segmentation_result=True, with_show_emb help="Visualize the embeddings computed by SegmentAnything. This can be helpful for debugging." ) parser.add_argument( - "--model_name", default=util._DEFAULT_MODEL, + "--model_type", default=util._DEFAULT_MODEL, help=f"The segment anything model that will be used, one of {available_models}." ) parser.add_argument( diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index c424bddc..46398cdb 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -10,7 +10,7 @@ def get_trainable_sam_model( - model_name: str = "vit_h", + model_type: str = "vit_h", device: Optional[Union[str, torch.device]] = None, checkpoint_path: Optional[Union[str, os.PathLike]] = None, freeze: Optional[List[str]] = None, @@ -18,7 +18,7 @@ def get_trainable_sam_model( """Get the trainable sam model. Args: - model_name: The segment anything model that should be finetuned. + model_type: The segment anything model that should be finetuned. The weights of this model will be used for initialization, unless a custom weight file is passed via `checkpoint_path`. device: The device to use for training. @@ -31,7 +31,7 @@ def get_trainable_sam_model( """ # set the device here so that the correct one is passed to TrainableSAM below device = get_device(device) - _, sam = get_sam_model(model_name=model_name, device=device, checkpoint_path=checkpoint_path, return_sam=True) + _, sam = get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path, return_sam=True) # freeze components of the model if freeze was passed # ideally we would want to add components in such a way that: diff --git a/micro_sam/util.py b/micro_sam/util.py index 99c30494..2b6e61ec 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -178,7 +178,7 @@ def _available_devices(): def get_sam_model( - model_name: str = _DEFAULT_MODEL, + model_type: str = _DEFAULT_MODEL, device: Optional[Union[str, torch.device]] = None, checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_sam: bool = False, @@ -187,12 +187,12 @@ def get_sam_model( This function will download the required model or load it from the cached weight file. This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. - The name of the requested model can be set via `model_name`. + The name of the requested model can be set via `model_type`. See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models for an overview of the available models Alternatively this function can also load a model from weights stored in a local filepath. - The corresponding file path is given via `checkpoint_path`. In this case `model_name` + The corresponding file path is given via `checkpoint_path`. In this case `model_type` must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for a SAM model with vit_b encoder. @@ -205,13 +205,13 @@ def get_sam_model( https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html Args: - model_name: The SegmentAnything model to use. Will use the standard vit_h model by default. + model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. To get a list of all available model names you can call `get_model_names`. device: The device for the model. If none is given will use GPU if available. checkpoint_path: The path to a file with weights that should be used instead of using the - weights corresponding to `model_name`. If given, `model_name` must match the architecture + weights corresponding to `model_type`. If given, `model_type` must match the architecture corresponding to the weight file. E.g. if you use weights for SAM with vit_b encoder - then `model_name` must be given as "vit_b". + then `model_type` must be given as "vit_b". return_sam: Return the sam model object as well as the predictor. Returns: @@ -224,10 +224,10 @@ def get_sam_model( # as it is, without copying it over anywhere or checking it's hashes. # checkpoint_path has not been passed, we download a known model and derive the correct - # URL from the model_name. If the model_name is invalid pooch will raise an error. + # URL from the model_type. If the model_type is invalid pooch will raise an error. if checkpoint_path is None: model_registry = models() - checkpoint = model_registry.fetch(model_name) + checkpoint = model_registry.fetch(model_type) # checkpoint_path has been passed, we use it instead of downloading a model. else: # Check if the file exists and raise an error otherwise. @@ -239,19 +239,19 @@ def get_sam_model( # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped # before calling sam_model_registry. - model_type = model_name[:5] - if model_type not in _MODEL_TYPES: - raise ValueError(f"Invalid model_type: {model_type}. Expect one of {_MODEL_TYPES}") - if model_type == "vit_t" and not VIT_T_SUPPORT: + abbreviated_model_type = model_type[:5] + if abbreviated_model_type not in _MODEL_TYPES: + raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") + if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: raise RuntimeError( "mobile_sam is required for the vit-tiny." "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" ) - sam = sam_model_registry[model_type](checkpoint=checkpoint) + sam = sam_model_registry[abbreviated_model_type](checkpoint=checkpoint) sam.to(device=device) predictor = SamPredictor(sam) - predictor.model_type = model_type + predictor.model_type = abbreviated_model_type if return_sam: return predictor, sam return predictor From 08232370d32b17a415ff519d74919de6d7df10e8 Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 23 Nov 2023 17:14:51 +1100 Subject: [PATCH 14/16] Revert accidental changes to model_comparison.py --- micro_sam/evaluation/model_comparison.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/micro_sam/evaluation/model_comparison.py b/micro_sam/evaluation/model_comparison.py index c4d0ff8e..8e98a6c2 100644 --- a/micro_sam/evaluation/model_comparison.py +++ b/micro_sam/evaluation/model_comparison.py @@ -96,8 +96,8 @@ def _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, def generate_data_for_model_comparison( loader: torch.utils.data.DataLoader, output_folder: Union[str, os.PathLike], - model_name1: str, - model_name2: str, + model_type1: str, + model_type2: str, n_samples: int, ) -> None: """Generate samples for qualitative model comparison. @@ -107,9 +107,9 @@ def generate_data_for_model_comparison( Args: loader: The torch dataloader from which samples are drawn. output_folder: The folder where the samples will be saved. - model_name1: The first model to use for comparison. + model_type1: The first model to use for comparison. The value needs to be a valid model_name for `micro_sam.util.get_sam_model`. - model_name1: The second model to use for comparison. + model_type2: The second model to use for comparison. The value needs to be a valid model_name for `micro_sam.util.get_sam_model`. n_samples: The number of samples to draw from the dataloader. """ @@ -120,8 +120,8 @@ def generate_data_for_model_comparison( get_point_prompts=True, get_box_prompts=True, ) - predictor1 = util.get_sam_model(model_type=model_name1) - predictor2 = util.get_sam_model(model_type=model_name2) + predictor1 = util.get_sam_model(model_type=model_type1) + predictor2 = util.get_sam_model(model_type=model_type2) _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, output_folder) From c1081b75bd7a3f18ebb10f2b56e8c7a0ec48352d Mon Sep 17 00:00:00 2001 From: Genevieve Buckley <30920819+GenevieveBuckley@users.noreply.github.com> Date: Thu, 23 Nov 2023 17:16:39 +1100 Subject: [PATCH 15/16] Revert another accidental change to model_comparison.py --- micro_sam/evaluation/model_comparison.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/micro_sam/evaluation/model_comparison.py b/micro_sam/evaluation/model_comparison.py index 8e98a6c2..6df9e8f4 100644 --- a/micro_sam/evaluation/model_comparison.py +++ b/micro_sam/evaluation/model_comparison.py @@ -108,9 +108,9 @@ def generate_data_for_model_comparison( loader: The torch dataloader from which samples are drawn. output_folder: The folder where the samples will be saved. model_type1: The first model to use for comparison. - The value needs to be a valid model_name for `micro_sam.util.get_sam_model`. + The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. model_type2: The second model to use for comparison. - The value needs to be a valid model_name for `micro_sam.util.get_sam_model`. + The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. n_samples: The number of samples to draw from the dataloader. """ prompt_generator = PointAndBoxPromptGenerator( From 22f9274466693322f39af036ac8cf39964c29004 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 23 Nov 2023 09:20:19 +0100 Subject: [PATCH 16/16] Remove redundant pooch.os_cache --- micro_sam/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_sam/util.py b/micro_sam/util.py index 2b6e61ec..7833ec9d 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -83,7 +83,7 @@ def models(): are respected. """ models = pooch.create( - path=pooch.os_cache(os.path.join(microsam_cachedir(), "models")), + path=os.path.join(microsam_cachedir(), "models"), base_url="", registry={ # the default segment anything models