diff --git a/examples/community/pipeline_stable_diffusion_xl_t5.py b/examples/community/pipeline_stable_diffusion_xl_t5.py new file mode 100644 index 000000000000..7434c90bff38 --- /dev/null +++ b/examples/community/pipeline_stable_diffusion_xl_t5.py @@ -0,0 +1,194 @@ +# Copyright Philip Brown, ppbrown@github +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: At this time, the intent is to use the T5 encoder mentioned +# below, with zero changes. +# Therefore, the model deliberately does not store the T5 encoder model bytes, +# (Since they are not unique!) +# but instead takes advantage of huggingface hub cache loading + +T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly" + + +# Caller is expected to load this, or equivalent, as model name for now +# eg: pipe = StableDiffusionXL_T5Pipeline(SDXL_NAME) +SDXL_NAME = "stabilityai/stable-diffusion-xl-base-1.0" + + + +from diffusers import StableDiffusionXLPipeline, DiffusionPipeline +from transformers import T5Tokenizer, T5EncoderModel +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor + + +from typing import Optional + +import torch.nn as nn, torch, types + +import torch.nn as nn + +class LinearWithDtype(nn.Linear): + @property + def dtype(self): + return self.weight.dtype + + +class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline): + _expected_modules = [ + "vae", "unet", "scheduler", "tokenizer", + "image_encoder", "feature_extractor", + "t5_encoder", "t5_projection", "t5_pooled_projection", + ] + + _optional_components = [ + "image_encoder", "feature_extractor", + "t5_encoder", "t5_projection", "t5_pooled_projection", + ] + + def __init__( + self, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + tokenizer: CLIPTokenizer, + t5_encoder=None, + t5_projection=None, + t5_pooled_projection=None, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + DiffusionPipeline.__init__(self) + + if t5_encoder is None: + self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME, + torch_dtype=unet.dtype) + else: + self.t5_encoder = t5_encoder + + # ----- build T5 4096 => 2048 dim projection ----- + if t5_projection is None: + self.t5_projection = LinearWithDtype(4096, 2048) # trainable + else: + self.t5_projection = t5_projection + self.t5_projection.to(dtype=unet.dtype) + # ----- build T5 4096 => 1280 dim projection ----- + if t5_pooled_projection is None: + self.t5_pooled_projection = LinearWithDtype(4096, 1280) # trainable + else: + self.t5_pooled_projection = t5_pooled_projection + self.t5_pooled_projection.to(dtype=unet.dtype) + + print("dtype of Linear is ",self.t5_projection.dtype) + + self.register_modules( + vae=vae, + unet=unet, + scheduler=scheduler, + tokenizer=tokenizer, + t5_encoder=self.t5_encoder, + t5_projection=self.t5_projection, + t5_pooled_projection=self.t5_pooled_projection, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) + + self.watermark = None + + # Parts of original SDXL class complain if these attributes are not + # at least PRESENT + self.text_encoder = self.text_encoder_2 = None + + # ------------------------------------------------------------------ + # Encode a text prompt (T5-XXL + 4096→2048 projection) + # Returns exactly four tensors in the order SDXL’s __call__ expects. + # ------------------------------------------------------------------ + def encode_prompt( + self, + prompt, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + **_, + ): + """ + Returns + ------- + prompt_embeds : Tensor [B, T, 2048] + negative_prompt_embeds : Tensor [B, T, 2048] | None + pooled_prompt_embeds : Tensor [B, 1280] + negative_pooled_prompt_embeds: Tensor [B, 1280] | None + where B = batch * num_images_per_prompt + """ + + # --- helper to tokenize on the pipeline’s device ---------------- + def _tok(text: str): + tok_out = self.tokenizer( + text, + return_tensors="pt", + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + ).to(self.device) + return tok_out.input_ids, tok_out.attention_mask + + # ---------- positive stream ------------------------------------- + ids, mask = _tok(prompt) + h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state # [b, T, 4096] + tok_pos = self.t5_projection(h_pos) # [b, T, 2048] + pool_pos = self.t5_pooled_projection(h_pos.mean(dim=1)) # [b, 1280] + + # expand for multiple images per prompt + tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0) + pool_pos = pool_pos.repeat_interleave(num_images_per_prompt, 0) + + # ---------- negative / CFG stream -------------------------------- + if do_classifier_free_guidance: + neg_text = "" if negative_prompt is None else negative_prompt + ids_n, mask_n = _tok(neg_text) + h_neg = self.t5_encoder(ids_n, attention_mask=mask_n).last_hidden_state + tok_neg = self.t5_projection(h_neg) + pool_neg = self.t5_pooled_projection(h_neg.mean(dim=1)) + + tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0) + pool_neg = pool_neg.repeat_interleave(num_images_per_prompt, 0) + else: + tok_neg = pool_neg = None + + # ----------------- final ordered return -------------------------- + # 1) positive token embeddings + # 2) negative token embeddings (or None) + # 3) positive pooled embeddings + # 4) negative pooled embeddings (or None) + return tok_pos, tok_neg, pool_pos, pool_neg