Skip to content

Commit

Permalink
Download models with pooch (#276)
Browse files Browse the repository at this point in the history
Use pooch to download model weights and simplify some other functionality
---------

Co-authored-by: Constantin Pape <[email protected]>
Co-authored-by: Constantin Pape <[email protected]>
  • Loading branch information
3 people authored Nov 23, 2023
1 parent 23e974f commit c2a4e54
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 114 deletions.
2 changes: 1 addition & 1 deletion development/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def main():
args = parser.parse_args()

model_type = args.model_type
device = util._get_device(args.device)
device = util.get_device(args.device)
print("Running benchmarks for", model_type)
print("with device:", device)

Expand Down
2 changes: 1 addition & 1 deletion micro_sam/evaluation/model_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def generate_data_for_model_comparison(
output_folder: The folder where the samples will be saved.
model_type1: The first model to use for comparison.
The value needs to be a valid model_type for `micro_sam.util.get_sam_model`.
model_type1: The second model to use for comparison.
model_type2: The second model to use for comparison.
The value needs to be a valid model_type for `micro_sam.util.get_sam_model`.
n_samples: The number of samples to draw from the dataloader.
"""
Expand Down
2 changes: 1 addition & 1 deletion micro_sam/precompute_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def main():
parser = argparse.ArgumentParser(description="Compute the embeddings for an image.")
parser.add_argument("-i", "--input_path", required=True)
parser.add_argument("-o", "--output_path", required=True)
parser.add_argument("-m", "--model_type", default="vit_h")
parser.add_argument("-m", "--model_type", default=util._DEFAULT_MODEL)
parser.add_argument("-c", "--checkpoint_path", default=None)
parser.add_argument("-k", "--key")
parser.add_argument(
Expand Down
34 changes: 16 additions & 18 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from enum import Enum
import os
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Literal

from magicgui import magic_factory, widgets
from napari.qt.threading import thread_worker
import pooch
import zarr
from zarr.errors import PathNotFoundError

Expand All @@ -14,7 +12,7 @@
ImageEmbeddings,
get_sam_model,
precompute_image_embeddings,
_MODEL_URLS,
models,
_DEFAULT_MODEL,
_available_devices,
get_cache_directory,
Expand All @@ -23,21 +21,17 @@
if TYPE_CHECKING:
import napari

Model = Enum("Model", _MODEL_URLS)
available_devices_list = ["auto"] + _available_devices()


@magic_factory(
pbar={'visible': False, 'max': 0, 'value': 0, 'label': 'working...'},
call_button="Compute image embeddings",
device = {"choices": available_devices_list},
save_path={"mode": "d"}, # choose a directory
)
def embedding_widget(
pbar: widgets.ProgressBar,
image: "napari.layers.Image",
model: Model = Model.__getitem__(_DEFAULT_MODEL),
device = "auto",
model: Literal[tuple(models().urls.keys())] = _DEFAULT_MODEL,
device: Literal[tuple(["auto"] + _available_devices())] = "auto",
save_path: Optional[Path] = None, # where embeddings for this image are cached (optional)
optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights.
) -> ImageEmbeddings:
Expand All @@ -54,8 +48,9 @@ def embedding_widget(

@thread_worker(connect={'started': pbar.show, 'finished': pbar.hide})
def _compute_image_embedding(state, image_data, save_path, ndim=None,
device="auto", model=Model.__getitem__(_DEFAULT_MODEL),
optional_custom_weights=None):
device="auto", model=_DEFAULT_MODEL,
optional_custom_weights=None,
):
# Make sure save directory exists and is an empty directory
if save_path is not None:
os.makedirs(save_path, exist_ok=True)
Expand All @@ -71,19 +66,22 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None,
"The user selected 'save_path' is not a zarr array "
f"or empty directory: {save_path}"
)

# Initialize the model
state.predictor = get_sam_model(device=device, model_type=model.name,
checkpoint_path=optional_custom_weights)
state.predictor = get_sam_model(device=device, model_type=model, checkpoint_path=optional_custom_weights)
# Compute the image embeddings
state.image_embeddings = precompute_image_embeddings(
predictor = state.predictor,
input_ = image_data,
save_path = str(save_path),
predictor=state.predictor,
input_=image_data,
save_path=save_path,
ndim=ndim,
)
return state # returns napari._qt.qthreading.FunctionWorker

return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model, optional_custom_weights=optional_custom_weights)
return _compute_image_embedding(
state, image.data, save_path, ndim=ndim, device=device, model=model,
optional_custom_weights=optional_custom_weights
)


@magic_factory(
Expand Down
4 changes: 3 additions & 1 deletion micro_sam/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def get_trainable_sam_model(
"""Get the trainable sam model.
Args:
model_type: The type of the segment anything model.
model_type: The segment anything model that should be finetuned.
The weights of this model will be used for initialization, unless a
custom weight file is passed via `checkpoint_path`.
device: The device to use for training.
checkpoint_path: Path to a custom checkpoint from which to load the model weights.
freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
Expand Down
Loading

0 comments on commit c2a4e54

Please sign in to comment.