Skip to content

Commit

Permalink
Merge pull request #40 from invoke-ai/feat/generate-with-lora-and-ti
Browse files Browse the repository at this point in the history
Add ability to use LoRAs and TI embeddings in generate_images.py script
  • Loading branch information
RyanJDick authored Dec 7, 2023
2 parents 67c9050 + 7ecdc59 commit 8c43e88
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
37 changes: 37 additions & 0 deletions src/invoke_training/scripts/invoke_generate_images.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from pathlib import Path

from invoke_training.training.shared.model_loading_utils import PipelineVersionEnum
from invoke_training.training.tools.generate_images import generate_images
Expand All @@ -25,6 +26,20 @@ def parse_args():
"stable diffusion checkpoint file. (E.g. 'runwayml/stable-diffusion-v1-5', "
"'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/realisticVisionV51_v51VAE.safetensors', etc. )",
)
parser.add_argument(
"-l",
"--lora",
type=str,
nargs="*",
help="LoRA models to apply to the base model. The LoRA weight can optionally be provided after a colon "
"separator. E.g. `-l path/to/lora.bin:0.5 -l path/to/lora_2.safetensors`. ",
)
parser.add_argument(
"--ti",
type=str,
nargs="*",
help="Paths(s) to Textual Inversion embeddings to apply to the base model.",
)
parser.add_argument(
"--sd-version",
type=str,
Expand Down Expand Up @@ -67,9 +82,29 @@ def parse_args():
return parser.parse_args()


def parse_lora_args(lora_args: list[str] | None) -> list[tuple[Path, int]]:
loras: list[tuple[Path, int]] = []

lora_args = lora_args or []
for lora in lora_args:
lora_split = lora.split(":")

if len(lora_split) == 1:
# If weight is not specified, assume 1.0.
loras.append((Path(lora_split[0]), 1.0))
elif len(lora_split) == 2:
loras.append((Path(lora_split[0]), float(lora_split[1])))
else:
raise ValueError(f"Invalid lora argument syntax: '{lora}'.")

return loras


def main():
args = parse_args()

loras = parse_lora_args(args.lora)

print(f"Generating {args.num_images} images in '{args.out_dir}'.")
generate_images(
out_dir=args.out_dir,
Expand All @@ -79,6 +114,8 @@ def main():
num_images=args.num_images,
height=args.height,
width=args.width,
loras=loras,
ti_embeddings=args.ti,
seed=args.seed,
enable_cpu_offload=args.enable_cpu_offload,
)
Expand Down
17 changes: 17 additions & 0 deletions src/invoke_training/training/tools/generate_images.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from pathlib import Path
from typing import Optional

import torch
from tqdm import tqdm
Expand All @@ -17,6 +19,8 @@ def generate_images(
num_images: int,
height: int,
width: int,
loras: Optional[list[tuple[Path, float]]] = None,
ti_embeddings: Optional[list[str]] = None,
seed: int = 0,
torch_dtype: torch.dtype = torch.float16,
torch_device: str = "cuda",
Expand All @@ -35,6 +39,9 @@ def generate_images(
with).
width (int): The output image width in pixels (recommended to match the resolution that the model was trained
with).
loras (list[tuple[Path, float]], optional): Paths to LoRA models to apply to the base model with associated
weights.
ti_embeddings (list[str], optional): Paths to TI embeddings to apply to the base model.
seed (int, optional): A seed for repeatability. Defaults to 0.
torch_dtype (torch.dtype, optional): The torch dtype. Defaults to torch.float16.
torch_device (str, optional): The torch device. Defaults to "cuda".
Expand All @@ -44,6 +51,16 @@ def generate_images(

pipeline = load_pipeline(model, pipeline_version)

loras = loras or []
for lora in loras:
lora_path, lora_scale = lora
pipeline.load_lora_weights(str(lora_path), weight_name=str(lora_path.name))
pipeline.fuse_lora(lora_scale=lora_scale)

ti_embeddings = ti_embeddings or []
for ti_embedding in ti_embeddings:
pipeline.load_textual_inversion(ti_embedding)

pipeline.to(torch_dtype=torch_dtype)
if enable_cpu_offload:
pipeline.enable_model_cpu_offload()
Expand Down

0 comments on commit 8c43e88

Please sign in to comment.