diff --git a/examples/annotator_with_custom_model.py b/examples/annotator_with_custom_model.py index ceb8b2cb..deba6638 100644 --- a/examples/annotator_with_custom_model.py +++ b/examples/annotator_with_custom_model.py @@ -1,6 +1,6 @@ import h5py import micro_sam.sam_annotator as annotator -from micro_sam.util import get_sam_model +from micro_sam.util import get_custom_sam_model # TODO add an example for the 2d annotator with a custom model @@ -11,7 +11,7 @@ def annotator_3d_with_custom_model(): custom_model = "/home/pape/Work/data/models/sam/user-study/vit_h_nuclei_em_finetuned.pt" embedding_path = "./embeddings/nuclei3d-custom-vit-h.zarr" - predictor = get_sam_model(checkpoint_path=custom_model, model_type="vit_h") + predictor = get_custom_sam_model(checkpoint_path=custom_model, model_type="vit_h") annotator.annotator_3d(raw, embedding_path, predictor=predictor) diff --git a/examples/finetuning/use_finetuned_model.py b/examples/finetuning/use_finetuned_model.py index 19600241..e07b0e36 100644 --- a/examples/finetuning/use_finetuned_model.py +++ b/examples/finetuning/use_finetuned_model.py @@ -21,7 +21,7 @@ def run_annotator_with_custom_model(): # Adapt this if you finetune a different model type, e.g. vit_h. # Load the custom model. - predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint) + predictor = util.get_custom_sam_model(model_type=model_type, checkpoint_path=checkpoint) # Run the 2d annotator with the custom model. annotator_2d( diff --git a/examples/use_as_library/instance_segmentation.py b/examples/use_as_library/instance_segmentation.py index a447ed0a..bb58e1d3 100644 --- a/examples/use_as_library/instance_segmentation.py +++ b/examples/use_as_library/instance_segmentation.py @@ -101,8 +101,7 @@ def segmentation_in_3d(): # Load the SAM model for prediction. model_type = "vit_b" # The model-type to use: vit_h, vit_l, vit_b etc. - checkpoint_path = None # You can specifiy the path to a custom (fine-tuned) model here. - predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) + predictor = util.get_sam_model(model_type=model_type) # Run 3d segmentation for a given slice. Will segment all objects found in that slice # throughout the volume. diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index bfa1b160..bf9ea068 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -163,9 +163,7 @@ def get_predictor( ) else: # Vanilla SAM model assert not return_state - predictor = util.get_sam_model( - model_type=model_type, device=device, checkpoint_path=checkpoint_path - ) # type: ignore + predictor = util.get_sam_model(model_type=model_type, device=device) # type: ignore return predictor diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index 66e6edf5..fbef37a4 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -124,8 +124,6 @@ def precompute_state( it can be given to provide a glob pattern to subselect files from the folder. output_path: The output path were the embeddings and other state will be saved. model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. - checkpoint_path: Path to a checkpoint for a custom model. - key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) and can be used to provide a glob pattern if the input is a folder with image files. ndim: The dimensionality of the data. tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. @@ -133,7 +131,7 @@ def precompute_state( precompute_amg_state: Whether to precompute the state for automatic instance segmentation in addition to the image embeddings. """ - predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) + predictor = util.get_sam_model(model_type=model_type) # check if we precompute the state for a single file or for a folder with image files if os.path.isdir(input_path) and Path(input_path).suffix not in (".n5", ".zarr"): pattern = "*" if key is None else key diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index dc2bf2fe..c286d8eb 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -37,7 +37,6 @@ def embedding_widget( model: Model = Model.__getitem__(_DEFAULT_MODEL), device = "auto", save_path: Optional[Path] = None, # where embeddings for this image are cached (optional) - optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights. ) -> ImageEmbeddings: """Image embedding widget.""" state = AnnotatorState() @@ -53,7 +52,7 @@ def embedding_widget( @thread_worker(connect={'started': pbar.show, 'finished': pbar.hide}) def _compute_image_embedding(state, image_data, save_path, ndim=None, device="auto", model=Model.__getitem__(_DEFAULT_MODEL), - optional_custom_weights=None): + ): # Make sure save directory exists and is an empty directory if save_path is not None: os.makedirs(save_path, exist_ok=True) @@ -70,8 +69,7 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None, f"or empty directory: {save_path}" ) # Initialize the model - state.predictor = get_sam_model(device=device, model_type=model.name, - checkpoint_path=optional_custom_weights) + state.predictor = get_sam_model(device=device, model_type=model.name) # Compute the image embeddings state.image_embeddings = precompute_image_embeddings( predictor = state.predictor, @@ -81,4 +79,4 @@ def _compute_image_embedding(state, image_data, save_path, ndim=None, ) return state # returns napari._qt.qthreading.FunctionWorker - return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model, optional_custom_weights=optional_custom_weights) + return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model)