From 157c9011d87e52632113024c1dc5125426971556 Mon Sep 17 00:00:00 2001 From: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Date: Thu, 21 Sep 2023 21:35:35 +0530 Subject: [PATCH] Add BLIP Diffusion (#4388) * Add BLIP Diffusion skeleton * Add other model components * Add BLIP2, need to change it for now * Fix pipeline imports * Load pretrained ViT * Make qformer fwd pass same * Replicate fwd passes * Fix device bug * Add accelerate functions * Remove extra functions from Blip2 * Minor bug * Integrate initial review changes * Refactoring * Refactoring * Refactor * Add controlnet * Refactor * Update conversion script * Add image processor * Shift postprocessing to ImageProcessor * Refactor * Fix device * Add fast tests * Update conversion script * Fix checkpoint conversion script * Integrate review changes * Integrate reivew changes * Remove unused functions from test * Reuse HF image processor in Cond image * Create new BlipImageProcessor based on transfomers * Fix image preprocessor * Minor * Minor * Add canny preprocessing * Fix controlnet preprocessing * Fix blip diffusion test * Add controlnet test * Add initial doc strings * Integrate review changes * Refactor * Update examples * Remove DDIM comments * Add copied from for prepare_latents * Add type anotations * Add docstrings * Do black formatting * Add batch support * Make tests pass * Make controlnet tests pass * Black formatting * Fix progress bar * Fix some licensing comments * Fix imports * Refactor controlnet * Make tests faster * Edit examples * Black formatting/Ruff * Add doc * Minor Co-authored-by: Patrick von Platen * Move controlnet pipeline * Make tests faster * Fix imports * Fix formatting * Fix make errors * Fix make errors * Minor * Add suggested doc changes Co-authored-by: Sayak Paul * Edit docs * Fix 16 bit loading * Update examples * Edit toctree * Update docs/source/en/api/pipelines/blip_diffusion.md Co-authored-by: Sayak Paul * Minor * Add tips * Edit examples * Update model paths --------- Co-authored-by: Patrick von Platen Co-authored-by: Sayak Paul --- docs/source/en/_toctree.yml | 2 + .../source/en/api/pipelines/blip_diffusion.md | 29 + scripts/convert_blipdiffusion_to_diffusers.py | 343 +++++++ src/diffusers/__init__.py | 4 + src/diffusers/pipelines/__init__.py | 916 +++++++++--------- .../pipelines/blip_diffusion/__init__.py | 20 + .../blip_diffusion/blip_image_processing.py | 318 ++++++ .../blip_diffusion/modeling_blip2.py | 642 ++++++++++++ .../blip_diffusion/modeling_ctx_clip.py | 212 ++++ .../blip_diffusion/pipeline_blip_diffusion.py | 339 +++++++ .../pipelines/controlnet/__init__.py | 156 +-- .../pipeline_controlnet_blip_diffusion.py | 405 ++++++++ src/diffusers/utils/dummy_pt_objects.py | 30 + tests/pipelines/blipdiffusion/__init__.py | 0 .../blipdiffusion/test_blipdiffusion.py | 196 ++++ .../test_controlnet_blip_diffusion.py | 216 +++++ 16 files changed, 3295 insertions(+), 533 deletions(-) create mode 100644 docs/source/en/api/pipelines/blip_diffusion.md create mode 100644 scripts/convert_blipdiffusion_to_diffusers.py create mode 100644 src/diffusers/pipelines/blip_diffusion/__init__.py create mode 100644 src/diffusers/pipelines/blip_diffusion/blip_image_processing.py create mode 100644 src/diffusers/pipelines/blip_diffusion/modeling_blip2.py create mode 100644 src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py create mode 100644 src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py create mode 100644 src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py create mode 100644 tests/pipelines/blipdiffusion/__init__.py create mode 100644 tests/pipelines/blipdiffusion/test_blipdiffusion.py create mode 100644 tests/pipelines/controlnet/test_controlnet_blip_diffusion.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b56d9c094dab..cc50a956439c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -216,6 +216,8 @@ title: AudioLDM 2 - local: api/pipelines/auto_pipeline title: AutoPipeline + - local: api/pipelines/blip_diffusion + title: BLIP Diffusion - local: api/pipelines/consistency_models title: Consistency Models - local: api/pipelines/controlnet diff --git a/docs/source/en/api/pipelines/blip_diffusion.md b/docs/source/en/api/pipelines/blip_diffusion.md new file mode 100644 index 000000000000..698e1f05fd7e --- /dev/null +++ b/docs/source/en/api/pipelines/blip_diffusion.md @@ -0,0 +1,29 @@ +# Blip Diffusion + +Blip Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://arxiv.org/abs/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation. + + +The abstract from the paper is: + +*Subject-driven text-to-image generation models create novel renditions of an input subject based on text prompts. Existing models suffer from lengthy fine-tuning and difficulties preserving the subject fidelity. To overcome these limitations, we introduce BLIP-Diffusion, a new subject-driven image generation model that supports multimodal control which consumes inputs of subject images and text prompts. Unlike other subject-driven generation models, BLIP-Diffusion introduces a new multimodal encoder which is pre-trained to provide subject representation. We first pre-train the multimodal encoder following BLIP-2 to produce visual representation aligned with the text. Then we design a subject representation learning task which enables a diffusion model to leverage such visual representation and generates new subject renditions. Compared with previous methods such as DreamBooth, our model enables zero-shot subject-driven generation, and efficient fine-tuning for customized subject with up to 20x speedup. We also demonstrate that BLIP-Diffusion can be flexibly combined with existing techniques such as ControlNet and prompt-to-prompt to enable novel subject-driven generation and editing applications.* + +The original codebase can be found at [salesforce/LAVIS](https://github.com/salesforce/LAVIS/tree/main/projects/blip-diffusion). You can find the official BLIP Diffusion checkpoints under the [hf.co/SalesForce](https://hf.co/SalesForce) organization. + +`BlipDiffusionPipeline` and `BlipDiffusionControlNetPipeline` were contributed by [`ayushtues`](https://github.com/ayushtues/). + + + +Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. + + + + +## BlipDiffusionPipeline +[[autodoc]] BlipDiffusionPipeline + - all + - __call__ + +## BlipDiffusionControlNetPipeline +[[autodoc]] BlipDiffusionControlNetPipeline + - all + - __call__ diff --git a/scripts/convert_blipdiffusion_to_diffusers.py b/scripts/convert_blipdiffusion_to_diffusers.py new file mode 100644 index 000000000000..03cf67e5476b --- /dev/null +++ b/scripts/convert_blipdiffusion_to_diffusers.py @@ -0,0 +1,343 @@ +""" +This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main. +""" + +import argparse +import os +import tempfile + +import torch +from lavis.models import load_model_and_preprocess +from transformers import CLIPTokenizer +from transformers.models.blip_2.configuration_blip_2 import Blip2Config + +from diffusers import ( + AutoencoderKL, + PNDMScheduler, + UNet2DConditionModel, +) +from diffusers.pipelines import BlipDiffusionPipeline +from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor +from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel +from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel + + +BLIP2_CONFIG = { + "vision_config": { + "hidden_size": 1024, + "num_hidden_layers": 23, + "num_attention_heads": 16, + "image_size": 224, + "patch_size": 14, + "intermediate_size": 4096, + "hidden_act": "quick_gelu", + }, + "qformer_config": { + "cross_attention_frequency": 1, + "encoder_hidden_size": 1024, + "vocab_size": 30523, + }, + "num_query_tokens": 16, +} +blip2config = Blip2Config(**BLIP2_CONFIG) + + +def qformer_model_from_original_config(): + qformer = Blip2QFormerModel(blip2config) + return qformer + + +def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix): + embeddings = {} + embeddings.update( + { + f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[ + f"{original_embeddings_prefix}.word_embeddings.weight" + ] + } + ) + embeddings.update( + { + f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[ + f"{original_embeddings_prefix}.position_embeddings.weight" + ] + } + ) + embeddings.update( + {f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]} + ) + embeddings.update( + {f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]} + ) + return embeddings + + +def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix): + proj_layer = {} + proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]}) + proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]}) + proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]}) + proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]}) + proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]}) + proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]}) + return proj_layer + + +def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix): + attention = {} + attention.update( + { + f"{diffuser_attention_prefix}.attention.query.weight": model[ + f"{original_attention_prefix}.self.query.weight" + ] + } + ) + attention.update( + {f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]} + ) + attention.update( + {f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]} + ) + attention.update( + {f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]} + ) + attention.update( + { + f"{diffuser_attention_prefix}.attention.value.weight": model[ + f"{original_attention_prefix}.self.value.weight" + ] + } + ) + attention.update( + {f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]} + ) + attention.update( + {f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]} + ) + attention.update( + {f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]} + ) + attention.update( + { + f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[ + f"{original_attention_prefix}.output.LayerNorm.weight" + ] + } + ) + attention.update( + { + f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[ + f"{original_attention_prefix}.output.LayerNorm.bias" + ] + } + ) + return attention + + +def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix): + output_layers = {} + output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]}) + output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]}) + output_layers.update( + {f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]} + ) + output_layers.update( + {f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]} + ) + return output_layers + + +def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix): + encoder = {} + for i in range(blip2config.qformer_config.num_hidden_layers): + encoder.update( + attention_from_original_checkpoint( + model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention" + ) + ) + encoder.update( + attention_from_original_checkpoint( + model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention" + ) + ) + + encoder.update( + { + f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[ + f"{original_encoder_prefix}.{i}.intermediate.dense.weight" + ] + } + ) + encoder.update( + { + f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[ + f"{original_encoder_prefix}.{i}.intermediate.dense.bias" + ] + } + ) + encoder.update( + { + f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[ + f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight" + ] + } + ) + encoder.update( + { + f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[ + f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias" + ] + } + ) + + encoder.update( + output_layers_from_original_checkpoint( + model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output" + ) + ) + encoder.update( + output_layers_from_original_checkpoint( + model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query" + ) + ) + return encoder + + +def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix): + visual_encoder_layer = {} + + visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]}) + visual_encoder_layer.update( + {f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]} + ) + visual_encoder_layer.update( + {f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]} + ) + visual_encoder_layer.update( + {f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]} + ) + visual_encoder_layer.update( + {f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]} + ) + visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]}) + + return visual_encoder_layer + + +def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix): + visual_encoder = {} + + visual_encoder.update( + { + f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"] + .unsqueeze(0) + .unsqueeze(0) + } + ) + visual_encoder.update( + { + f"{diffuser_prefix}.embeddings.position_embedding": model[ + f"{original_prefix}.positional_embedding" + ].unsqueeze(0) + } + ) + visual_encoder.update( + {f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]} + ) + visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]}) + visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]}) + + for i in range(blip2config.vision_config.num_hidden_layers): + visual_encoder.update( + visual_encoder_layer_from_original_checkpoint( + model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}" + ) + ) + + visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]}) + visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]}) + + return visual_encoder + + +def qformer_original_checkpoint_to_diffusers_checkpoint(model): + qformer_checkpoint = {} + qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings")) + qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]}) + qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer")) + qformer_checkpoint.update( + encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer") + ) + qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder")) + return qformer_checkpoint + + +def get_qformer(model): + print("loading qformer") + + qformer = qformer_model_from_original_config() + qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model) + + load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer) + + print("done loading qformer") + return qformer + + +def load_checkpoint_to_model(checkpoint, model): + with tempfile.NamedTemporaryFile(delete=False) as file: + torch.save(checkpoint, file.name) + del checkpoint + model.load_state_dict(torch.load(file.name), strict=False) + + os.remove(file.name) + + +def save_blip_diffusion_model(model, args): + qformer = get_qformer(model) + qformer.eval() + + text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") + + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + vae.eval() + text_encoder.eval() + scheduler = PNDMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + set_alpha_to_one=False, + skip_prk_steps=True, + ) + tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer") + image_processor = BlipImageProcessor() + blip_diffusion = BlipDiffusionPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + scheduler=scheduler, + qformer=qformer, + image_processor=image_processor, + ) + blip_diffusion.save_pretrained(args.checkpoint_path) + + +def main(args): + model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True) + save_blip_diffusion_model(model.state_dict(), args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") + args = parser.parse_args() + + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8fdf8df60f4c..cc82cb09f969 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -197,6 +197,8 @@ "AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel", "AudioLDMPipeline", + "BlipDiffusionControlNetPipeline", + "BlipDiffusionPipeline", "CLIPImageProjection", "CycleDiffusionPipeline", "IFImg2ImgPipeline", @@ -458,6 +460,8 @@ AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image, + BlipDiffusionControlNetPipeline, + BlipDiffusionPipeline, CLIPImageProjection, ConsistencyModelPipeline, DanceDiffusionPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ef3566cf61c0..94e849a68788 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,456 +1,460 @@ -from typing import TYPE_CHECKING - -from ..utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - get_objects_from_module, - is_flax_available, - is_k_diffusion_available, - is_librosa_available, - is_note_seq_available, - is_onnx_available, - is_torch_available, - is_transformers_available, -) - - -# These modules contain pipelines from multiple libraries/frameworks -_dummy_objects = {} -_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []} - -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_pt_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) -else: - _import_structure["auto_pipeline"] = [ - "AutoPipelineForImage2Image", - "AutoPipelineForInpainting", - "AutoPipelineForText2Image", - ] - _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] - _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] - _import_structure["ddim"] = ["DDIMPipeline"] - _import_structure["ddpm"] = ["DDPMPipeline"] - _import_structure["dit"] = ["DiTPipeline"] - _import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"]) - _import_structure["latent_diffusion_uncond"] = ["LDMPipeline"] - _import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"] - _import_structure["pndm"] = ["PNDMPipeline"] - _import_structure["repaint"] = ["RePaintPipeline"] - _import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"] - _import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"] -try: - if not (is_torch_available() and is_librosa_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_torch_and_librosa_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects)) -else: - _import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"] -try: - if not (is_torch_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_torch_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) -else: - _import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"] - _import_structure["audioldm"] = ["AudioLDMPipeline"] - _import_structure["audioldm2"] = [ - "AudioLDM2Pipeline", - "AudioLDM2ProjectionModel", - "AudioLDM2UNet2DConditionModel", - ] - _import_structure["controlnet"].extend( - [ - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - "StableDiffusionControlNetPipeline", - "StableDiffusionXLControlNetImg2ImgPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - "StableDiffusionXLControlNetPipeline", - ] - ) - _import_structure["deepfloyd_if"] = [ - "IFImg2ImgPipeline", - "IFImg2ImgSuperResolutionPipeline", - "IFInpaintingPipeline", - "IFInpaintingSuperResolutionPipeline", - "IFPipeline", - "IFSuperResolutionPipeline", - ] - _import_structure["kandinsky"] = [ - "KandinskyCombinedPipeline", - "KandinskyImg2ImgCombinedPipeline", - "KandinskyImg2ImgPipeline", - "KandinskyInpaintCombinedPipeline", - "KandinskyInpaintPipeline", - "KandinskyPipeline", - "KandinskyPriorPipeline", - ] - _import_structure["kandinsky2_2"] = [ - "KandinskyV22CombinedPipeline", - "KandinskyV22ControlnetImg2ImgPipeline", - "KandinskyV22ControlnetPipeline", - "KandinskyV22Img2ImgCombinedPipeline", - "KandinskyV22Img2ImgPipeline", - "KandinskyV22InpaintCombinedPipeline", - "KandinskyV22InpaintPipeline", - "KandinskyV22Pipeline", - "KandinskyV22PriorEmb2EmbPipeline", - "KandinskyV22PriorPipeline", - ] - _import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"]) - _import_structure["musicldm"] = ["MusicLDMPipeline"] - _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] - _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] - _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] - _import_structure["stable_diffusion"].extend( - [ - "CLIPImageProjection", - "CycleDiffusionPipeline", - "StableDiffusionAttendAndExcitePipeline", - "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENTextImagePipeline", - "StableDiffusionImageVariationPipeline", - "StableDiffusionImg2ImgPipeline", - "StableDiffusionInpaintPipeline", - "StableDiffusionInpaintPipelineLegacy", - "StableDiffusionInstructPix2PixPipeline", - "StableDiffusionLatentUpscalePipeline", - "StableDiffusionLDM3DPipeline", - "StableDiffusionModelEditingPipeline", - "StableDiffusionPanoramaPipeline", - "StableDiffusionParadigmsPipeline", - "StableDiffusionPipeline", - "StableDiffusionPix2PixZeroPipeline", - "StableDiffusionSAGPipeline", - "StableDiffusionUpscalePipeline", - "StableUnCLIPImg2ImgPipeline", - "StableUnCLIPPipeline", - ] - ) - _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] - _import_structure["stable_diffusion_xl"] = [ - "StableDiffusionXLImg2ImgPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLPipeline", - ] - _import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"] - _import_structure["text_to_video_synthesis"] = [ - "TextToVideoSDPipeline", - "TextToVideoZeroPipeline", - "VideoToVideoSDPipeline", - ] - _import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"] - _import_structure["unidiffuser"] = [ - "ImageTextPipelineOutput", - "UniDiffuserModel", - "UniDiffuserPipeline", - "UniDiffuserTextDecoder", - ] - _import_structure["versatile_diffusion"] = [ - "VersatileDiffusionDualGuidedPipeline", - "VersatileDiffusionImageVariationPipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionTextToImagePipeline", - ] - _import_structure["vq_diffusion"] = ["VQDiffusionPipeline"] - _import_structure["wuerstchen"] = [ - "WuerstchenCombinedPipeline", - "WuerstchenDecoderPipeline", - "WuerstchenPriorPipeline", - ] -try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_onnx_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_onnx_objects)) -else: - _import_structure["onnx_utils"] = ["OnnxRuntimeModel"] -try: - if not (is_torch_available() and is_transformers_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) -else: - _import_structure["stable_diffusion"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) -try: - if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) -else: - _import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"]) -try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_flax_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_flax_objects)) -else: - _import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"] -try: - if not (is_flax_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_flax_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) -else: - _import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"]) - _import_structure["stable_diffusion"].extend( - [ - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - ] - ) -try: - if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) -else: - _import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"] - -if TYPE_CHECKING: - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_pt_objects import * # noqa F403 - - else: - from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image - from .consistency_models import ConsistencyModelPipeline - from .dance_diffusion import DanceDiffusionPipeline - from .ddim import DDIMPipeline - from .ddpm import DDPMPipeline - from .dit import DiTPipeline - from .latent_diffusion import LDMSuperResolutionPipeline - from .latent_diffusion_uncond import LDMPipeline - from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput - from .pndm import PNDMPipeline - from .repaint import RePaintPipeline - from .score_sde_ve import ScoreSdeVePipeline - from .stochastic_karras_ve import KarrasVePipeline - - try: - if not (is_torch_available() and is_librosa_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_librosa_objects import * - else: - from .audio_diffusion import AudioDiffusionPipeline, Mel - - try: - if not (is_torch_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_transformers_objects import * - else: - from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline - from .audioldm import AudioLDMPipeline - from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel - from .controlnet import ( - StableDiffusionControlNetImg2ImgPipeline, - StableDiffusionControlNetInpaintPipeline, - StableDiffusionControlNetPipeline, - StableDiffusionXLControlNetImg2ImgPipeline, - StableDiffusionXLControlNetInpaintPipeline, - StableDiffusionXLControlNetPipeline, - ) - from .deepfloyd_if import ( - IFImg2ImgPipeline, - IFImg2ImgSuperResolutionPipeline, - IFInpaintingPipeline, - IFInpaintingSuperResolutionPipeline, - IFPipeline, - IFSuperResolutionPipeline, - ) - from .kandinsky import ( - KandinskyCombinedPipeline, - KandinskyImg2ImgCombinedPipeline, - KandinskyImg2ImgPipeline, - KandinskyInpaintCombinedPipeline, - KandinskyInpaintPipeline, - KandinskyPipeline, - KandinskyPriorPipeline, - ) - from .kandinsky2_2 import ( - KandinskyV22CombinedPipeline, - KandinskyV22ControlnetImg2ImgPipeline, - KandinskyV22ControlnetPipeline, - KandinskyV22Img2ImgCombinedPipeline, - KandinskyV22Img2ImgPipeline, - KandinskyV22InpaintCombinedPipeline, - KandinskyV22InpaintPipeline, - KandinskyV22Pipeline, - KandinskyV22PriorEmb2EmbPipeline, - KandinskyV22PriorPipeline, - ) - from .latent_diffusion import LDMTextToImagePipeline - from .musicldm import MusicLDMPipeline - from .paint_by_example import PaintByExamplePipeline - from .semantic_stable_diffusion import SemanticStableDiffusionPipeline - from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline - from .stable_diffusion import ( - CLIPImageProjection, - CycleDiffusionPipeline, - StableDiffusionAttendAndExcitePipeline, - StableDiffusionDepth2ImgPipeline, - StableDiffusionDiffEditPipeline, - StableDiffusionGLIGENPipeline, - StableDiffusionGLIGENTextImagePipeline, - StableDiffusionImageVariationPipeline, - StableDiffusionImg2ImgPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionInpaintPipelineLegacy, - StableDiffusionInstructPix2PixPipeline, - StableDiffusionLatentUpscalePipeline, - StableDiffusionLDM3DPipeline, - StableDiffusionModelEditingPipeline, - StableDiffusionPanoramaPipeline, - StableDiffusionParadigmsPipeline, - StableDiffusionPipeline, - StableDiffusionPix2PixZeroPipeline, - StableDiffusionSAGPipeline, - StableDiffusionUpscalePipeline, - StableUnCLIPImg2ImgPipeline, - StableUnCLIPPipeline, - ) - from .stable_diffusion_safe import StableDiffusionPipelineSafe - from .stable_diffusion_xl import ( - StableDiffusionXLImg2ImgPipeline, - StableDiffusionXLInpaintPipeline, - StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLPipeline, - ) - from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline - from .text_to_video_synthesis import ( - TextToVideoSDPipeline, - TextToVideoZeroPipeline, - VideoToVideoSDPipeline, - ) - from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline - from .unidiffuser import ( - ImageTextPipelineOutput, - UniDiffuserModel, - UniDiffuserPipeline, - UniDiffuserTextDecoder, - ) - from .versatile_diffusion import ( - VersatileDiffusionDualGuidedPipeline, - VersatileDiffusionImageVariationPipeline, - VersatileDiffusionPipeline, - VersatileDiffusionTextToImagePipeline, - ) - from .vq_diffusion import VQDiffusionPipeline - from .wuerstchen import ( - WuerstchenCombinedPipeline, - WuerstchenDecoderPipeline, - WuerstchenPriorPipeline, - ) - - try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_onnx_objects import * # noqa F403 - - else: - from .onnx_utils import OnnxRuntimeModel - - try: - if not (is_torch_available() and is_transformers_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_transformers_and_onnx_objects import * - else: - from .stable_diffusion import ( - OnnxStableDiffusionImg2ImgPipeline, - OnnxStableDiffusionInpaintPipeline, - OnnxStableDiffusionInpaintPipelineLegacy, - OnnxStableDiffusionPipeline, - OnnxStableDiffusionUpscalePipeline, - StableDiffusionOnnxPipeline, - ) - - try: - if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * - else: - from .stable_diffusion import StableDiffusionKDiffusionPipeline - - try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_flax_objects import * # noqa F403 - else: - from .pipeline_flax_utils import FlaxDiffusionPipeline - - try: - if not (is_flax_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_flax_and_transformers_objects import * - else: - from .controlnet import FlaxStableDiffusionControlNetPipeline - from .stable_diffusion import ( - FlaxStableDiffusionImg2ImgPipeline, - FlaxStableDiffusionInpaintPipeline, - FlaxStableDiffusionPipeline, - ) - - try: - if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 - - else: - from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline - -else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - ) - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) +from typing import TYPE_CHECKING + +from ..utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_k_diffusion_available, + is_librosa_available, + is_note_seq_available, + is_onnx_available, + is_torch_available, + is_transformers_available, +) + + +# These modules contain pipelines from multiple libraries/frameworks +_dummy_objects = {} +_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_pt_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["auto_pipeline"] = [ + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + ] + _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] + _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] + _import_structure["ddim"] = ["DDIMPipeline"] + _import_structure["ddpm"] = ["DDPMPipeline"] + _import_structure["dit"] = ["DiTPipeline"] + _import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"]) + _import_structure["latent_diffusion_uncond"] = ["LDMPipeline"] + _import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"] + _import_structure["pndm"] = ["PNDMPipeline"] + _import_structure["repaint"] = ["RePaintPipeline"] + _import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"] + _import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"] +try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_librosa_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects)) +else: + _import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"] +try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"] + _import_structure["audioldm"] = ["AudioLDMPipeline"] + _import_structure["audioldm2"] = [ + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + ] + _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] + _import_structure["controlnet"].extend( + [ + "BlipDiffusionControlNetPipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPipeline", + ] + ) + _import_structure["deepfloyd_if"] = [ + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + ] + _import_structure["kandinsky"] = [ + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + ] + _import_structure["kandinsky2_2"] = [ + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + ] + _import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"]) + _import_structure["musicldm"] = ["MusicLDMPipeline"] + _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] + _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] + _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] + _import_structure["stable_diffusion"].extend( + [ + "CLIPImageProjection", + "CycleDiffusionPipeline", + "StableDiffusionAttendAndExcitePipeline", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionLDM3DPipeline", + "StableDiffusionModelEditingPipeline", + "StableDiffusionPanoramaPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionPipeline", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionSAGPipeline", + "StableDiffusionUpscalePipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + ] + ) + _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] + _import_structure["stable_diffusion_xl"] = [ + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPipeline", + ] + _import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"] + _import_structure["text_to_video_synthesis"] = [ + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "VideoToVideoSDPipeline", + ] + _import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"] + _import_structure["unidiffuser"] = [ + "ImageTextPipelineOutput", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + ] + _import_structure["versatile_diffusion"] = [ + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + ] + _import_structure["vq_diffusion"] = ["VQDiffusionPipeline"] + _import_structure["wuerstchen"] = [ + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ] +try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_onnx_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_onnx_objects)) +else: + _import_structure["onnx_utils"] = ["OnnxRuntimeModel"] +try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) +else: + _import_structure["stable_diffusion"].extend( + [ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ] + ) +try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) +else: + _import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"]) +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_flax_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_objects)) +else: + _import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"] +try: + if not (is_flax_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + _import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"]) + _import_structure["stable_diffusion"].extend( + [ + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + ] + ) +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) +else: + _import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"] + +if TYPE_CHECKING: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 + + else: + from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image + from .consistency_models import ConsistencyModelPipeline + from .dance_diffusion import DanceDiffusionPipeline + from .ddim import DDIMPipeline + from .ddpm import DDPMPipeline + from .dit import DiTPipeline + from .latent_diffusion import LDMSuperResolutionPipeline + from .latent_diffusion_uncond import LDMPipeline + from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput + from .pndm import PNDMPipeline + from .repaint import RePaintPipeline + from .score_sde_ve import ScoreSdeVePipeline + from .stochastic_karras_ve import KarrasVePipeline + + try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_librosa_objects import * + else: + from .audio_diffusion import AudioDiffusionPipeline, Mel + + try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_objects import * + else: + from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline + from .audioldm import AudioLDMPipeline + from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel + from .blip_diffusion import BlipDiffusionPipeline + from .controlnet import ( + BlipDiffusionControlNetPipeline, + StableDiffusionControlNetImg2ImgPipeline, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionControlNetPipeline, + StableDiffusionXLControlNetImg2ImgPipeline, + StableDiffusionXLControlNetInpaintPipeline, + StableDiffusionXLControlNetPipeline, + ) + from .deepfloyd_if import ( + IFImg2ImgPipeline, + IFImg2ImgSuperResolutionPipeline, + IFInpaintingPipeline, + IFInpaintingSuperResolutionPipeline, + IFPipeline, + IFSuperResolutionPipeline, + ) + from .kandinsky import ( + KandinskyCombinedPipeline, + KandinskyImg2ImgCombinedPipeline, + KandinskyImg2ImgPipeline, + KandinskyInpaintCombinedPipeline, + KandinskyInpaintPipeline, + KandinskyPipeline, + KandinskyPriorPipeline, + ) + from .kandinsky2_2 import ( + KandinskyV22CombinedPipeline, + KandinskyV22ControlnetImg2ImgPipeline, + KandinskyV22ControlnetPipeline, + KandinskyV22Img2ImgCombinedPipeline, + KandinskyV22Img2ImgPipeline, + KandinskyV22InpaintCombinedPipeline, + KandinskyV22InpaintPipeline, + KandinskyV22Pipeline, + KandinskyV22PriorEmb2EmbPipeline, + KandinskyV22PriorPipeline, + ) + from .latent_diffusion import LDMTextToImagePipeline + from .musicldm import MusicLDMPipeline + from .paint_by_example import PaintByExamplePipeline + from .semantic_stable_diffusion import SemanticStableDiffusionPipeline + from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline + from .stable_diffusion import ( + CLIPImageProjection, + CycleDiffusionPipeline, + StableDiffusionAttendAndExcitePipeline, + StableDiffusionDepth2ImgPipeline, + StableDiffusionDiffEditPipeline, + StableDiffusionGLIGENPipeline, + StableDiffusionGLIGENTextImagePipeline, + StableDiffusionImageVariationPipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionInpaintPipelineLegacy, + StableDiffusionInstructPix2PixPipeline, + StableDiffusionLatentUpscalePipeline, + StableDiffusionLDM3DPipeline, + StableDiffusionModelEditingPipeline, + StableDiffusionPanoramaPipeline, + StableDiffusionParadigmsPipeline, + StableDiffusionPipeline, + StableDiffusionPix2PixZeroPipeline, + StableDiffusionSAGPipeline, + StableDiffusionUpscalePipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .stable_diffusion_xl import ( + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLInstructPix2PixPipeline, + StableDiffusionXLPipeline, + ) + from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline + from .text_to_video_synthesis import ( + TextToVideoSDPipeline, + TextToVideoZeroPipeline, + VideoToVideoSDPipeline, + ) + from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline + from .unidiffuser import ( + ImageTextPipelineOutput, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, + ) + from .versatile_diffusion import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) + from .vq_diffusion import VQDiffusionPipeline + from .wuerstchen import ( + WuerstchenCombinedPipeline, + WuerstchenDecoderPipeline, + WuerstchenPriorPipeline, + ) + + try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_onnx_objects import * # noqa F403 + + else: + from .onnx_utils import OnnxRuntimeModel + + try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_onnx_objects import * + else: + from .stable_diffusion import ( + OnnxStableDiffusionImg2ImgPipeline, + OnnxStableDiffusionInpaintPipeline, + OnnxStableDiffusionInpaintPipelineLegacy, + OnnxStableDiffusionPipeline, + OnnxStableDiffusionUpscalePipeline, + StableDiffusionOnnxPipeline, + ) + + try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * + else: + from .stable_diffusion import StableDiffusionKDiffusionPipeline + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_objects import * # noqa F403 + else: + from .pipeline_flax_utils import FlaxDiffusionPipeline + + try: + if not (is_flax_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_and_transformers_objects import * + else: + from .controlnet import FlaxStableDiffusionControlNetPipeline + from .stable_diffusion import ( + FlaxStableDiffusionImg2ImgPipeline, + FlaxStableDiffusionInpaintPipeline, + FlaxStableDiffusionPipeline, + ) + + try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 + + else: + from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/blip_diffusion/__init__.py b/src/diffusers/pipelines/blip_diffusion/__init__.py new file mode 100644 index 000000000000..af6c879d5ce8 --- /dev/null +++ b/src/diffusers/pipelines/blip_diffusion/__init__.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL +from PIL import Image + +from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline +else: + from .blip_image_processing import BlipImageProcessor + from .modeling_blip2 import Blip2QFormerModel + from .modeling_ctx_clip import ContextCLIPTextModel + from .pipeline_blip_diffusion import BlipDiffusionPipeline diff --git a/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py b/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py new file mode 100644 index 000000000000..2c2911eb9522 --- /dev/null +++ b/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py @@ -0,0 +1,318 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for BLIP.""" + +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format +from transformers.image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, +) +from transformers.utils import TensorType, is_vision_available, logging + +from diffusers.utils import numpy_to_pil + + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +# We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop +# Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor +class BlipImageProcessor(BaseImageProcessor): + r""" + Constructs a BLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `resample` parameter in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + do_center_crop: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + self.do_center_crop = do_center_crop + + # Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + do_center_crop: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + do_convert_rgb: bool = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Controls the size of the image after `resize`. The shortest edge of the image is resized to + `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image + is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest + edge equal to `int(size["shortest_edge"] * (1333 / 800))`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to normalize the image by if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to normalize the image by if `do_normalize` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if do_resize and size is None or resample is None: + raise ValueError("Size and resample must be specified if do_resize is True.") + + if do_rescale and rescale_factor is None: + raise ValueError("Rescale factor must be specified if do_rescale is True.") + + if do_normalize and (image_mean is None or image_std is None): + raise ValueError("Image mean and std must be specified if do_normalize is True.") + + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + if do_center_crop: + images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + return encoded_outputs + + # Follows diffusers.VaeImageProcessor.postprocess + def postprocess(self, sample: torch.FloatTensor, output_type: str = "pil"): + if output_type not in ["pt", "np", "pil"]: + raise ValueError( + f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" + ) + + # Equivalent to diffusers.VaeImageProcessor.denormalize + sample = (sample / 2 + 0.5).clamp(0, 1) + if output_type == "pt": + return sample + + # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "np": + return sample + # Output_type must be 'pil' + sample = numpy_to_pil(sample) + return sample diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py new file mode 100644 index 000000000000..e2862af23283 --- /dev/null +++ b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py @@ -0,0 +1,642 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers import BertTokenizer +from transformers.activations import QuickGELUActivation as QuickGELU +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig +from transformers.models.blip_2.modeling_blip_2 import ( + Blip2Encoder, + Blip2PreTrainedModel, + Blip2QFormerAttention, + Blip2QFormerIntermediate, + Blip2QFormerOutput, +) +from transformers.pytorch_utils import apply_chunking_to_forward +from transformers.utils import ( + logging, + replace_return_docstrings, +) + + +logger = logging.get_logger(__name__) + + +# There is an implementation of Blip2 in `transformers` : https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip_2/modeling_blip_2.py. +# But it doesn't support getting multimodal embeddings. So, this module can be +# replaced with a future `transformers` version supports that. +class Blip2TextEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + batch_size = embeddings.shape[0] + # repeat the query embeddings for batch size + query_embeds = query_embeds.repeat(batch_size, 1, 1) + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + embeddings = embeddings.to(query_embeds.dtype) + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2 +class Blip2VisionEmbeddings(nn.Module): + def __init__(self, config: Blip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype) + return embeddings + + +# The Qformer encoder, which takes the visual embeddings, and the text input, to get multimodal embeddings +class Blip2QFormerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions, query_length) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if layer_module.has_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# The layers making up the Qformer encoder +class Blip2QFormerLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Blip2QFormerAttention(config) + + self.layer_idx = layer_idx + + if layer_idx % config.cross_attention_frequency == 0: + self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = Blip2QFormerIntermediate(config) + self.intermediate_query = Blip2QFormerIntermediate(config) + self.output_query = Blip2QFormerOutput(config) + self.output = Blip2QFormerOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + if encoder_hidden_states is None: + raise ValueError("encoder_hidden_states must be given for cross-attention layers") + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +# ProjLayer used to project the multimodal Blip2 embeddings to be used in the text encoder +class ProjLayer(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12): + super().__init__() + + # Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm + self.dense1 = nn.Linear(in_dim, hidden_dim) + self.act_fn = QuickGELU() + self.dense2 = nn.Linear(hidden_dim, out_dim) + self.dropout = nn.Dropout(drop_p) + + self.LayerNorm = nn.LayerNorm(out_dim, eps=eps) + + def forward(self, x): + x_in = x + + x = self.LayerNorm(x) + x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in + + return x + + +# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2 +class Blip2VisionModel(Blip2PreTrainedModel): + main_input_name = "pixel_values" + config_class = Blip2VisionConfig + + def __init__(self, config: Blip2VisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + self.embeddings = Blip2VisionEmbeddings(config) + self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = Blip2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.post_init() + + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layernorm(hidden_states) + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.embeddings + + +# Qformer model, used to get multimodal embeddings from the text and image inputs +class Blip2QFormerModel(Blip2PreTrainedModel): + """ + Querying Transformer (Q-Former), used in BLIP-2. + """ + + def __init__(self, config: Blip2Config): + super().__init__(config) + self.config = config + self.embeddings = Blip2TextEmbeddings(config.qformer_config) + self.visual_encoder = Blip2VisionModel(config.vision_config) + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + if not hasattr(config, "tokenizer") or config.tokenizer is None: + self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right") + else: + self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right") + self.tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + self.proj_layer = ProjLayer( + in_dim=config.qformer_config.hidden_size, + out_dim=config.qformer_config.hidden_size, + hidden_dim=config.qformer_config.hidden_size * 4, + drop_p=0.1, + eps=1e-12, + ) + + self.encoder = Blip2QFormerEncoder(config.qformer_config) + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int], + device: torch.device, + has_query: bool = False, + ) -> torch.Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + device (`torch.device`): + The device of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + text_input=None, + image_input=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of: + shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and + value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are + used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key + value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape + `(batch_size, sequence_length)`. + use_cache (`bool`, `optional`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + + text = self.tokenizer(text_input, return_tensors="pt", padding=True) + text = text.to(self.device) + input_ids = text.input_ids + batch_size = input_ids.shape[0] + query_atts = torch.ones((batch_size, self.query_tokens.size()[1]), dtype=torch.long).to(self.device) + attention_mask = torch.cat([query_atts, text.attention_mask], dim=1) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 + ) + + query_length = self.query_tokens.shape[1] + + embedding_output = self.embeddings( + input_ids=input_ids, + query_embeds=self.query_tokens, + past_key_values_length=past_key_values_length, + ) + + # embedding_output = self.layernorm(query_embeds) + # embedding_output = self.dropout(embedding_output) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + image_embeds_frozen = self.visual_encoder(image_input).last_hidden_state + # image_embeds_frozen = torch.ones_like(image_embeds_frozen) + encoder_hidden_states = image_embeds_frozen + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, list): + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if isinstance(encoder_attention_mask, list): + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.qformer_config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + return self.proj_layer(sequence_output[:, :query_length, :]) + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py b/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py new file mode 100644 index 000000000000..53d57188743d --- /dev/null +++ b/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py @@ -0,0 +1,212 @@ +# Copyright 2023 Salesforce.com, inc. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from transformers import CLIPPreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.models.clip.configuration_clip import CLIPTextConfig +from transformers.models.clip.modeling_clip import ( + CLIPEncoder, + _expand_mask, +) + + +# This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip +# Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer +# They pass through the clip model, along with the text embeddings, and interact with them using self attention +class ContextCLIPTextModel(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = ContextCLIPTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + ctx_embeddings: torch.Tensor = None, + ctx_begin_pos: list = None, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + return self.text_model( + ctx_embeddings=ctx_embeddings, + ctx_begin_pos=ctx_begin_pos, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class ContextCLIPTextTransformer(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = ContextCLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + def forward( + self, + ctx_embeddings: torch.Tensor, + ctx_begin_pos: list, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify either input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + ctx_embeddings=ctx_embeddings, + ctx_begin_pos=ctx_begin_pos, + ) + + bsz, seq_len = input_shape + if ctx_embeddings is not None: + seq_len += ctx_embeddings.size(1) + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( + hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=input_ids.device), + input_ids.to(torch.int).argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def _build_causal_attention_mask(self, bsz, seq_len, dtype): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) + mask.fill_(torch.tensor(torch.finfo(dtype).min)) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask + + +class ContextCLIPTextEmbeddings(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward( + self, + ctx_embeddings: torch.Tensor, + ctx_begin_pos: list, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + if ctx_embeddings is None: + ctx_len = 0 + else: + ctx_len = ctx_embeddings.shape[1] + + seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + # for each input embeddings, add the ctx embeddings at the correct position + input_embeds_ctx = [] + bsz = inputs_embeds.shape[0] + + if ctx_embeddings is not None: + for i in range(bsz): + cbp = ctx_begin_pos[i] + + prefix = inputs_embeds[i, :cbp] + # remove the special token embedding + suffix = inputs_embeds[i, cbp:] + + input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0)) + + inputs_embeds = torch.stack(input_embeds_ctx, dim=0) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings diff --git a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py new file mode 100644 index 000000000000..3ca456c6f459 --- /dev/null +++ b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py @@ -0,0 +1,339 @@ +# Copyright 2023 Salesforce.com, inc. +# Copyright 2023 The HuggingFace Team. All rights reserved.# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Union + +import PIL +import torch +from transformers import CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import PNDMScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .blip_image_processing import BlipImageProcessor +from .modeling_blip2 import Blip2QFormerModel +from .modeling_ctx_clip import ContextCLIPTextModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers.pipelines import BlipDiffusionPipeline + >>> from diffusers.utils import load_image + >>> import torch + + >>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained( + ... "Salesforce/blipdiffusion", torch_dtype=torch.float16 + ... ).to("cuda") + + + >>> cond_subject = "dog" + >>> tgt_subject = "dog" + >>> text_prompt_input = "swimming underwater" + + >>> cond_image = load_image( + ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg" + ... ) + >>> guidance_scale = 7.5 + >>> num_inference_steps = 25 + >>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" + + + >>> output = blip_diffusion_pipe( + ... text_prompt_input, + ... cond_image, + ... cond_subject, + ... tgt_subject, + ... guidance_scale=guidance_scale, + ... num_inference_steps=num_inference_steps, + ... neg_prompt=negative_prompt, + ... height=512, + ... width=512, + ... ).images + >>> output[0].save("image.png") + ``` +""" + + +class BlipDiffusionPipeline(DiffusionPipeline): + """ + Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer ([`CLIPTokenizer`]): + Tokenizer for the text encoder + text_encoder ([`ContextCLIPTextModel`]): + Text encoder to encode the text prompt + vae ([`AutoencoderKL`]): + VAE model to map the latents to the image + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + scheduler ([`PNDMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + qformer ([`Blip2QFormerModel`]): + QFormer model to get multi-modal embeddings from the text and image. + image_processor ([`BlipImageProcessor`]): + Image Processor to preprocess and postprocess the image. + ctx_begin_pos (int, `optional`, defaults to 2): + Position of the context token in the text encoder. + """ + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: ContextCLIPTextModel, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + scheduler: PNDMScheduler, + qformer: Blip2QFormerModel, + image_processor: BlipImageProcessor, + ctx_begin_pos: int = 2, + mean: List[float] = None, + std: List[float] = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + scheduler=scheduler, + qformer=qformer, + image_processor=image_processor, + ) + self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std) + + def get_query_embeddings(self, input_image, src_subject): + return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False) + + # from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it + def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): + rv = [] + for prompt, tgt_subject in zip(prompts, tgt_subjects): + prompt = f"a {tgt_subject} {prompt.strip()}" + # a trick to amplify the prompt + rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps))) + + return rv + + # Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def encode_prompt(self, query_embeds, prompt): + # embeddings for prompt, with query_embeds as context + max_len = self.text_encoder.text_model.config.max_position_embeddings + max_len -= self.qformer.config.num_query_tokens + + tokenized_prompt = self.tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=max_len, + return_tensors="pt", + ).to(self.device) + + batch_size = query_embeds.shape[0] + ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size + + text_embeddings = self.text_encoder( + input_ids=tokenized_prompt.input_ids, + ctx_embeddings=query_embeds, + ctx_begin_pos=ctx_begin_pos, + )[0] + + return text_embeddings + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: List[str], + reference_image: PIL.Image.Image, + source_subject_category: List[str], + target_subject_category: List[str], + latents: Optional[torch.FloatTensor] = None, + guidance_scale: float = 7.5, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + neg_prompt: Optional[str] = "", + prompt_strength: float = 1.0, + prompt_reps: int = 20, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`List[str]`): + The prompt or prompts to guide the image generation. + reference_image (`PIL.Image.Image`): + The reference image to condition the generation on. + source_subject_category (`List[str]`): + The source subject category. + target_subject_category (`List[str]`): + The target subject category. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by random sampling. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + height (`int`, *optional*, defaults to 512): + The height of the generated image. + width (`int`, *optional*, defaults to 512): + The width of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + neg_prompt (`str`, *optional*, defaults to ""): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_strength (`float`, *optional*, defaults to 1.0): + The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps + to amplify the prompt. + prompt_reps (`int`, *optional*, defaults to 20): + The number of times the prompt is repeated along with prompt_strength to amplify the prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + reference_image = self.image_processor.preprocess( + reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" + )["pixel_values"] + reference_image = reference_image.to(self.device) + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(source_subject_category, str): + source_subject_category = [source_subject_category] + if isinstance(target_subject_category, str): + target_subject_category = [target_subject_category] + + batch_size = len(prompt) + + prompt = self._build_prompt( + prompts=prompt, + tgt_subjects=target_subject_category, + prompt_strength=prompt_strength, + prompt_reps=prompt_reps, + ) + query_embeds = self.get_query_embeddings(reference_image, source_subject_category) + text_embeddings = self.encode_prompt(query_embeds, prompt) + do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + max_length = self.text_encoder.text_model.config.max_position_embeddings + + uncond_input = self.tokenizer( + [neg_prompt] * batch_size, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder( + input_ids=uncond_input.input_ids.to(self.device), + ctx_embeddings=None, + )[0] + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1) + latents = self.prepare_latents( + batch_size=batch_size, + num_channels=self.unet.config.in_channels, + height=height // scale_down_factor, + width=width // scale_down_factor, + generator=generator, + latents=latents, + dtype=self.unet.dtype, + device=self.device, + ) + # set timesteps + extra_set_kwargs = {} + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + do_classifier_free_guidance = guidance_scale > 1.0 + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + noise_pred = self.unet( + latent_model_input, + timestep=t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals=None, + mid_block_additional_residual=None, + )["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + )["prev_sample"] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/controlnet/__init__.py b/src/diffusers/pipelines/controlnet/__init__.py index 5c551533f3a8..e14fd438a5bb 100644 --- a/src/diffusers/pipelines/controlnet/__init__.py +++ b/src/diffusers/pipelines/controlnet/__init__.py @@ -1,77 +1,79 @@ -from typing import TYPE_CHECKING - -from ...utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - get_objects_from_module, - is_flax_available, - is_torch_available, - is_transformers_available, -) - - -_dummy_objects = {} -_import_structure = {} - -try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_torch_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) -else: - _import_structure["multicontrolnet"] = ["MultiControlNetModel"] - _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"] - _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"] - _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"] - _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"] - _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"] - _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"] -try: - if not (is_transformers_available() and is_flax_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_flax_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) -else: - _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] - - -if TYPE_CHECKING: - try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() - - except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import * - else: - from .multicontrolnet import MultiControlNetModel - from .pipeline_controlnet import StableDiffusionControlNetPipeline - from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline - from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline - from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline - from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline - from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline - - try: - if not (is_transformers_available() and is_flax_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 - else: - from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline - - -else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - ) - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["multicontrolnet"] = ["MultiControlNetModel"] + _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"] + _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"] + _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"] + _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"] + _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"] + _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"] + _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"] +try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] + + +if TYPE_CHECKING: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .multicontrolnet import MultiControlNetModel + from .pipeline_controlnet import StableDiffusionControlNetPipeline + from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline + from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline + from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline + from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline + from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline + from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py new file mode 100644 index 000000000000..1a7efa3212bd --- /dev/null +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py @@ -0,0 +1,405 @@ +# Copyright 2023 Salesforce.com, inc. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Union + +import PIL +import torch +from transformers import CLIPTokenizer + +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...schedulers import PNDMScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..blip_diffusion.blip_image_processing import BlipImageProcessor +from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel +from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers.pipelines import BlipDiffusionControlNetPipeline + >>> from diffusers.utils import load_image + >>> from controlnet_aux import CannyDetector + >>> import torch + + >>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained( + ... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16 + ... ).to("cuda") + + >>> style_subject = "flower" + >>> tgt_subject = "teapot" + >>> text_prompt = "on a marble table" + + >>> cldm_cond_image = load_image( + ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg" + ... ).resize(512, 512) + >>> canny = CannyDetector() + >>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil") + >>> style_image = load_image( + ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg" + ... ) + >>> guidance_scale = 7.5 + >>> num_inference_steps = 50 + >>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" + + + >>> output = blip_diffusion_pipe( + ... text_prompt, + ... style_image, + ... cldm_cond_image, + ... style_subject, + ... tgt_subject, + ... guidance_scale=guidance_scale, + ... num_inference_steps=num_inference_steps, + ... neg_prompt=negative_prompt, + ... height=512, + ... width=512, + ... ).images + >>> output[0].save("image.png") + ``` +""" + + +class BlipDiffusionControlNetPipeline(DiffusionPipeline): + """ + Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer ([`CLIPTokenizer`]): + Tokenizer for the text encoder + text_encoder ([`ContextCLIPTextModel`]): + Text encoder to encode the text prompt + vae ([`AutoencoderKL`]): + VAE model to map the latents to the image + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + scheduler ([`PNDMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + qformer ([`Blip2QFormerModel`]): + QFormer model to get multi-modal embeddings from the text and image. + controlnet ([`ControlNetModel`]): + ControlNet model to get the conditioning image embedding. + image_processor ([`BlipImageProcessor`]): + Image Processor to preprocess and postprocess the image. + ctx_begin_pos (int, `optional`, defaults to 2): + Position of the context token in the text encoder. + """ + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: ContextCLIPTextModel, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + scheduler: PNDMScheduler, + qformer: Blip2QFormerModel, + controlnet: ControlNetModel, + image_processor: BlipImageProcessor, + ctx_begin_pos: int = 2, + mean: List[float] = None, + std: List[float] = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + scheduler=scheduler, + qformer=qformer, + controlnet=controlnet, + image_processor=image_processor, + ) + self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std) + + def get_query_embeddings(self, input_image, src_subject): + return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False) + + # from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it + def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): + rv = [] + for prompt, tgt_subject in zip(prompts, tgt_subjects): + prompt = f"a {tgt_subject} {prompt.strip()}" + # a trick to amplify the prompt + rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps))) + + return rv + + # Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def encode_prompt(self, query_embeds, prompt): + # embeddings for prompt, with query_embeds as context + max_len = self.text_encoder.text_model.config.max_position_embeddings + max_len -= self.qformer.config.num_query_tokens + + tokenized_prompt = self.tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=max_len, + return_tensors="pt", + ).to(self.device) + + batch_size = query_embeds.shape[0] + ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size + + text_embeddings = self.text_encoder( + input_ids=tokenized_prompt.input_ids, + ctx_embeddings=query_embeds, + ctx_begin_pos=ctx_begin_pos, + )[0] + + return text_embeddings + + # Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.image_processor.preprocess( + image, + size={"width": width, "height": height}, + do_rescale=True, + do_center_crop=False, + do_normalize=False, + return_tensors="pt", + )["pixel_values"].to(self.device) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: List[str], + reference_image: PIL.Image.Image, + condtioning_image: PIL.Image.Image, + source_subject_category: List[str], + target_subject_category: List[str], + latents: Optional[torch.FloatTensor] = None, + guidance_scale: float = 7.5, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + neg_prompt: Optional[str] = "", + prompt_strength: float = 1.0, + prompt_reps: int = 20, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`List[str]`): + The prompt or prompts to guide the image generation. + reference_image (`PIL.Image.Image`): + The reference image to condition the generation on. + condtioning_image (`PIL.Image.Image`): + The conditioning canny edge image to condition the generation on. + source_subject_category (`List[str]`): + The source subject category. + target_subject_category (`List[str]`): + The target subject category. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by random sampling. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + height (`int`, *optional*, defaults to 512): + The height of the generated image. + width (`int`, *optional*, defaults to 512): + The width of the generated image. + seed (`int`, *optional*, defaults to 42): + The seed to use for random generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + neg_prompt (`str`, *optional*, defaults to ""): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_strength (`float`, *optional*, defaults to 1.0): + The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps + to amplify the prompt. + prompt_reps (`int`, *optional*, defaults to 20): + The number of times the prompt is repeated along with prompt_strength to amplify the prompt. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + reference_image = self.image_processor.preprocess( + reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" + )["pixel_values"] + reference_image = reference_image.to(self.device) + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(source_subject_category, str): + source_subject_category = [source_subject_category] + if isinstance(target_subject_category, str): + target_subject_category = [target_subject_category] + + batch_size = len(prompt) + + prompt = self._build_prompt( + prompts=prompt, + tgt_subjects=target_subject_category, + prompt_strength=prompt_strength, + prompt_reps=prompt_reps, + ) + query_embeds = self.get_query_embeddings(reference_image, source_subject_category) + text_embeddings = self.encode_prompt(query_embeds, prompt) + # 3. unconditional embedding + do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + max_length = self.text_encoder.text_model.config.max_position_embeddings + + uncond_input = self.tokenizer( + [neg_prompt] * batch_size, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder( + input_ids=uncond_input.input_ids.to(self.device), + ctx_embeddings=None, + )[0] + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1) + latents = self.prepare_latents( + batch_size=batch_size, + num_channels=self.unet.config.in_channels, + height=height // scale_down_factor, + width=width // scale_down_factor, + generator=generator, + latents=latents, + dtype=self.unet.dtype, + device=self.device, + ) + # set timesteps + extra_set_kwargs = {} + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + cond_image = self.prepare_control_image( + image=condtioning_image, + width=width, + height=height, + batch_size=batch_size, + num_images_per_prompt=1, + device=self.device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + do_classifier_free_guidance = guidance_scale > 1.0 + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + down_block_res_samples, mid_block_res_sample = self.controlnet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + controlnet_cond=cond_image, + return_dict=False, + ) + + noise_pred = self.unet( + latent_model_input, + timestep=t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + )["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + )["prev_sample"] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 0c7b3117fa47..8e95dde52caf 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -315,6 +315,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class BlipDiffusionControlNetPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class BlipDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CLIPImageProjection(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/blipdiffusion/__init__.py b/tests/pipelines/blipdiffusion/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/blipdiffusion/test_blipdiffusion.py b/tests/pipelines/blipdiffusion/test_blipdiffusion.py new file mode 100644 index 000000000000..480581928c77 --- /dev/null +++ b/tests/pipelines/blipdiffusion/test_blipdiffusion.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import CLIPTokenizer +from transformers.models.blip_2.configuration_blip_2 import Blip2Config +from transformers.models.clip.configuration_clip import CLIPTextConfig + +from diffusers import AutoencoderKL, BlipDiffusionPipeline, PNDMScheduler, UNet2DConditionModel +from diffusers.utils.testing_utils import enable_full_determinism +from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor +from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel +from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = BlipDiffusionPipeline + params = [ + "prompt", + "reference_image", + "source_subject_category", + "target_subject_category", + ] + batch_params = [ + "prompt", + "reference_image", + "source_subject_category", + "target_subject_category", + ] + required_optional_params = [ + "generator", + "height", + "width", + "latents", + "guidance_scale", + "num_inference_steps", + "neg_prompt", + "guidance_scale", + "prompt_strength", + "prompt_reps", + ] + + def get_dummy_components(self): + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + vocab_size=1000, + hidden_size=16, + intermediate_size=16, + projection_dim=16, + num_hidden_layers=1, + num_attention_heads=1, + max_position_embeddings=77, + ) + text_encoder = ContextCLIPTextModel(text_encoder_config) + + vae = AutoencoderKL( + in_channels=4, + out_channels=4, + down_block_types=("DownEncoderBlock2D",), + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(32,), + layers_per_block=1, + act_fn="silu", + latent_channels=4, + norm_num_groups=16, + sample_size=16, + ) + + blip_vision_config = { + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 1, + "image_size": 224, + "patch_size": 14, + "hidden_act": "quick_gelu", + } + + blip_qformer_config = { + "vocab_size": 1000, + "hidden_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 1, + "intermediate_size": 16, + "max_position_embeddings": 512, + "cross_attention_frequency": 1, + "encoder_hidden_size": 16, + } + qformer_config = Blip2Config( + vision_config=blip_vision_config, + qformer_config=blip_qformer_config, + num_query_tokens=16, + tokenizer="hf-internal-testing/tiny-random-bert", + ) + qformer = Blip2QFormerModel(qformer_config) + + unet = UNet2DConditionModel( + block_out_channels=(16, 32), + norm_num_groups=16, + layers_per_block=1, + sample_size=16, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=16, + ) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + scheduler = PNDMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + set_alpha_to_one=False, + skip_prk_steps=True, + ) + + vae.eval() + qformer.eval() + text_encoder.eval() + + image_processor = BlipImageProcessor() + + components = { + "text_encoder": text_encoder, + "vae": vae, + "qformer": qformer, + "unet": unet, + "tokenizer": tokenizer, + "scheduler": scheduler, + "image_processor": image_processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + np.random.seed(seed) + reference_image = np.random.rand(32, 32, 3) * 255 + reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA") + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "swimming underwater", + "generator": generator, + "reference_image": reference_image, + "source_subject_category": "dog", + "target_subject_category": "dog", + "height": 32, + "width": 32, + "guidance_scale": 7.5, + "num_inference_steps": 2, + "output_type": "np", + } + return inputs + + def test_blipdiffusion(self): + device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + image = pipe(**self.get_dummy_inputs(device))[0] + image_slice = image[0, -3:, -3:, 0] + + assert image.shape == (1, 16, 16, 4) + + expected_slice = np.array([0.7096, 0.5900, 0.6703, 0.4032, 0.7766, 0.3629, 0.5447, 0.4149, 0.8172]) + + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}" diff --git a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py new file mode 100644 index 000000000000..f15da0a67653 --- /dev/null +++ b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py @@ -0,0 +1,216 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import CLIPTokenizer +from transformers.models.blip_2.configuration_blip_2 import Blip2Config +from transformers.models.clip.configuration_clip import CLIPTextConfig + +from diffusers import ( + AutoencoderKL, + BlipDiffusionControlNetPipeline, + ControlNetModel, + PNDMScheduler, + UNet2DConditionModel, +) +from diffusers.utils.testing_utils import enable_full_determinism +from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor +from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel +from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = BlipDiffusionControlNetPipeline + params = [ + "prompt", + "reference_image", + "source_subject_category", + "target_subject_category", + "condtioning_image", + ] + batch_params = [ + "prompt", + "reference_image", + "source_subject_category", + "target_subject_category", + "condtioning_image", + ] + required_optional_params = [ + "generator", + "height", + "width", + "latents", + "guidance_scale", + "num_inference_steps", + "neg_prompt", + "guidance_scale", + "prompt_strength", + "prompt_reps", + ] + + def get_dummy_components(self): + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + vocab_size=1000, + hidden_size=16, + intermediate_size=16, + projection_dim=16, + num_hidden_layers=1, + num_attention_heads=1, + max_position_embeddings=77, + ) + text_encoder = ContextCLIPTextModel(text_encoder_config) + + vae = AutoencoderKL( + in_channels=4, + out_channels=4, + down_block_types=("DownEncoderBlock2D",), + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(32,), + layers_per_block=1, + act_fn="silu", + latent_channels=4, + norm_num_groups=16, + sample_size=16, + ) + + blip_vision_config = { + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 1, + "image_size": 224, + "patch_size": 14, + "hidden_act": "quick_gelu", + } + + blip_qformer_config = { + "vocab_size": 1000, + "hidden_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 1, + "intermediate_size": 16, + "max_position_embeddings": 512, + "cross_attention_frequency": 1, + "encoder_hidden_size": 16, + } + qformer_config = Blip2Config( + vision_config=blip_vision_config, + qformer_config=blip_qformer_config, + num_query_tokens=16, + tokenizer="hf-internal-testing/tiny-random-bert", + ) + qformer = Blip2QFormerModel(qformer_config) + + unet = UNet2DConditionModel( + block_out_channels=(4, 16), + layers_per_block=1, + norm_num_groups=4, + sample_size=16, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=16, + ) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + scheduler = PNDMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + set_alpha_to_one=False, + skip_prk_steps=True, + ) + controlnet = ControlNetModel( + block_out_channels=(4, 16), + layers_per_block=1, + in_channels=4, + norm_num_groups=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=16, + conditioning_embedding_out_channels=(8, 16), + ) + + vae.eval() + qformer.eval() + text_encoder.eval() + + image_processor = BlipImageProcessor() + + components = { + "text_encoder": text_encoder, + "vae": vae, + "qformer": qformer, + "unet": unet, + "tokenizer": tokenizer, + "scheduler": scheduler, + "controlnet": controlnet, + "image_processor": image_processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + np.random.seed(seed) + reference_image = np.random.rand(32, 32, 3) * 255 + reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA") + cond_image = np.random.rand(32, 32, 3) * 255 + cond_image = Image.fromarray(cond_image.astype("uint8")).convert("RGBA") + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "swimming underwater", + "generator": generator, + "reference_image": reference_image, + "condtioning_image": cond_image, + "source_subject_category": "dog", + "target_subject_category": "dog", + "height": 32, + "width": 32, + "guidance_scale": 7.5, + "num_inference_steps": 2, + "output_type": "np", + } + return inputs + + def test_blipdiffusion_controlnet(self): + device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + image = pipe(**self.get_dummy_inputs(device))[0] + image_slice = image[0, -3:, -3:, 0] + + assert image.shape == (1, 16, 16, 4) + expected_slice = np.array([0.7953, 0.7136, 0.6597, 0.4779, 0.7389, 0.4111, 0.5826, 0.4150, 0.8422]) + + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"