Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoRA Implementation #611

Merged
merged 19 commits into from
Jun 19, 2024
Merged

Add LoRA Implementation #611

merged 19 commits into from
Jun 19, 2024

Conversation

anwai98
Copy link
Contributor

@anwai98 anwai98 commented May 21, 2024

No description provided.

@anwai98 anwai98 marked this pull request as ready for review May 23, 2024 15:16
@anwai98 anwai98 requested a review from constantinpape May 23, 2024 15:16
Copy link
Contributor

@constantinpape constantinpape left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, only a few minor comments.

micro_sam/training/trainable_sam.py Outdated Show resolved Hide resolved
micro_sam/training/trainable_sam.py Outdated Show resolved Hide resolved
micro_sam/training/util.py Outdated Show resolved Hide resolved
@anwai98 anwai98 requested a review from constantinpape June 14, 2024 13:03
Copy link
Contributor

@constantinpape constantinpape left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make more sense to design the PEFT adapters as proper pytorch modules, see the comments for details.

micro_sam/training/peft_sam.py Outdated Show resolved Hide resolved
micro_sam/training/peft_sam.py Outdated Show resolved Hide resolved
micro_sam/training/peft_sam.py Outdated Show resolved Hide resolved
micro_sam/training/peft_sam.py Outdated Show resolved Hide resolved
micro_sam/training/peft_sam.py Outdated Show resolved Hide resolved
@anwai98
Copy link
Contributor Author

anwai98 commented Jun 16, 2024

@constantinpape Thanks for all the suggestions. I thought about all your feedback ,and imcorporated them now in a much more flexible design. Let me know if I missed something or you spot some more room for improvement. Thanks!

@anwai98 anwai98 requested a review from constantinpape June 16, 2024 19:19
Copy link
Contributor

@constantinpape constantinpape left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going the right direction! A few more changes to make it more modular.

In addition, it would also be good to have a test for this, that instantiates the PEFT SAM, applies a forward pass and checks the gradients afterwards.

micro_sam/training/peft_sam.py Show resolved Hide resolved
micro_sam/training/peft_sam.py Outdated Show resolved Hide resolved
micro_sam/training/peft_sam.py Show resolved Hide resolved
@anwai98
Copy link
Contributor Author

anwai98 commented Jun 17, 2024

@constantinpape Thanks for the feedback. I've updated the PEFT_Sam block as mentioned, so that it is now independent of any assumptions to update parts of the backbone, and instead the self.peft_module will be responsible of doing this from now on (i.e. the surgery class provided by the user to handle everything for the attention blocks).

I also added a simple test for the PEFT_Sam class.

Let me know how things look now. Thanks!

@anwai98 anwai98 requested a review from constantinpape June 17, 2024 11:01
@anwai98
Copy link
Contributor Author

anwai98 commented Jun 17, 2024

To add, looks like tests at test_bioimageio/test_model_export.py are failing for some reason.

Copy link
Contributor

@constantinpape constantinpape left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good now, only one minor thing.

micro_sam/training/peft_sam.py Show resolved Hide resolved
# Check the expected shape of the outputs
mask_shapes = [output["masks"].shape[-2:] for output in outputs]
for shape in mask_shapes:
self.assertEqual(shape, input_shape[1:])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we would check here that gradients are propagated when through the model. But I can also do this in a follow up PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I've left this in favor of the follow-up PR.

@anwai98
Copy link
Contributor Author

anwai98 commented Jun 18, 2024

Hi @constantinpape, this PR is GTG from my side now. Thanks!

@constantinpape constantinpape merged commit 22edc30 into dev Jun 19, 2024
0 of 3 checks passed
@constantinpape constantinpape deleted the lora-sam branch June 19, 2024 07:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants