Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Download models with pooch #276

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
00006fc
Use pooch to download model weights
GenevieveBuckley Nov 14, 2023
926b4ba
Remove checkpoint_path from default sam model download (custom weight…
GenevieveBuckley Nov 14, 2023
09bc898
Allow MICROSAM_CACHEDIR os environment variable
GenevieveBuckley Nov 14, 2023
e5b1433
Remove checkpoint kwarg
GenevieveBuckley Nov 14, 2023
8460954
Bugfixes for image embedding widget
GenevieveBuckley Nov 22, 2023
c00c660
Restore get_sam_model checkpoint_path kwarg to select specific model …
GenevieveBuckley Nov 22, 2023
0387b95
Fix typo in docstring, mixed up order of text lines
GenevieveBuckley Nov 22, 2023
32c8bb2
Remove unnecessary import
GenevieveBuckley Nov 22, 2023
1c29685
Finish restoring optional_custom_weights / checkpoint_path kwarg
GenevieveBuckley Nov 22, 2023
865d613
Remove another unnecessary import
GenevieveBuckley Nov 22, 2023
a7a19b9
Merge branch 'dev' into pooch-model-download
GenevieveBuckley Nov 22, 2023
3b40099
Update pooch download, rename model_type->model_name in get_sam_model
constantinpape Nov 22, 2023
355600f
Fix remaining merge problems
GenevieveBuckley Nov 22, 2023
ce52b29
kkeep model_type kwarg naming, rename model_type_ to abbreviated_mode…
GenevieveBuckley Nov 23, 2023
0823237
Revert accidental changes to model_comparison.py
GenevieveBuckley Nov 23, 2023
c1081b7
Revert another accidental change to model_comparison.py
GenevieveBuckley Nov 23, 2023
22f9274
Remove redundant pooch.os_cache
constantinpape Nov 23, 2023
4d7a406
Merge pull request #4 from computational-cell-analytics/pooch-update
constantinpape Nov 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/annotator_with_custom_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import h5py
import micro_sam.sam_annotator as annotator
from micro_sam.util import get_sam_model
from micro_sam.util import get_custom_sam_model

# TODO add an example for the 2d annotator with a custom model

Expand All @@ -11,7 +11,7 @@ def annotator_3d_with_custom_model():

custom_model = "/home/pape/Work/data/models/sam/user-study/vit_h_nuclei_em_finetuned.pt"
embedding_path = "./embeddings/nuclei3d-custom-vit-h.zarr"
predictor = get_sam_model(checkpoint_path=custom_model, model_type="vit_h")
predictor = get_custom_sam_model(checkpoint_path=custom_model, model_type="vit_h")
annotator.annotator_3d(raw, embedding_path, predictor=predictor)


Expand Down
2 changes: 1 addition & 1 deletion examples/finetuning/use_finetuned_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def run_annotator_with_custom_model():
# Adapt this if you finetune a different model type, e.g. vit_h.

# Load the custom model.
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint)
predictor = util.get_custom_sam_model(model_type=model_type, checkpoint_path=checkpoint)

# Run the 2d annotator with the custom model.
annotator_2d(
Expand Down
3 changes: 1 addition & 2 deletions examples/use_as_library/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def segmentation_in_3d():

# Load the SAM model for prediction.
model_type = "vit_b" # The model-type to use: vit_h, vit_l, vit_b etc.
checkpoint_path = None # You can specifiy the path to a custom (fine-tuned) model here.
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path)
predictor = util.get_sam_model(model_type=model_type)

# Run 3d segmentation for a given slice. Will segment all objects found in that slice
# throughout the volume.
Expand Down
4 changes: 1 addition & 3 deletions micro_sam/evaluation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,7 @@ def get_predictor(
)
else: # Vanilla SAM model
assert not return_state
predictor = util.get_sam_model(
model_type=model_type, device=device, checkpoint_path=checkpoint_path
) # type: ignore
predictor = util.get_sam_model(model_type=model_type, device=device) # type: ignore
return predictor


