Skip to content

Commit

Permalink
[LoRA] support LyCORIS (#5102)
Browse files Browse the repository at this point in the history
* better condition.

* debugging

* how about now?

* how about now?

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* support for lycoris.

* style

* add: lycoris test

* fix from_pretrained call.

* fix assertion values.
  • Loading branch information
sayakpaul authored Sep 20, 2023
1 parent 8263cf0 commit e312b23
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
9 changes: 8 additions & 1 deletion src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,7 +1878,7 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")

# SDXL specificity.
if "emb" in diffusers_name:
if "emb" in diffusers_name and "time" not in diffusers_name:
pattern = r"\.\d+(?=\D*$)"
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
if ".in." in diffusers_name:
Expand All @@ -1890,6 +1890,13 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
if "skip" in diffusers_name:
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")

# LyCORIS specificity.
if "time" in diffusers_name:
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
if "conv.shortcut" in diffusers_name:
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")

# General coverage.
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import nn

from .activations import get_activation
from .lora import LoRACompatibleLinear


def get_timestep_embedding(
Expand Down Expand Up @@ -166,7 +167,7 @@ def __init__(
):
super().__init__()

self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)

if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
Expand All @@ -179,7 +180,7 @@ def __init__(
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)

if post_act_fn is None:
self.post_act = None
Expand Down
19 changes: 19 additions & 0 deletions tests/lora/test_lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1876,6 +1876,25 @@ def test_a1111(self):

self.assertTrue(np.allclose(images, expected, atol=1e-3))

def test_lycoris(self):
generator = torch.Generator().manual_seed(0)

pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/Amixx", safety_checker=None, use_safetensors=True, variant="fp16"
).to(torch_device)
lora_model_id = "hf-internal-testing/edgLycorisMugler-light"
lora_filename = "edgLycorisMugler-light.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)

images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images

images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.6463, 0.658, 0.599, 0.6542, 0.6512, 0.6213, 0.658, 0.6485, 0.6017])

self.assertTrue(np.allclose(images, expected, atol=1e-3))

def test_a1111_with_model_cpu_offload(self):
generator = torch.Generator().manual_seed(0)

Expand Down

0 comments on commit e312b23

Please sign in to comment.