Skip to content

FIX set_lora_device when target layers differ #11844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,27 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
you want to load multiple adapters and free some GPU memory.

After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would include a code snippet to also show how users can do it just to keep it complete.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an example below, does that work?

can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to
GPU before using those LoRA adapters for inference.

```python
>>> pipe.load_lora_weights(<path-1>, adapter_name="adapter-1")
>>> pipe.load_lora_weights(<path-1>, adapter_name="adapter-2")
>>> pipe.set_adapters("adapter-1")
>>> image_1 = pipe(**kwargs)
>>> # switch to adapter-2, offload adapter-1
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu")
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0")
>>> pipe.set_adapters("adapter-2")
>>> image_2 = pipe(**kwargs)
>>> # switch back to adapter-1, offload adapter-2
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu")
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0")
>>> pipe.set_adapters("adapter-1")
>>> ...
```

Args:
adapter_names (`List[str]`):
List of adapters to send device to.
Expand All @@ -949,6 +970,10 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
for module in model.modules():
if isinstance(module, BaseTunerLayer):
for adapter_name in adapter_names:
if adapter_name not in module.lora_A:
# it is sufficient to check lora_A
continue

module.lora_A[adapter_name].to(device)
module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
Expand Down
51 changes: 49 additions & 2 deletions tests/lora/test_lora_layers_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_integration_move_lora_cpu(self):

self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in text encoder",
"Lora not correctly set in unet",
)

# We will offload the first adapter in CPU and check if the offloading
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_integration_move_lora_dora_cpu(self):

self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in text encoder",
"Lora not correctly set in unet",
)

for name, param in pipe.unet.named_parameters():
Expand All @@ -208,6 +208,53 @@ def test_integration_move_lora_dora_cpu(self):
if "lora_" in name:
self.assertNotEqual(param.device, torch.device("cpu"))

@slow
@require_torch_accelerator
def test_integration_set_lora_device_different_target_layers(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it makes sense to have this as a fast test and have it available in tests/lora/utils.py?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what qualifies as a slow test in diffusers and what doesn't. As this test is very similar to test_integration_move_lora_cpu, which is marked as @slow, I opted to do the same here. As to where to put the test, I can move it to utils.py but I chose to keep put it here because test_integration_move_lora_cpu is also here.

If we decide to move it, would you still keep @require_torch_accelerator or do everything on CPU (as I mentioned, GPU is not strictly required to trigger the bug).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the test_integration_move_lora_cpu was added like that to quickly fix the underlying issue and validate the solution while doing so. Yes, that test also deserved to be moved to tests/lora/utils.py. Do you see any reason not to?

I think a majority of the slow tests should essentially be "integration" tests (i.e., the ones that test with real checkpoints), barring a few where fast testing isn't really possible.

If we decide to move it, would you still keep @require_torch_accelerator or do everything on CPU (as I mentioned, GPU is not strictly required to trigger the bug).

I won't keep the @require_torch_accelerator then if CPU (or more generally, torch_device) is enough to test it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the test_integration_move_lora_cpu was added like that to quickly fix the underlying issue and validate the solution while doing so. Yes, that test also deserved to be moved to tests/lora/utils.py.

Okay, I can do that in this PR. Same for test_integration_move_lora_dora_cpu?

Do you see any reason not to?

As I'm not sure what the organization of tests is here, I can't say.

I won't keep the @require_torch_accelerator then if CPU (or more generally, torch_device) is enough to test it.

Okay, so I'd remove the @require_torch_accelerator decorator but keep on using torch_device as is presently done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I can do that in this PR. Same for test_integration_move_lora_dora_cpu?

Yeah sure.

I am also happy to merge the PR after #11844 (comment) and do the organizing myself in a future PR, and ask for your reviews.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I just saw your comment above:

# Keeping this test here makes sense because it doesn't look any integration
# (value assertions on logits).

So it seems the test should stay here after all?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still didn't make sense to keep "keeping" it there after the number of pipelines supporting LoRAs increased from just SD and SDXL. So, I still abide by what I said in the above comments i.e., moving those tests to tests/lora/utils.py. But as also mentioned in #11844 (comment), it's fine if you don't do that and we tackle it in an immediate future PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I checked what it would take to move this to utils.py. From my understanding, this is a base test class that checks a bunch of different models. Therefore, the test must be made more generic, I cannot just hard-code target_modules and I can't assume that there is a unet component. I think this exceeds my diffusers knowledge, so I'd either need your guidance or must defer to you implementing this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will take care of it after this PR is merged.

# fixes a bug that occurred when calling set_lora_device with multiple adapters loaded that target different
# layers, see #11833
from peft import LoraConfig
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could move it here:

if is_peft_available():

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean in case we move the test there?

Copy link
Member

@sayakpaul sayakpaul Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant moving the import there (i.e., under is_peft_available()).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving the import to another file would not help us avoiding it here, or am I missing something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, in my mind, I had already that this test is under tests/lora/utils.py. As it seems like we are going to move the test there, let's make use to move the LoraConfig import under the block specified?


path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
# configs partly target the same, partly different layers
config0 = LoraConfig(target_modules=["to_k", "to_v"])
config1 = LoraConfig(target_modules=["to_k", "to_q"])
pipe.unet.add_adapter(config0, adapter_name="adapter-0")
pipe.unet.add_adapter(config1, adapter_name="adapter-1")
pipe = pipe.to(torch_device)

self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in unet",
)

# sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
self.assertNotEqual(modules_adapter_0, modules_adapter_1)
self.assertTrue(modules_adapter_0 - modules_adapter_1)
self.assertTrue(modules_adapter_1 - modules_adapter_0)

# setting both separately works
pipe.set_lora_device(["adapter-0"], "cpu")
pipe.set_lora_device(["adapter-1"], "cpu")

for name, module in pipe.unet.named_modules():
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu"))
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu"))

# setting both at once also works
pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)

for name, module in pipe.unet.named_modules():
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu"))
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu"))
Comment on lines +243 to +256
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also test inference by invoking pipe(...)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add inference. I didn't do it because it is not done in test_integration_move_lora_cpu either and I thought it might distract from the main point of the test, but I can add it still.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay then. Let's keep inference out of the question. But IMO, it is just helpful to also test that the pipeline isn't impacted as far as running inference is concerned.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having another test for inference would make sense, I guess it could be similar but only check for inference. Presumably, that would be an actual @slow test?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't suggesting another test, though. Just adding the pipe(**inputs) line won't likely be too much of an unpleasantness in the current test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I was suggesting a separate test to check pipe(**inputs). I still tried it in this test but actually there is an error. When moving the LoRAs to CPU, as the base model is still on GPU, there is a device error. I don't think it's really surprising, but it means we can't really test that part.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay that's good to know. If this is the case, I think we should prefer to show / comment on what the users are supposed to do make pipeline inference work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should prefer to show / comment on what the users are supposed to do make pipeline inference work.

You mean updating the docs? If yes, I'd add a sentence to the set_lora_device docstring about that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either docs or the tests. Updating the docstring is more sensible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the docstring.



@slow
@nightly
Expand Down
Loading