Expand Down
4 changes: 1 addition & 3 deletions micro_sam/precompute_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,14 @@
it can be given to provide a glob pattern to subselect files from the folder.
output_path: The output path were the embeddings and other state will be saved.
model_type: The SegmentAnything model to use. Will use the standard vit_h model by default.
checkpoint_path: Path to a checkpoint for a custom model.
key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr)
and can be used to provide a glob pattern if the input is a folder with image files.
ndim: The dimensionality of the data.
tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
halo: Overlap of the tiles for tiled prediction.
precompute_amg_state: Whether to precompute the state for automatic instance segmentation
in addition to the image embeddings.
"""
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path)
predictor = util.get_sam_model(model_type=model_type)

Check warning on line 134 in micro_sam/precompute_state.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/precompute_state.py#L134

Added line #L134 was not covered by tests
# check if we precompute the state for a single file or for a folder with image files
if os.path.isdir(input_path) and Path(input_path).suffix not in (".n5", ".zarr"):
pattern = "*" if key is None else key
Expand Down
12 changes: 5 additions & 7 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
ImageEmbeddings,
get_sam_model,
precompute_image_embeddings,
_MODEL_URLS,
MODELS,
_DEFAULT_MODEL,
_available_devices,
)

if TYPE_CHECKING:
import napari

Model = Enum("Model", _MODEL_URLS)
Model = Enum("Model", MODELS.urls)
available_devices_list = ["auto"] + _available_devices()


Expand All @@ -37,7 +37,6 @@
model: Model = Model.__getitem__(_DEFAULT_MODEL),
device = "auto",
save_path: Optional[Path] = None, # where embeddings for this image are cached (optional)
optional_custom_weights: Optional[Path] = None, # A filepath or URL to custom model weights.
) -> ImageEmbeddings:
"""Image embedding widget."""
state = AnnotatorState()
Expand All @@ -53,7 +52,7 @@
@thread_worker(connect={'started': pbar.show, 'finished': pbar.hide})
def _compute_image_embedding(state, image_data, save_path, ndim=None,
device="auto", model=Model.__getitem__(_DEFAULT_MODEL),
optional_custom_weights=None):
):
# Make sure save directory exists and is an empty directory
if save_path is not None:
os.makedirs(save_path, exist_ok=True)
Expand All @@ -70,8 +69,7 @@
f"or empty directory: {save_path}"
)
# Initialize the model
state.predictor = get_sam_model(device=device, model_type=model.name,
checkpoint_path=optional_custom_weights)
state.predictor = get_sam_model(device=device, model_type=model.name)

Check warning on line 72 in micro_sam/sam_annotator/_widgets.py

View check run for this annotation

Codecov / codecov/patch

micro_sam/sam_annotator/_widgets.py#L72

Added line #L72 was not covered by tests
# Compute the image embeddings
state.image_embeddings = precompute_image_embeddings(
predictor = state.predictor,
Expand All @@ -81,4 +79,4 @@
)
return state # returns napari._qt.qthreading.FunctionWorker

return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model, optional_custom_weights=optional_custom_weights)
return _compute_image_embedding(state, image.data, save_path, ndim=ndim, device=device, model=model)
3 changes: 1 addition & 2 deletions micro_sam/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
def get_trainable_sam_model(
model_type: str = "vit_h",
device: Optional[str] = None,
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
freeze: Optional[List[str]] = None,
) -> TrainableSAM:
"""Get the trainable sam model.
Expand All @@ -28,7 +27,7 @@ def get_trainable_sam_model(
"""
# set the device here so that the correct one is passed to TrainableSAM below
device = _get_device(device)
_, sam = get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path, return_sam=True)
_, sam = get_sam_model(model_type=model_type, device=device, return_sam=True)

# freeze components of the model if freeze was passed
# ideally we would want to add components in such a way that:
Expand Down
115 changes: 34 additions & 81 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,42 +35,7 @@
except ImportError:
from tqdm import tqdm

_MODEL_URLS = {
# the default segment anything models
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
# first version of finetuned models on zenodo
"vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1",
"vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1",
"vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1",
"vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1",
}
_CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam')
_CHECKPOINT_FOLDER = os.path.join(_CACHE_DIR, 'models')
_CHECKSUMS = {
# the default segment anything models
"vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e",
"vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622",
"vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912",
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f",
# first version of finetuned models on zenodo
"vit_h_lm": "9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26",
"vit_b_lm": "5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd",
"vit_h_em": "ae3798a0646c8df1d4db147998a2d37e402ff57d3aa4e571792fbb911d8a979c",
"vit_b_em": "c04a714a4e14a110f0eec055a65f7409d54e6bf733164d2933a0ce556f7d6f81",
}
# this is required so that the downloaded file is not called 'download'
_DOWNLOAD_NAMES = {
"vit_t": "vit_t_mobile_sam.pth",
"vit_h_lm": "vit_h_lm.pth",
"vit_b_lm": "vit_b_lm.pth",
"vit_h_em": "vit_h_em.pth",
"vit_b_em": "vit_b_em.pth",
}

# this is the default model used in micro_sam
# currently set to the default vit_h
_DEFAULT_MODEL = "vit_h"
Expand All @@ -84,49 +49,38 @@
#
# Functionality for model download and export
#


def _download(url, path, model_type):
with requests.get(url, stream=True, verify=True) as r:
if r.status_code != 200:
r.raise_for_status()
raise RuntimeError(f"Request to {url} returned status code {r.status_code}")
file_size = int(r.headers.get("Content-Length", 0))
desc = f"Download {url} to {path}"
if file_size == 0:
desc += " (unknown file size)"
with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f:
copyfileobj(r_raw, f)

# validate the checksum
expected_checksum = _CHECKSUMS[model_type]
if expected_checksum is None:
return
with open(path, "rb") as f:
file_ = f.read()
checksum = hashlib.sha256(file_).hexdigest()
if checksum != expected_checksum:
raise RuntimeError(
"The checksum of the download does not match the expected checksum."
f"Expected: {expected_checksum}, got: {checksum}"
)
print("Download successful and checksums agree.")


def _get_checkpoint(model_type, checkpoint_path=None):
if checkpoint_path is None:
checkpoint_url = _MODEL_URLS[model_type]
checkpoint_name = _DOWNLOAD_NAMES.get(model_type, checkpoint_url.split("/")[-1])
checkpoint_path = os.path.join(_CHECKPOINT_FOLDER, checkpoint_name)

# download the checkpoint if necessary
if not os.path.exists(checkpoint_path):
os.makedirs(_CHECKPOINT_FOLDER, exist_ok=True)
_download(checkpoint_url, checkpoint_path, model_type)
elif not os.path.exists(checkpoint_path):
raise ValueError(f"The checkpoint path {checkpoint_path} that was passed does not exist.")

return checkpoint_path
_CACHE_DIR = os.environ.get('MICROSAM_CACHEDIR') or pooch.os_cache('micro_sam')
GenevieveBuckley marked this conversation as resolved.
Show resolved Hide resolved
MODELS = pooch.create(
path=pooch.os_cache(os.path.join(_CACHE_DIR, "models")),
base_url="",
registry={
# the default segment anything models
"vit_h": "a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e",
"vit_l": "3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622",
"vit_b": "ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912",
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "6dbb90523a35330fedd7f1d3dfc66f995213d81b29a5ca8108dbcdd4e37d6c2f",
# first version of finetuned models on zenodo
"vit_h_lm": "9a65ee0cddc05a98d60469a12a058859c89dc3ea3ba39fed9b90d786253fbf26",
"vit_b_lm": "5a59cc4064092d54cd4d92cd967e39168f3760905431e868e474d60fe5464ecd",
"vit_h_em": "ae3798a0646c8df1d4db147998a2d37e402ff57d3aa4e571792fbb911d8a979c",
"vit_b_em": "c04a714a4e14a110f0eec055a65f7409d54e6bf733164d2933a0ce556f7d6f81",
},
# Now specify custom URLs for some of the files in the registry.
urls={
# the default segment anything models
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
# the model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM
"vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
# first version of finetuned models on zenodo
"vit_h_lm": "https://zenodo.org/record/8250299/files/vit_h_lm.pth?download=1",
"vit_b_lm": "https://zenodo.org/record/8250281/files/vit_b_lm.pth?download=1",
"vit_h_em": "https://zenodo.org/record/8250291/files/vit_h_em.pth?download=1",
"vit_b_em": "https://zenodo.org/record/8250260/files/vit_b_em.pth?download=1",
},
)


def _get_default_device():
Expand Down Expand Up @@ -177,7 +131,6 @@ def _available_devices():
def get_sam_model(
model_type: str = _DEFAULT_MODEL,
device: Optional[str] = None,
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep this. Merging this into get_custom_sam_model would be confusing and I don't see a big issue keeping it here. See my main comment for more details.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I have restored the checkpoint_path keyword argument.

I will need you to test this (because all the example scripts involving checkpoint_path rely on data you personally have). The _get_checkpoint function is not covered by any of the existing tests, so it's safe to say we're not testing anything involving checkpoint_path right now.
Could you please run:

  • examples/annotator_with_custom_model.py
  • examples/finetuning/use_finetuned_model.py
  • examples/use_as_library/instance_segmentation.py

It would be great if we could upload the data and custom weights for examples/annotator_with_custom_model.py to zenodo or similar - or even make a similar example but with smaller 2d data and model weights. Ideally then anybody could run this example as a test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @GenevieveBuckley,
the models are all available via zenodo already. See https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/util.py#L46-L49. (I agree though that we should list this in the doc, I have added that to #249)
and I changed the examples in #280 so that all the examples use sample data.

The _get_checkpoint function is not covered by any of the existing tests, so it's safe to say we're not testing anything involving checkpoint_path right now.

I don't think that's true either since merging #280. This now uses a custom checkpoint. But you would need to merge the current dev branch in here first to get those tests.

I will still go ahead and test the examples with this branch.

return_sam: bool = False,
) -> SamPredictor:
r"""Get the SegmentAnything Predictor.
Expand All @@ -203,7 +156,7 @@ def get_sam_model(
Returns:
The segment anything predictor.
"""
checkpoint = _get_checkpoint(model_type, checkpoint_path)
checkpoint = MODELS.fetch(model_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just add something like this to support checkpoint_path:

if checkpoint_path is None:
  checkpoint = MODELS.fetch(model_type)
else:
  checkpoint = checkpoint_path

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I don't see any problem with mixing up hashes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused - if the weights are different, the hash value for the object should be different too, right? So how can anybody possibly do some extra fine tuning on a particular model, save the new weights, and then use checkpoint_path and give it the hash value stored in the pooch registry?

I might misunderstand what is actually happening with checkpoint_path.

Copy link
Contributor

@constantinpape constantinpape Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I should have written a better description here. I will try to explain what happens if a checkpoint_path is passed or not:

  • if checkpoint_path is given: this is a local path to sam weigths. E.g. if a user has fine-tuned on their own data. In this case we don't download any weights but initialize the model from these local weights. We do not check the weight file against any hashes.
  • if checkpoint_path is not given: we download the SAM weights corresponding to model_type. (I agree that this name is confusing, see comment below). We can check their hash.

So how can anybody possibly do some extra fine tuning on a particular model, save the new weights, and then use checkpoint_path and give it the hash value stored in the pooch registry?

When finetuning (or passing any other weights via checkpoint_path) we don't check the weights against the hash values in the pooch registry. From what I understand the current code is doing exactly what I described, since we only use the model registry in the case where checkpoint_path is None, and only then we actually check the hashes.

device = _get_device(device)

# Our custom model types have a suffix "_...". This suffix needs to be stripped
Expand Down