-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LoRA Implementation #611
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, only a few minor comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would make more sense to design the PEFT adapters as proper pytorch modules, see the comments for details.
@constantinpape Thanks for all the suggestions. I thought about all your feedback ,and imcorporated them now in a much more flexible design. Let me know if I missed something or you spot some more room for improvement. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going the right direction! A few more changes to make it more modular.
In addition, it would also be good to have a test for this, that instantiates the PEFT SAM, applies a forward pass and checks the gradients afterwards.
@constantinpape Thanks for the feedback. I've updated the I also added a simple test for the Let me know how things look now. Thanks! |
To add, looks like tests at |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good now, only one minor thing.
# Check the expected shape of the outputs | ||
mask_shapes = [output["masks"].shape[-2:] for output in outputs] | ||
for shape in mask_shapes: | ||
self.assertEqual(shape, input_shape[1:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we would check here that gradients are propagated when through the model. But I can also do this in a follow up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I've left this in favor of the follow-up PR.
Hi @constantinpape, this PR is GTG from my side now. Thanks! |
No description provided.