Skip to content

Commit

Permalink
[LoRA] fix cross_attention_kwargs problems and tighten tests (#7388)
Browse files Browse the repository at this point in the history
* debugging

* let's see the numbers

* let's see the numbers

* let's see the numbers

* restrict tolerance.

* increase inference steps.

* shallow copy of cross_attentionkwargs

* remove print
  • Loading branch information
sayakpaul committed Mar 20, 2024
1 parent 5584e1c commit 84bc0e4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,7 @@ def forward(
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
if cross_attention_kwargs is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
Expand Down
9 changes: 7 additions & 2 deletions tests/lora/test_lora_layers_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def get_dummy_inputs(self, with_generator=True):

pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 2,
"num_inference_steps": 5,
"guidance_scale": 6.0,
"output_type": "np",
}
Expand Down Expand Up @@ -589,7 +589,7 @@ def test_simple_inference_with_text_unet_lora_and_scale(self):
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
).images
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
not np.allclose(output_lora, output_lora_scale, atol=1e-4, rtol=1e-4),
"Lora + scale should change the output",
)

Expand Down Expand Up @@ -1300,6 +1300,11 @@ def test_integration_logits_with_scale(self):
pipe.load_lora_weights(lora_id)
pipe = pipe.to("cuda")

self.assertTrue(
self.check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in UNet",
)

self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder 2",
Expand Down

0 comments on commit 84bc0e4

Please sign in to comment.