Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Download models with pooch #276

Conversation

GenevieveBuckley
Copy link
Collaborator

I've made a start on #36

  • This PR downloads the models using pooch
  • We are still using the SHA256 hashes for validation, although at some future point we could use a (non-cryptographic) XXH128 hash as discussed here
  • I've removed the 'checkpoint' kwarg from get_sam_model. I think it makes more sense to have custom weights of any sort go through the get_custom_sam_model function (I haven't touched this function yet), and also if we download weights from different checkpoints then the hash validation is going to break.
  • Problem: I can't test a lot of this, because...
    • we have low test coverage, and
    • the example scripts often use hard coded paths to locations on Constantin's computer, so I can't run them (eg: examples/annotator_with_custom_model.py)

@GenevieveBuckley GenevieveBuckley changed the base branch from master to dev November 14, 2023 06:01
Copy link

codecov bot commented Nov 14, 2023

Codecov Report

Attention: 23 lines in your changes are missing coverage. Please review.

Comparison is base (71181d4) 39.72% compared to head (4d7a406) 40.55%.
Report is 27 commits behind head on dev.

Files Patch % Lines
micro_sam/sample_data.py 6.66% 14 Missing ⚠️
micro_sam/util.py 86.84% 5 Missing ⚠️
micro_sam/sam_annotator/_widgets.py 57.14% 3 Missing ⚠️
micro_sam/precompute_state.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##              dev     #276      +/-   ##
==========================================
+ Coverage   39.72%   40.55%   +0.83%     
==========================================
  Files          33       33              
  Lines        4179     4167      -12     
==========================================
+ Hits         1660     1690      +30     
+ Misses       2519     2477      -42     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@GenevieveBuckley
Copy link
Collaborator Author

If/when we go ahead with #276, we'll need to make sure pooch.os_cache is recomputed in the pooch MODELS registry every time, instead of just once when the module is imported.

micro_sam/util.py Outdated Show resolved Hide resolved
@constantinpape
Copy link
Contributor

Overall this is great!

  • We are still using the SHA256 hashes for validation, although at some future point we could use a (non-cryptographic) XXH128 hash as discussed here

That's good to know, and we can address this in a follow up.

  • I've removed the 'checkpoint' kwarg from get_sam_model. I think it makes more sense to have custom weights of any sort go through the get_custom_sam_model function (I haven't touched this function yet), and also if we download weights from different checkpoints then the hash validation is going to break.

I think that makes sense. I need to think about this a bit more because get_custom_sam_model currently also supports other kinds of weights. I will read this more carefully later and follow up.

Problem: I can't test a lot of this, because...

* we have low test coverage, and

* the example scripts often use hard coded paths to locations on Constantin's computer, so I can't run them (eg: `examples/annotator_with_custom_model.py`)
  • Regarding test coverage: we should probably add some more unit-test for the weights, I will think about it.
  • Regarding example scripts: good point. I will update those later and remove as much hard-coded paths as possible.

