diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 562a21dbbb74..cd4738cfa03c 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -934,6 +934,27 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case you want to load multiple adapters and free some GPU memory. + After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters + can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to + GPU before using those LoRA adapters for inference. + + ```python + >>> pipe.load_lora_weights(path_1, adapter_name="adapter-1") + >>> pipe.load_lora_weights(path_2, adapter_name="adapter-2") + >>> pipe.set_adapters("adapter-1") + >>> image_1 = pipe(**kwargs) + >>> # switch to adapter-2, offload adapter-1 + >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu") + >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0") + >>> pipe.set_adapters("adapter-2") + >>> image_2 = pipe(**kwargs) + >>> # switch back to adapter-1, offload adapter-2 + >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu") + >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0") + >>> pipe.set_adapters("adapter-1") + >>> ... + ``` + Args: adapter_names (`List[str]`): List of adapters to send device to. @@ -949,6 +970,10 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, for module in model.modules(): if isinstance(module, BaseTunerLayer): for adapter_name in adapter_names: + if adapter_name not in module.lora_A: + # it is sufficient to check lora_A + continue + module.lora_A[adapter_name].to(device) module.lora_B[adapter_name].to(device) # this is a param, not a module, so device placement is not in-place -> re-assign diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py index a81128fa446b..1c5a9b00e9da 100644 --- a/tests/lora/test_lora_layers_sd.py +++ b/tests/lora/test_lora_layers_sd.py @@ -120,7 +120,7 @@ def test_integration_move_lora_cpu(self): self.assertTrue( check_if_lora_correctly_set(pipe.unet), - "Lora not correctly set in text encoder", + "Lora not correctly set in unet", ) # We will offload the first adapter in CPU and check if the offloading @@ -187,7 +187,7 @@ def test_integration_move_lora_dora_cpu(self): self.assertTrue( check_if_lora_correctly_set(pipe.unet), - "Lora not correctly set in text encoder", + "Lora not correctly set in unet", ) for name, param in pipe.unet.named_parameters(): @@ -208,6 +208,53 @@ def test_integration_move_lora_dora_cpu(self): if "lora_" in name: self.assertNotEqual(param.device, torch.device("cpu")) + @slow + @require_torch_accelerator + def test_integration_set_lora_device_different_target_layers(self): + # fixes a bug that occurred when calling set_lora_device with multiple adapters loaded that target different + # layers, see #11833 + from peft import LoraConfig + + path = "stable-diffusion-v1-5/stable-diffusion-v1-5" + pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16) + # configs partly target the same, partly different layers + config0 = LoraConfig(target_modules=["to_k", "to_v"]) + config1 = LoraConfig(target_modules=["to_k", "to_q"]) + pipe.unet.add_adapter(config0, adapter_name="adapter-0") + pipe.unet.add_adapter(config1, adapter_name="adapter-1") + pipe = pipe.to(torch_device) + + self.assertTrue( + check_if_lora_correctly_set(pipe.unet), + "Lora not correctly set in unet", + ) + + # sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix + modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")} + modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")} + self.assertNotEqual(modules_adapter_0, modules_adapter_1) + self.assertTrue(modules_adapter_0 - modules_adapter_1) + self.assertTrue(modules_adapter_1 - modules_adapter_0) + + # setting both separately works + pipe.set_lora_device(["adapter-0"], "cpu") + pipe.set_lora_device(["adapter-1"], "cpu") + + for name, module in pipe.unet.named_modules(): + if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)): + self.assertTrue(module.weight.device == torch.device("cpu")) + elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): + self.assertTrue(module.weight.device == torch.device("cpu")) + + # setting both at once also works + pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device) + + for name, module in pipe.unet.named_modules(): + if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)): + self.assertTrue(module.weight.device != torch.device("cpu")) + elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)): + self.assertTrue(module.weight.device != torch.device("cpu")) + @slow @nightly