-
Notifications
You must be signed in to change notification settings - Fork 6.1k
FIX set_lora_device when target layers differ #11844
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
82a7f92
f023991
62769f3
37e544c
41b6206
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -120,7 +120,7 @@ def test_integration_move_lora_cpu(self): | |||
|
||||
self.assertTrue( | ||||
check_if_lora_correctly_set(pipe.unet), | ||||
"Lora not correctly set in text encoder", | ||||
"Lora not correctly set in unet", | ||||
) | ||||
|
||||
# We will offload the first adapter in CPU and check if the offloading | ||||
|
@@ -187,7 +187,7 @@ def test_integration_move_lora_dora_cpu(self): | |||
|
||||
self.assertTrue( | ||||
check_if_lora_correctly_set(pipe.unet), | ||||
"Lora not correctly set in text encoder", | ||||
"Lora not correctly set in unet", | ||||
) | ||||
|
||||
for name, param in pipe.unet.named_parameters(): | ||||
|
@@ -208,6 +208,53 @@ def test_integration_move_lora_dora_cpu(self): | |||
if "lora_" in name: | ||||
self.assertNotEqual(param.device, torch.device("cpu")) | ||||
|
||||
@slow | ||||
@require_torch_accelerator | ||||
def test_integration_set_lora_device_different_target_layers(self): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if it makes sense to have this as a fast test and have it available in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know what qualifies as a slow test in diffusers and what doesn't. As this test is very similar to If we decide to move it, would you still keep There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the I think a majority of the slow tests should essentially be "integration" tests (i.e., the ones that test with real checkpoints), barring a few where fast testing isn't really possible.
I won't keep the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Okay, I can do that in this PR. Same for
As I'm not sure what the organization of tests is here, I can't say.
Okay, so I'd remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah sure. I am also happy to merge the PR after #11844 (comment) and do the organizing myself in a future PR, and ask for your reviews. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I just saw your comment above:
So it seems the test should stay here after all? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It still didn't make sense to keep "keeping" it there after the number of pipelines supporting LoRAs increased from just SD and SDXL. So, I still abide by what I said in the above comments i.e., moving those tests to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I checked what it would take to move this to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I will take care of it after this PR is merged. |
||||
# fixes a bug that occurred when calling set_lora_device with multiple adapters loaded that target different | ||||
# layers, see #11833 | ||||
from peft import LoraConfig | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could move it here: Line 50 in a79c3af
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean in case we move the test there? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant moving the import there (i.e., under There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving the import to another file would not help us avoiding it here, or am I missing something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, in my mind, I had already that this test is under |
||||
|
||||
path = "stable-diffusion-v1-5/stable-diffusion-v1-5" | ||||
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16) | ||||
# configs partly target the same, partly different layers | ||||
config0 = LoraConfig(target_modules=["to_k", "to_v"]) | ||||
config1 = LoraConfig(target_modules=["to_k", "to_q"]) | ||||
pipe.unet.add_adapter(config0, adapter_name="adapter-0") | ||||
pipe.unet.add_adapter(config1, adapter_name="adapter-1") | ||||
pipe = pipe.to(torch_device) | ||||
|
||||
self.assertTrue( | ||||
check_if_lora_correctly_set(pipe.unet), | ||||
"Lora not correctly set in unet", | ||||
) | ||||
|
||||
# sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix | ||||
modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")} | ||||
modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")} | ||||
self.assertNotEqual(modules_adapter_0, modules_adapter_1) | ||||
self.assertTrue(modules_adapter_0 - modules_adapter_1) | ||||
self.assertTrue(modules_adapter_1 - modules_adapter_0) | ||||
|
||||
# setting both separately works | ||||
pipe.set_lora_device(["adapter-0"], "cpu") | ||||
pipe.set_lora_device(["adapter-1"], "cpu") | ||||
|
||||
for name, module in pipe.unet.named_modules(): | ||||
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)): | ||||
self.assertTrue(module.weight.device == torch.device("cpu")) | ||||
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): | ||||
self.assertTrue(module.weight.device == torch.device("cpu")) | ||||
|
||||
# setting both at once also works | ||||
pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device) | ||||
|
||||
for name, module in pipe.unet.named_modules(): | ||||
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)): | ||||
self.assertTrue(module.weight.device != torch.device("cpu")) | ||||
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): | ||||
self.assertTrue(module.weight.device != torch.device("cpu")) | ||||
Comment on lines
+243
to
+256
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we also test inference by invoking There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can add inference. I didn't do it because it is not done in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay then. Let's keep inference out of the question. But IMO, it is just helpful to also test that the pipeline isn't impacted as far as running inference is concerned. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think having another test for inference would make sense, I guess it could be similar but only check for inference. Presumably, that would be an actual There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't suggesting another test, though. Just adding the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I was suggesting a separate test to check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay that's good to know. If this is the case, I think we should prefer to show / comment on what the users are supposed to do make pipeline inference work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
You mean updating the docs? If yes, I'd add a sentence to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either docs or the tests. Updating the docstring is more sensible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated the docstring. |
||||
|
||||
|
||||
@slow | ||||
@nightly | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would include a code snippet to also show how users can do it just to keep it complete.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added an example below, does that work?