diff --git a/development/benchmark.py b/development/benchmark.py index b3d3ab15..ff5a50e7 100644 --- a/development/benchmark.py +++ b/development/benchmark.py @@ -180,7 +180,7 @@ def main(): args = parser.parse_args() model_type = args.model_type - device = util._get_device(args.device) + device = util.get_device(args.device) print("Running benchmarks for", model_type) print("with device:", device) diff --git a/micro_sam/evaluation/model_comparison.py b/micro_sam/evaluation/model_comparison.py index 702f944f..6df9e8f4 100644 --- a/micro_sam/evaluation/model_comparison.py +++ b/micro_sam/evaluation/model_comparison.py @@ -109,7 +109,7 @@ def generate_data_for_model_comparison( output_folder: The folder where the samples will be saved. model_type1: The first model to use for comparison. The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. - model_type1: The second model to use for comparison. + model_type2: The second model to use for comparison. The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. n_samples: The number of samples to draw from the dataloader. """ diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index 66e6edf5..6f20dd7a 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -158,7 +158,7 @@ def main(): parser = argparse.ArgumentParser(description="Compute the embeddings for an image.") parser.add_argument("-i", "--input_path", required=True) parser.add_argument("-o", "--output_path", required=True) - parser.add_argument("-m", "--model_type", default="vit_h") + parser.add_argument("-m", "--model_type", default=util._DEFAULT_MODEL) parser.add_argument("-c", "--checkpoint_path", default=None) parser.add_argument("-k", "--key") parser.add_argument( diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 190584c5..b048f788 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -1,11 +1,9 @@ -from enum import Enum import os from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Literal from magicgui import magic_factory, widgets from napari.qt.threading import thread_worker -import pooch import zarr from zarr.errors import PathNotFoundError @@ -14,7 +12,7 @@ ImageEmbeddings, get_sam_model, precompute_image_embeddings, - _MODEL_URLS, + models, _DEFAULT_MODEL, _available_devices, get_cache_directory, @@ -23,21 +21,17 @@ if TYPE_CHECKING: import napari -Model = Enum("Model", _MODEL_URLS) -available_devices_list = ["auto"] + _available_devices() - @magic_factory( pbar={'visible': False, 'max': 0, 'value': 0, 'label': 'working...'}, call_button="Compute image embeddings", - device = {"choices": available_devices_list}, save_path={"mode": "d"}, # choose a directory ) def embedding_widget( pbar: widgets.ProgressBar, image: "napari.layers.Image", - model: Model = Model.__getitem__(_DEFAULT_MODEL), - device = "auto", + model: Literal[tuple(models().urls.keys())] = _DEFAULT_MODEL, + device: Literal[tuple(["auto"] + _available_devices())] = "auto", save_path: Optional[Path] = None, # where embeddings for this image are cached (optional) optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights. ) -> ImageEmbeddings: @@ -54,8 +48,9 @@ def embedding_widget( @thread_worker(connect={'started': pbar.show, 'finished': pbar.hide}) def _compute_image_embedding(state, image_data, save_path, ndim=None, - device="auto", model=Model.__getitem__(_DEFAULT_MODEL), - optional_custom_weights=None): + device="auto", model=_DEFAULT_MODEL, + optional_custom_weights=None, + ): # Make sure save directory exists and is an empty directory if save_path is not None: os.makedirs(save_path, exist_ok=True) @@ -71,19 +66,22 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None, "The user selected 'save_path' is not a zarr array " f"or empty directory: {save_path}" ) + # Initialize the model - state.predictor = get_sam_model(device=device, model_type=model.name, - checkpoint_path=optional_custom_weights) + state.predictor = get_sam_model(device=device, model_type=model, checkpoint_path=optional_custom_weights) # Compute the image embeddings state.image_embeddings = precompute_image_embeddings( - predictor = state.predictor, - input_ = image_data, - save_path = str(save_path), + predictor=state.predictor, + input_=image_data, + save_path=save_path, ndim=ndim, ) return state # returns napari._qt.qthreading.FunctionWorker - return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model, optional_custom_weights=optional_custom_weights) + return _compute_image_embedding( + state, image.data, save_path, ndim=ndim, device=device, model=model, + optional_custom_weights=optional_custom_weights + ) @magic_factory( diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 51c08759..46398cdb 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -18,7 +18,9 @@ def get_trainable_sam_model( """Get the trainable sam model. Args: - model_type: The type of the segment anything model. + model_type: The segment anything model that should be finetuned. + The weights of this model will be used for initialization, unless a + custom weight file is passed via `checkpoint_path`. device: The device to use for training. checkpoint_path: Path to a custom checkpoint from which to load the model weights. freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder diff --git a/micro_sam/util.py b/micro_sam/util.py index 50421b0a..7833ec9d 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -8,13 +8,11 @@ import pickle import warnings from collections import OrderedDict -from shutil import copyfileobj from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import imageio.v3 as imageio import numpy as np import pooch -import requests import torch import vigra import zarr @@ -36,45 +34,15 @@ except ImportError: from tqdm import tqdm -_MODEL_URLS = { - # the default segment anything models - "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", - "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", - "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", - # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM - "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", - # first version of finetuned models on zenodo - "vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1", - "vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1", - "vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1", - "vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1", -} - -_CHECKSUMS = { - # the default segment anything models - "vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e", - "vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622", - "vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912", - # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM - "vit_t": "6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f", - # first version of finetuned models on zenodo - "vit_h_lm": "9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26", - "vit_b_lm": "5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd", - "vit_h_em": "ae3798a0646c8df1d4db147998a2d37e402ff57d3aa4e571792fbb911d8a979c", - "vit_b_em": "c04a714a4e14a110f0eec055a65f7409d54e6bf733164d2933a0ce556f7d6f81", -} -# this is required so that the downloaded file is not called 'download' -_DOWNLOAD_NAMES = { - "vit_t": "vit_t_mobile_sam.pth", - "vit_h_lm": "vit_h_lm.pth", - "vit_b_lm": "vit_b_lm.pth", - "vit_h_em": "vit_h_em.pth", - "vit_b_em": "vit_b_em.pth", -} + # this is the default model used in micro_sam # currently set to the default vit_h _DEFAULT_MODEL = "vit_h" +# The valid model types. Each type corresponds to the architecture of the +# vision transformer used within SAM. +_MODEL_TYPES = ("vit_h", "vit_b", "vit_l", "vit_t") + # TODO define the proper type for image embeddings ImageEmbeddings = Dict[str, Any] @@ -90,53 +58,62 @@ def get_cache_directory() -> None: cache_directory = Path(os.environ.get('MICROSAM_CACHEDIR', default_cache_directory)) return cache_directory + # # Functionality for model download and export # +def microsam_cachedir(): + """Return the micro-sam cache directory. -def _download(url, path, model_type): - with requests.get(url, stream=True, verify=True) as r: - if r.status_code != 200: - r.raise_for_status() - raise RuntimeError(f"Request to {url} returned status code {r.status_code}") - file_size = int(r.headers.get("Content-Length", 0)) - desc = f"Download {url} to {path}" - if file_size == 0: - desc += " (unknown file size)" - with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f: - copyfileobj(r_raw, f) - - # validate the checksum - expected_checksum = _CHECKSUMS[model_type] - if expected_checksum is None: - return - with open(path, "rb") as f: - file_ = f.read() - checksum = hashlib.sha256(file_).hexdigest() - if checksum != expected_checksum: - raise RuntimeError( - "The checksum of the download does not match the expected checksum." - f"Expected: {expected_checksum}, got: {checksum}" - ) - print("Download successful and checksums agree.") + Returns the top level cache directory for micro-sam models and sample data. + Every time this function is called, we check for any user updates made to + the MICROSAM_CACHEDIR os environment variable since the last time. + """ + cache_directory = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam') + return cache_directory -def _get_checkpoint(model_type, checkpoint_path=None): - if checkpoint_path is None: - checkpoint_url = _MODEL_URLS[model_type] - checkpoint_name = _DOWNLOAD_NAMES.get(model_type, checkpoint_url.split("/")[-1]) - checkpoint_folder = os.path.join(get_cache_directory(), "models") - checkpoint_path = os.path.join(checkpoint_folder, checkpoint_name) - # download the checkpoint if necessary - if not os.path.exists(checkpoint_path): - os.makedirs(checkpoint_folder, exist_ok=True) - _download(checkpoint_url, checkpoint_path, model_type) - elif not os.path.exists(checkpoint_path): - raise ValueError(f"The checkpoint path {checkpoint_path} that was passed does not exist.") +def models(): + """Return the segmentation models registry. - return checkpoint_path + We recreate the model registry every time this function is called, + so any user changes to the default micro-sam cache directory location + are respected. + """ + models = pooch.create( + path=os.path.join(microsam_cachedir(), "models"), + base_url="", + registry={ + # the default segment anything models + "vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e", + "vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622", + "vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912", + # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM + "vit_t": "6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f", + # first version of finetuned models on zenodo + "vit_h_lm": "9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26", + "vit_b_lm": "5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd", + "vit_h_em": "ae3798a0646c8df1d4db147998a2d37e402ff57d3aa4e571792fbb911d8a979c", + "vit_b_em": "c04a714a4e14a110f0eec055a65f7409d54e6bf733164d2933a0ce556f7d6f81", + }, + # Now specify custom URLs for some of the files in the registry. + urls={ + # the default segment anything models + "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + # the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM + "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", + # first version of finetuned models on zenodo + "vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1", + "vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1", + "vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1", + "vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1", + }, + ) + return models def _get_default_device(): @@ -208,9 +185,16 @@ def get_sam_model( ) -> SamPredictor: r"""Get the SegmentAnything Predictor. - This function will download the required model checkpoint or load it from file if it - was already downloaded. - This location can be changed by setting the environment variable: MICROSAM_CACHEDIR. + This function will download the required model or load it from the cached weight file. + This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. + The name of the requested model can be set via `model_type`. + See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models + for an overview of the available models + + Alternatively this function can also load a model from weights stored in a local filepath. + The corresponding file path is given via `checkpoint_path`. In this case `model_type` + must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for + a SAM model with vit_b encoder. By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg: @@ -222,30 +206,52 @@ def get_sam_model( Args: model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. + To get a list of all available model names you can call `get_model_names`. device: The device for the model. If none is given will use GPU if available. - checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. + checkpoint_path: The path to a file with weights that should be used instead of using the + weights corresponding to `model_type`. If given, `model_type` must match the architecture + corresponding to the weight file. E.g. if you use weights for SAM with vit_b encoder + then `model_type` must be given as "vit_b". return_sam: Return the sam model object as well as the predictor. Returns: The segment anything predictor. """ - checkpoint = _get_checkpoint(model_type, checkpoint_path) device = get_device(device) - # Our custom model types have a suffix "_...". This suffix needs to be stripped + # We support passing a local filepath to a checkpoint. + # In this case we do not download any weights but just use the local weight file, + # as it is, without copying it over anywhere or checking it's hashes. + + # checkpoint_path has not been passed, we download a known model and derive the correct + # URL from the model_type. If the model_type is invalid pooch will raise an error. + if checkpoint_path is None: + model_registry = models() + checkpoint = model_registry.fetch(model_type) + # checkpoint_path has been passed, we use it instead of downloading a model. + else: + # Check if the file exists and raise an error otherwise. + # We can't check any hashes here, and we don't check if the file is actually a valid weight file. + # (If it isn't the model creation will fail below.) + if not os.path.exists(checkpoint_path): + raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.") + checkpoint = checkpoint_path + + # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped # before calling sam_model_registry. - model_type_ = model_type[:5] - assert model_type_ in ("vit_h", "vit_b", "vit_l", "vit_t") - if model_type == "vit_t" and not VIT_T_SUPPORT: + abbreviated_model_type = model_type[:5] + if abbreviated_model_type not in _MODEL_TYPES: + raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") + if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: raise RuntimeError( "mobile_sam is required for the vit-tiny." "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" ) - sam = sam_model_registry[model_type_](checkpoint=checkpoint) + sam = sam_model_registry[abbreviated_model_type](checkpoint=checkpoint) sam.to(device=device) predictor = SamPredictor(sam) - predictor.model_type = model_type + predictor.model_type = abbreviated_model_type if return_sam: return predictor, sam return predictor @@ -278,7 +284,7 @@ def get_custom_sam_model( Args: checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. - model_type: The SegmentAnything model to use. + model_type: The SegmentAnything model_type for the given checkpoint. device: The device for the model. If none is given will use GPU if available. return_sam: Return the sam model object as well as the predictor. return_state: Return the full state of the checkpoint in addition to the predictor. @@ -328,7 +334,7 @@ def export_custom_sam_model( Args: checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. - model_type: The SegmentAnything model type to use (vit_h, vit_b or vit_l). + model_type: The SegmentAnything model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). save_path: Where to save the exported model. """ _, state = get_custom_sam_model( @@ -343,7 +349,9 @@ def export_custom_sam_model( def get_model_names() -> Iterable: - return _MODEL_URLS.keys() + model_registry = models() + model_names = model_registry.registry.keys() + return model_names # @@ -600,6 +608,7 @@ def precompute_image_embeddings( assert save_path is not None, "Tiled prediction is only supported when the embeddings are saved to file." if save_path is not None: + save_path = str(save_path) data_signature = _compute_data_signature(input_) f = zarr.open(save_path, "a") diff --git a/test/test_sam_annotator/test_widgets.py b/test/test_sam_annotator/test_widgets.py index dd4adb6f..dc5e26f1 100644 --- a/test/test_sam_annotator/test_widgets.py +++ b/test/test_sam_annotator/test_widgets.py @@ -7,7 +7,7 @@ import zarr from micro_sam.sam_annotator._state import AnnotatorState -from micro_sam.sam_annotator._widgets import embedding_widget, Model +from micro_sam.sam_annotator._widgets import embedding_widget from micro_sam.util import _compute_data_signature @@ -22,7 +22,7 @@ def test_embedding_widget(make_napari_viewer, tmp_path): layer = viewer.open_sample('napari', 'camera')[0] my_widget = embedding_widget() # run image embedding widget - worker = my_widget(image=layer, model=Model.vit_t, device="cpu", save_path=tmp_path) + worker = my_widget(image=layer, model="vit_t", device="cpu", save_path=tmp_path) worker.await_workers() # blocks until thread worker is finished the embedding # Check in-memory state - predictor assert isinstance(AnnotatorState().predictor, (SamPredictor, MobileSamPredictor)) diff --git a/test/test_util.py b/test/test_util.py index 505d0208..c0d4e5ca 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -34,7 +34,7 @@ def check_predictor(predictor): # check predictor with checkpoint path (using the cached model) checkpoint_path = os.path.join( - get_cache_directory(), "models", "vit_t_mobile_sam.pth" if VIT_T_SUPPORT else "sam_vit_b_01ec64.pth" + get_cache_directory(), "models", "vit_t" if VIT_T_SUPPORT else "vit_b" ) predictor = get_sam_model(model_type=self.model_type, checkpoint_path=checkpoint_path) check_predictor(predictor)