Skip to content

Commit

Permalink
[Feat] Support SDXL Kohya-style LoRA (#4287)
Browse files Browse the repository at this point in the history
* sdxl lora changes.

* better name replacement.

* better replacement.

* debugging

* debugging

* debugging

* debugging

* debugging

* remove print.

* print state dict keys.

* print

* distingisuih better

* debuggable.

* fxi: tyests

* fix: arg from training script.

* access from class.

* run style

* debug

* save intermediate

* some simplifications for SDXL LoRA

* styling

* unet config is not needed in diffusers format.

* fix: dynamic SGM block mapping for SDXL kohya loras (#4322)

* Use lora compatible layers for linear proj_in/proj_out (#4323)

* improve condition for using the sgm_diffusers mapping

* informative comment.

* load compatible keys and embedding layer maaping.

* Get SDXL 1.0 example lora to load

* simplify

* specif ranks and hidden sizes.

* better handling of k rank and hidden

* debug

* debug

* debug

* debug

* debug

* fix: alpha keys

* add check for handling LoRAAttnAddedKVProcessor

* sanity comment

* modifications for text encoder SDXL

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* denugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* up

* up

* up

* up

* up

* up

* unneeded comments.

* unneeded comments.

* kwargs for the other attention processors.

* kwargs for the other attention processors.

* debugging

* debugging

* debugging

* debugging

* improve

* debugging

* debugging

* more print

* Fix alphas

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* clean up

* clean up.

* debugging

* fix: text

---------

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Batuhan Taskaya <[email protected]>
  • Loading branch information
3 people committed Jul 28, 2023
1 parent c3e3a1e commit b1e5279
Show file tree
Hide file tree
Showing 10 changed files with 553 additions and 173 deletions.
50 changes: 49 additions & 1 deletion docs/source/en/training/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,4 +354,52 @@ directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so:
lora_model_id = "sayakpaul/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
```
```

### Supporting Stable Diffusion XL LoRAs trained using the Kohya-trainer

With this [PR](https://github.com/huggingface/diffusers/pull/4287), there should now be better support for loading Kohya-style LoRAs trained on Stable Diffusion XL (SDXL).

Here are some example checkpoints we tried out:

* SDXL 0.9:
* https://civitai.com/models/22279?modelVersionId=118556
* https://civitai.com/models/104515/sdxlor30costumesrevue-starlight-saijoclaudine-lora
* https://civitai.com/models/108448/daiton-sdxl-test
* https://filebin.net/2ntfqqnapiu9q3zx/pixelbuildings128-v1.safetensors
* SDXL 1.0:
* https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors

Here is an example of how to perform inference with these checkpoints in `diffusers`:

```python
from diffusers import DiffusionPipeline
import torch

base_model_id = "stabilityai/stable-diffusion-xl-base-0.9"
pipeline = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
pipeline.load_lora_weights(".", weight_name="Kamepan.safetensors")

prompt = "anime screencap, glint, drawing, best quality, light smile, shy, a full body of a girl wearing wedding dress in the middle of the forest beneath the trees, fireflies, big eyes, 2d, cute, anime girl, waifu, cel shading, magical girl, vivid colors, (outline:1.1), manga anime artstyle, masterpiece, offical wallpaper, glint <lora:kame_sdxl_v2:1>"
negative_prompt = "(deformed, bad quality, sketch, depth of field, blurry:1.1), grainy, bad anatomy, bad perspective, old, ugly, realistic, cartoon, disney, bad propotions"
generator = torch.manual_seed(2947883060)
num_inference_steps = 30
guidance_scale = 7

image = pipeline(
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps,
generator=generator, guidance_scale=guidance_scale
).images[0]
image.save("Kamepan.png")
```

`Kamepan.safetensors` comes from https://civitai.com/models/22279?modelVersionId=118556 .

If you notice carefully, the inference UX is exactly identical to what we presented in the sections above.

Thanks to [@isidentical](https://github.com/isidentical) for helping us on integrating this feature.

### Known limitations specific to the Kohya-styled LoRAs

* SDXL LoRAs that have both the text encoders are currently leading to weird results. We're actively investigating the issue.
* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).
6 changes: 3 additions & 3 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,10 +924,10 @@ def load_model_hook(models, input_dir):
else:
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
)

accelerator.register_save_state_pre_hook(save_model_hook)
Expand Down
8 changes: 4 additions & 4 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,13 +829,13 @@ def load_model_hook(models, input_dir):
else:
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)

accelerator.register_save_state_pre_hook(save_model_hook)
Expand Down
Loading

0 comments on commit b1e5279

Please sign in to comment.