@@ -177,7 +131,6 @@ def _available_devices():
def get_sam_model(
model_type: str = _DEFAULT_MODEL,
device: Optional[str] = None,
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep this. Merging this into get_custom_sam_model would be confusing and I don't see a big issue keeping it here. See my main comment for more details.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I have restored the checkpoint_path keyword argument.

I will need you to test this (because all the example scripts involving checkpoint_path rely on data you personally have). The _get_checkpoint function is not covered by any of the existing tests, so it's safe to say we're not testing anything involving checkpoint_path right now.
Could you please run:

  • examples/annotator_with_custom_model.py
  • examples/finetuning/use_finetuned_model.py
  • examples/use_as_library/instance_segmentation.py

It would be great if we could upload the data and custom weights for examples/annotator_with_custom_model.py to zenodo or similar - or even make a similar example but with smaller 2d data and model weights. Ideally then anybody could run this example as a test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @GenevieveBuckley,
the models are all available via zenodo already. See https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/util.py#L46-L49. (I agree though that we should list this in the doc, I have added that to #249)
and I changed the examples in #280 so that all the examples use sample data.

The _get_checkpoint function is not covered by any of the existing tests, so it's safe to say we're not testing anything involving checkpoint_path right now.

I don't think that's true either since merging #280. This now uses a custom checkpoint. But you would need to merge the current dev branch in here first to get those tests.

I will still go ahead and test the examples with this branch.

@@ -203,7 +156,7 @@ def get_sam_model(
Returns:
The segment anything predictor.
"""
checkpoint = _get_checkpoint(model_type, checkpoint_path)
checkpoint = MODELS.fetch(model_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just add something like this to support checkpoint_path:

if checkpoint_path is None:
  checkpoint = MODELS.fetch(model_type)
else:
  checkpoint = checkpoint_path

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I don't see any problem with mixing up hashes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused - if the weights are different, the hash value for the object should be different too, right? So how can anybody possibly do some extra fine tuning on a particular model, save the new weights, and then use checkpoint_path and give it the hash value stored in the pooch registry?

I might misunderstand what is actually happening with checkpoint_path.

Copy link
Contributor

@constantinpape constantinpape Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I should have written a better description here. I will try to explain what happens if a checkpoint_path is passed or not:

  • if checkpoint_path is given: this is a local path to sam weigths. E.g. if a user has fine-tuned on their own data. In this case we don't download any weights but initialize the model from these local weights. We do not check the weight file against any hashes.
  • if checkpoint_path is not given: we download the SAM weights corresponding to model_type. (I agree that this name is confusing, see comment below). We can check their hash.

So how can anybody possibly do some extra fine tuning on a particular model, save the new weights, and then use checkpoint_path and give it the hash value stored in the pooch registry?

When finetuning (or passing any other weights via checkpoint_path) we don't check the weights against the hash values in the pooch registry. From what I understand the current code is doing exactly what I described, since we only use the model registry in the case where checkpoint_path is None, and only then we actually check the hashes.

@constantinpape
Copy link
Contributor

Hi @GenevieveBuckley ,
I had some time to look into this now:

get_sam_model

I think we should keep the (optional) checkpoint_path argument there. Mixing this up with get_custom_sam_model would be confusing. To clarify this:

  • get_sam_model loads a model from weights stored in the default format. These can come from two sources: either from weights that are available online (either the original SAM weights or our fine-tuned models) and that we have registered. Or they can be loaded from a local filepath that is specified via checkpoint_path.
  • get_custom_sam_model loads a model from a training checkpoint from our finetuning code. This function is not very relevant for most users, but is important for us to easily evaluate the finetuned models without having to convert them to the default format (which is done by export_custom_sam_model)

So these two functions do something conceptually different, and it would also complicate the implementation if get_custom_sam_model would accept both types of weight files, because we would then need to find out which of the two formats it is before loading the model.

  • and also if we download weights from different checkpoints then the hash validation is going to break.

We don't want to support downloading weights from arbitrary locations, but only for known models (either original SAM or our finetuned ones). So I don't really see any problems with hash validation here.
(If a user wants to use some other model available online they can download it themselves and use the checkpoint_path argument).

I hope that makes sense to you @GenevieveBuckley.
All of this should probably be documented better, I can follow up with a PR to improve the docstrings after we finish this one.

test coverage and example scripts

  • I have updated the example scripts so that we use get_cache_directory in there, see 1c79323
  • I have added more tests to extend the coverage for loading models in Increase test coverage #280. I will wait for your review before merging since I have written the test for get_sam_model according to the current logic.

@constantinpape
Copy link
Contributor

Hi @GenevieveBuckley,
I have merged #280 now to increase the test coverage. Most of the changes here should be covered by tests now, so I think you can continue here.
Let me know if something is still unclear with respect to get_sam_model vs. get_custom_sam_model or if there is some problem in resolving the merge conflicts.

os.makedirs(_CHECKPOINT_FOLDER, exist_ok=True)
_download(checkpoint_url, checkpoint_path, model_type)
os.makedirs(microsam_cachedir()/"models", exist_ok=True)
pooch.retrieve(url=checkpoint_url, known_hash=models().registry.get(model_name))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does known_hash need to be None here (i.e. unknown)? I don't quite understand how someone can do fine tuning of model weights, but have the hash stay unchanged (which is I think what would have to happen for what's written here to make any sense).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I don't fully understand what pooch is doing. I will check this out and follow up.

# Our custom model types have a suffix "_...". This suffix needs to be stripped
# before calling sam_model_registry.
model_type_ = model_type[:5]
assert model_type_ in ("vit_h", "vit_b", "vit_l", "vit_t")
if model_type == "vit_t" and not VIT_T_SUPPORT:
if model_type_ == "vit_t" and not VIT_T_SUPPORT:
Copy link
Collaborator Author

@GenevieveBuckley GenevieveBuckley Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems possible that in the future people might fine tune models based off vit_t, so we'd have model_type="vit_t_em", etc.
I find the model_type/model_type_ distinction here confusing, I'd kinda like to rename model_type -> model_name instead. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fully agree, we should have a better distinction here and the current distinction is confusing. Let's do what you suggest and rename model_type -> model_name and model_type_ -> model_type

@@ -219,7 +222,7 @@ def get_sam_model(
sam = sam_model_registry[model_type_](checkpoint=checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.model_type = model_type
predictor.model_type = model_type_
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's correct.

@GenevieveBuckley
Copy link
Collaborator Author

The most important next step is to check that checkpoint_path is working well.

Coping this comment from above, so it doesn't get lost/hidden if I make updates to that section of the code.

Ok. I have restored the checkpoint_path keyword argument.

I will need you to test this (because all the example scripts involving checkpoint_path rely on data you personally have). The _get_checkpoint function is not covered by any of the existing tests, so it's safe to say we're not testing anything involving checkpoint_path right now. Could you please run:

* `examples/annotator_with_custom_model.py`

* `examples/finetuning/use_finetuned_model.py`

* `examples/use_as_library/instance_segmentation.py`

It would be great if we could upload the data and custom weights for examples/annotator_with_custom_model.py to zenodo or similar - or even make a similar example but with smaller 2d data and model weights. Ideally then anybody could run this example as a test.

@constantinpape
Copy link
Contributor

The most important next step is to check that checkpoint_path is working well.

Yep, you're right. I think I misunderstood a few things about how pooch works, this is why we may have talked past each other in some of the comments. I don't think this is a big issue though and will just need changes in a few other places, I will check this out later and make a small PR onto this one, I think that's the easiest solution.

constantinpape and others added 2 commits November 23, 2023 09:20
Update pooch download, rename model_type->model_name in get_sam_model
@constantinpape
Copy link
Contributor

Thanks @GenevieveBuckley! This cleans up things significantly and I hope now the usage of the get_sam_model functions is much clearer.

@constantinpape constantinpape merged commit c2a4e54 into computational-cell-analytics:dev Nov 23, 2023
5 checks passed
@constantinpape constantinpape deleted the pooch-model-download branch November 23, 2023 09:21
@constantinpape constantinpape restored the pooch-model-download branch November 23, 2023 09:21
@GenevieveBuckley GenevieveBuckley deleted the pooch-model-download branch December 6, 2023 04:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants