Skip to content

Commit

Permalink
Remove checkpoint_path from default sam model download (custom weight…
Browse files Browse the repository at this point in the history
…s should belong with the custom model download function)
  • Loading branch information
GenevieveBuckley committed Nov 14, 2023
1 parent 00006fc commit 926b4ba
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 16 deletions.
4 changes: 2 additions & 2 deletions examples/annotator_with_custom_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import h5py
import micro_sam.sam_annotator as annotator
from micro_sam.util import get_sam_model
from micro_sam.util import get_custom_sam_model

# TODO add an example for the 2d annotator with a custom model

Expand All @@ -11,7 +11,7 @@ def annotator_3d_with_custom_model():

custom_model = "/home/pape/Work/data/models/sam/user-study/vit_h_nuclei_em_finetuned.pt"
embedding_path = "./embeddings/nuclei3d-custom-vit-h.zarr"
predictor = get_sam_model(checkpoint_path=custom_model, model_type="vit_h")
predictor = get_custom_sam_model(checkpoint_path=custom_model, model_type="vit_h")
annotator.annotator_3d(raw, embedding_path, predictor=predictor)


Expand Down
2 changes: 1 addition & 1 deletion examples/finetuning/use_finetuned_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def run_annotator_with_custom_model():
# Adapt this if you finetune a different model type, e.g. vit_h.

# Load the custom model.
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint)
predictor = util.get_custom_sam_model(model_type=model_type, checkpoint_path=checkpoint)

# Run the 2d annotator with the custom model.
annotator_2d(
Expand Down
3 changes: 1 addition & 2 deletions examples/use_as_library/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def segmentation_in_3d():

# Load the SAM model for prediction.
model_type = "vit_b" # The model-type to use: vit_h, vit_l, vit_b etc.
checkpoint_path = None # You can specifiy the path to a custom (fine-tuned) model here.
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path)
predictor = util.get_sam_model(model_type=model_type)

# Run 3d segmentation for a given slice. Will segment all objects found in that slice
# throughout the volume.
Expand Down
4 changes: 1 addition & 3 deletions micro_sam/evaluation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,7 @@ def get_predictor(
)
else: # Vanilla SAM model
assert not return_state
predictor = util.get_sam_model(
model_type=model_type, device=device, checkpoint_path=checkpoint_path
) # type: ignore
predictor = util.get_sam_model(model_type=model_type, device=device) # type: ignore
return predictor


Expand Down
4 changes: 1 addition & 3 deletions micro_sam/precompute_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,14 @@ def precompute_state(
it can be given to provide a glob pattern to subselect files from the folder.
output_path: The output path were the embeddings and other state will be saved.
model_type: The SegmentAnything model to use. Will use the standard vit_h model by default.
checkpoint_path: Path to a checkpoint for a custom model.
key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr)
and can be used to provide a glob pattern if the input is a folder with image files.
ndim: The dimensionality of the data.
tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
halo: Overlap of the tiles for tiled prediction.
precompute_amg_state: Whether to precompute the state for automatic instance segmentation
in addition to the image embeddings.
"""
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path)
predictor = util.get_sam_model(model_type=model_type)

Check warning on line 134 in micro_sam/precompute_state.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/precompute_state.py#L134

Added line #L134 was not covered by tests
# check if we precompute the state for a single file or for a folder with image files
if os.path.isdir(input_path) and Path(input_path).suffix not in (".n5", ".zarr"):
pattern = "*" if key is None else key
Expand Down
8 changes: 3 additions & 5 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def embedding_widget(
model: Model = Model.__getitem__(_DEFAULT_MODEL),
device = "auto",
save_path: Optional[Path] = None, # where embeddings for this image are cached (optional)
optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights.
) -> ImageEmbeddings:
"""Image embedding widget."""
state = AnnotatorState()
Expand All @@ -53,7 +52,7 @@ def embedding_widget(
@thread_worker(connect={'started': pbar.show, 'finished': pbar.hide})
def _compute_image_embedding(state, image_data, save_path, ndim=None,
device="auto", model=Model.__getitem__(_DEFAULT_MODEL),
optional_custom_weights=None):
):
# Make sure save directory exists and is an empty directory
if save_path is not None:
os.makedirs(save_path, exist_ok=True)
Expand All @@ -70,8 +69,7 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None,
f"or empty directory: {save_path}"
)
# Initialize the model
state.predictor = get_sam_model(device=device, model_type=model.name,
checkpoint_path=optional_custom_weights)
state.predictor = get_sam_model(device=device, model_type=model.name)

Check warning on line 72 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L72

Added line #L72 was not covered by tests
# Compute the image embeddings
state.image_embeddings = precompute_image_embeddings(
predictor = state.predictor,
Expand All @@ -81,4 +79,4 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None,
)
return state # returns napari._qt.qthreading.FunctionWorker

return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model, optional_custom_weights=optional_custom_weights)
return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model)

0 comments on commit 926b4ba

Please sign in to comment.