Skip to content

Commit

Permalink
Merge pull request #72 from noskill/lpw
Browse files Browse the repository at this point in the history
handle clip_skip and lora scale in lpw pipelines
  • Loading branch information
Necr0x0Der authored Aug 1, 2024
2 parents df5157c + c568a71 commit 5d6d975
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 29 deletions.
40 changes: 25 additions & 15 deletions multigen/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,19 @@
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
PIL_INTERPOLATION,
deprecate,
logging,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor


# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -227,8 +223,7 @@ def get_unweighted_text_embeddings(
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device))
text_embedding = prompt_embeds[0]
else:
prompt_embeds = pipe.text_encoder(
text_input_chunk.to(pipe.device), output_hidden_states=True)
prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device), output_hidden_states=True)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
Expand Down Expand Up @@ -372,11 +367,7 @@ def get_weighted_text_embeddings(

# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe,
prompt_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
clip_skip=clip_skip
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, clip_skip=clip_skip
)
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
if uncond_prompt is not None:
Expand All @@ -385,7 +376,7 @@ def get_weighted_text_embeddings(
uncond_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
clip_skip=clip_skip
clip_skip=clip_skip,
)
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)

Expand Down Expand Up @@ -454,7 +445,11 @@ def preprocess_mask(mask, batch_size, scale_factor=8):


class StableDiffusionLongPromptWeightingPipeline(
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
LoraLoaderMixin,
FromSingleFileMixin,
):
r"""
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
Expand Down Expand Up @@ -590,6 +585,8 @@ def _encode_prompt(
max_embeddings_multiples=3,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
clip_skip: Optional[int] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Expand Down Expand Up @@ -639,6 +636,7 @@ def _encode_prompt(
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=clip_skip,
lora_scale=lora_scale,
)
if prompt_embeds is None:
prompt_embeds = prompt_embeds1
Expand Down Expand Up @@ -832,6 +830,7 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
clip_skip: Optional[int] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
Expand Down Expand Up @@ -907,6 +906,9 @@ def __call__(
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Expand Down Expand Up @@ -945,6 +947,7 @@ def __call__(
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None

# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
Expand All @@ -956,6 +959,8 @@ def __call__(
max_embeddings_multiples,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
clip_skip=clip_skip,
lora_scale=lora_scale,
)
dtype = prompt_embeds.dtype

Expand Down Expand Up @@ -1086,6 +1091,7 @@ def text2img(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
clip_skip=None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
Expand Down Expand Up @@ -1143,6 +1149,9 @@ def text2img(
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Expand Down Expand Up @@ -1177,6 +1186,7 @@ def text2img(
return_dict=return_dict,
callback=callback,
is_cancelled_callback=is_cancelled_callback,
clip_skip=clip_skip,
callback_steps=callback_steps,
cross_attention_kwargs=cross_attention_kwargs,
)
Expand Down
33 changes: 19 additions & 14 deletions multigen/lpw_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,31 @@

from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.loaders import (
FromSingleFileMixin,
IPAdapterMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available,
logging,
replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor


if is_invisible_watermark_available():
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
Expand Down Expand Up @@ -263,7 +265,7 @@ def get_weighted_text_embeddings_sdxl(
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
clip_skip: Optional[int] = None,
lora_scale: Optional[int] = None
lora_scale: Optional[int] = None,
):
"""
This function can process long prompt with weights, no length limitation
Expand Down Expand Up @@ -580,7 +582,7 @@ class SDXLLongPromptWeightingPipeline(
StableDiffusionMixin,
FromSingleFileMixin,
IPAdapterMixin,
LoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
):
r"""
Expand All @@ -592,8 +594,8 @@ class SDXLLongPromptWeightingPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
Args:
Expand Down Expand Up @@ -774,7 +776,7 @@ def encode_prompt(

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale

if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1643,7 +1645,9 @@ def __call__(
image_embeds = torch.cat([negative_image_embeds, image_embeds])

# 3. Encode input prompt
(self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None)
lora_scale = (
self._cross_attention_kwargs.get("scale", None) if self._cross_attention_kwargs is not None else None
)

negative_prompt = negative_prompt if negative_prompt is not None else ""

Expand All @@ -1658,6 +1662,7 @@ def __call__(
neg_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
lora_scale=lora_scale,
)
dtype = prompt_embeds.dtype

Expand Down

0 comments on commit 5d6d975

Please sign in to comment.