Skip to content

Commit

Permalink
[PEFT / LoRA] PEFT integration - text encoder (#5058)
Browse files Browse the repository at this point in the history
* more fixes

* up

* up

* style

* add in setup

* oops

* more changes

* v1 rzfactor CI

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* few todos

* protect torch import

* style

* fix fuse text encoder

* Update src/diffusers/loaders.py

Co-authored-by: Sayak Paul <[email protected]>

* replace with `recurse_replace_peft_layers`

* keep old modules for BC

* adjustments on `adjust_lora_scale_text_encoder`

* nit

* move tests

* add conversion utils

* remove unneeded methods

* use class method instead

* oops

* use `base_version`

* fix examples

* fix CI

* fix weird error with python 3.8

* fix

* better fix

* style

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* add comment

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

* conv2d support for recurse remove

* added docstrings

* more docstring

* add deprecate

* revert

* try to fix merge conflicts

* v1 tests

* add new decorator

* add saving utilities test

* adapt tests a bit

* add save / from_pretrained tests

* add saving tests

* add scale tests

* fix deps tests

* fix lora CI

* fix tests

* add comment

* fix

* style

* add slow tests

* slow tests pass

* style

* Update src/diffusers/utils/import_utils.py

Co-authored-by: Benjamin Bossan <[email protected]>

* Apply suggestions from code review

Co-authored-by: Benjamin Bossan <[email protected]>

* circumvents pattern finding issue

* left a todo

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* update hub path

* add lora workflow

* fix

---------

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Benjamin Bossan <[email protected]>
  • Loading branch information
4 people authored Sep 22, 2023
1 parent b32555a commit 493f952
Show file tree
Hide file tree
Showing 46 changed files with 1,193 additions and 175 deletions.
67 changes: 67 additions & 0 deletions .github/workflows/pr_test_peft_backend.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
name: Fast tests for PRs - PEFT backend

on:
pull_request:
branches:
- main

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60

jobs:
run_fast_tests:
strategy:
fail-fast: false
matrix:
config:
- name: LoRA
framework: lora
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_lora


name: ${{ matrix.config.name }}

runs-on: ${{ matrix.config.runner }}

container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/

defaults:
run:
shell: bash

steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2

- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate.git
python -m pip install -U git+https://github.com/huggingface/transformers.git
python -m pip install -U git+https://github.com/huggingface/peft.git
- name: Environment
run: |
python utils/print_env.py
- name: Run fast PyTorch LoRA CPU tests with PEFT backend
if: ${{ matrix.config.framework == 'lora' }}
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
-s -v \
--make-reports=tests_${{ matrix.config.report }} \
tests/lora/test_lora_layers_peft.py
316 changes: 198 additions & 118 deletions src/diffusers/loaders.py

Large diffs are not rendered by default.

31 changes: 19 additions & 12 deletions src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,25 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale

for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
if use_peft_backend:
from peft.tuners.lora import LoraLayer

for module in text_encoder.modules():
if isinstance(module, LoraLayer):
module.scaling[module.active_adapter] = lora_scale
else:
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale

for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale


class LoRALinearLayer(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)

prompt = [prompt] if isinstance(prompt, str) else prompt

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)

prompt = [prompt] if isinstance(prompt, str) else prompt

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)

prompt = [prompt] if isinstance(prompt, str) else prompt

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)

prompt = [prompt] if isinstance(prompt, str) else prompt

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)

prompt = [prompt] if isinstance(prompt, str) else prompt

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,8 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)

prompt = [prompt] if isinstance(prompt, str) else prompt

Expand Down
Loading

0 comments on commit 493f952

Please sign in to comment.