Skip to content

Commit 157c901

Browse files
ayushtuespatrickvonplatensayakpaul
authored
Add BLIP Diffusion (#4388)
* Add BLIP Diffusion skeleton * Add other model components * Add BLIP2, need to change it for now * Fix pipeline imports * Load pretrained ViT * Make qformer fwd pass same * Replicate fwd passes * Fix device bug * Add accelerate functions * Remove extra functions from Blip2 * Minor bug * Integrate initial review changes * Refactoring * Refactoring * Refactor * Add controlnet * Refactor * Update conversion script * Add image processor * Shift postprocessing to ImageProcessor * Refactor * Fix device * Add fast tests * Update conversion script * Fix checkpoint conversion script * Integrate review changes * Integrate reivew changes * Remove unused functions from test * Reuse HF image processor in Cond image * Create new BlipImageProcessor based on transfomers * Fix image preprocessor * Minor * Minor * Add canny preprocessing * Fix controlnet preprocessing * Fix blip diffusion test * Add controlnet test * Add initial doc strings * Integrate review changes * Refactor * Update examples * Remove DDIM comments * Add copied from for prepare_latents * Add type anotations * Add docstrings * Do black formatting * Add batch support * Make tests pass * Make controlnet tests pass * Black formatting * Fix progress bar * Fix some licensing comments * Fix imports * Refactor controlnet * Make tests faster * Edit examples * Black formatting/Ruff * Add doc * Minor Co-authored-by: Patrick von Platen <[email protected]> * Move controlnet pipeline * Make tests faster * Fix imports * Fix formatting * Fix make errors * Fix make errors * Minor * Add suggested doc changes Co-authored-by: Sayak Paul <[email protected]> * Edit docs * Fix 16 bit loading * Update examples * Edit toctree * Update docs/source/en/api/pipelines/blip_diffusion.md Co-authored-by: Sayak Paul <[email protected]> * Minor * Add tips * Edit examples * Update model paths --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 24563ca commit 157c901

16 files changed

+3295
-533
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@
216216
title: AudioLDM 2
217217
- local: api/pipelines/auto_pipeline
218218
title: AutoPipeline
219+
- local: api/pipelines/blip_diffusion
220+
title: BLIP Diffusion
219221
- local: api/pipelines/consistency_models
220222
title: Consistency Models
221223
- local: api/pipelines/controlnet
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Blip Diffusion
2+
3+
Blip Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://arxiv.org/abs/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
4+
5+
6+
The abstract from the paper is:
7+
8+
*Subject-driven text-to-image generation models create novel renditions of an input subject based on text prompts. Existing models suffer from lengthy fine-tuning and difficulties preserving the subject fidelity. To overcome these limitations, we introduce BLIP-Diffusion, a new subject-driven image generation model that supports multimodal control which consumes inputs of subject images and text prompts. Unlike other subject-driven generation models, BLIP-Diffusion introduces a new multimodal encoder which is pre-trained to provide subject representation. We first pre-train the multimodal encoder following BLIP-2 to produce visual representation aligned with the text. Then we design a subject representation learning task which enables a diffusion model to leverage such visual representation and generates new subject renditions. Compared with previous methods such as DreamBooth, our model enables zero-shot subject-driven generation, and efficient fine-tuning for customized subject with up to 20x speedup. We also demonstrate that BLIP-Diffusion can be flexibly combined with existing techniques such as ControlNet and prompt-to-prompt to enable novel subject-driven generation and editing applications.*
9+
10+
The original codebase can be found at [salesforce/LAVIS](https://github.com/salesforce/LAVIS/tree/main/projects/blip-diffusion). You can find the official BLIP Diffusion checkpoints under the [hf.co/SalesForce](https://hf.co/SalesForce) organization.
11+
12+
`BlipDiffusionPipeline` and `BlipDiffusionControlNetPipeline` were contributed by [`ayushtues`](https://github.com/ayushtues/).
13+
14+
<Tip>
15+
16+
Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
17+
18+
</Tip>
19+
20+
21+
## BlipDiffusionPipeline
22+
[[autodoc]] BlipDiffusionPipeline
23+
- all
24+
- __call__
25+
26+
## BlipDiffusionControlNetPipeline
27+
[[autodoc]] BlipDiffusionControlNetPipeline
28+
- all
29+
- __call__
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
"""
2+
This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.
3+
"""
4+
5+
import argparse
6+
import os
7+
import tempfile
8+
9+
import torch
10+
from lavis.models import load_model_and_preprocess
11+
from transformers import CLIPTokenizer
12+
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
13+
14+
from diffusers import (
15+
AutoencoderKL,
16+
PNDMScheduler,
17+
UNet2DConditionModel,
18+
)
19+
from diffusers.pipelines import BlipDiffusionPipeline
20+
from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
21+
from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
22+
from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
23+
24+
25+
BLIP2_CONFIG = {
26+
"vision_config": {
27+
"hidden_size": 1024,
28+
"num_hidden_layers": 23,
29+
"num_attention_heads": 16,
30+
"image_size": 224,
31+
"patch_size": 14,
32+
"intermediate_size": 4096,
33+
"hidden_act": "quick_gelu",
34+
},
35+
"qformer_config": {
36+
"cross_attention_frequency": 1,
37+
"encoder_hidden_size": 1024,
38+
"vocab_size": 30523,
39+
},
40+
"num_query_tokens": 16,
41+
}
42+
blip2config = Blip2Config(**BLIP2_CONFIG)
43+
44+
45+
def qformer_model_from_original_config():
46+
qformer = Blip2QFormerModel(blip2config)
47+
return qformer
48+
49+
50+
def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):
51+
embeddings = {}
52+
embeddings.update(
53+
{
54+
f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[
55+
f"{original_embeddings_prefix}.word_embeddings.weight"
56+
]
57+
}
58+
)
59+
embeddings.update(
60+
{
61+
f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[
62+
f"{original_embeddings_prefix}.position_embeddings.weight"
63+
]
64+
}
65+
)
66+
embeddings.update(
67+
{f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]}
68+
)
69+
embeddings.update(
70+
{f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]}
71+
)
72+
return embeddings
73+
74+
75+
def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):
76+
proj_layer = {}
77+
proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]})
78+
proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]})
79+
proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]})
80+
proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]})
81+
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]})
82+
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]})
83+
return proj_layer
84+
85+
86+
def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):
87+
attention = {}
88+
attention.update(
89+
{
90+
f"{diffuser_attention_prefix}.attention.query.weight": model[
91+
f"{original_attention_prefix}.self.query.weight"
92+
]
93+
}
94+
)
95+
attention.update(
96+
{f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]}
97+
)
98+
attention.update(
99+
{f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]}
100+
)
101+
attention.update(
102+
{f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]}
103+
)
104+
attention.update(
105+
{
106+
f"{diffuser_attention_prefix}.attention.value.weight": model[
107+
f"{original_attention_prefix}.self.value.weight"
108+
]
109+
}
110+
)
111+
attention.update(
112+
{f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]}
113+
)
114+
attention.update(
115+
{f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]}
116+
)
117+
attention.update(
118+
{f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]}
119+
)
120+
attention.update(
121+
{
122+
f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[
123+
f"{original_attention_prefix}.output.LayerNorm.weight"
124+
]
125+
}
126+
)
127+
attention.update(
128+
{
129+
f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[
130+
f"{original_attention_prefix}.output.LayerNorm.bias"
131+
]
132+
}
133+
)
134+
return attention
135+
136+
137+
def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):
138+
output_layers = {}
139+
output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]})
140+
output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]})
141+
output_layers.update(
142+
{f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]}
143+
)
144+
output_layers.update(
145+
{f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]}
146+
)
147+
return output_layers
148+
149+
150+
def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):
151+
encoder = {}
152+
for i in range(blip2config.qformer_config.num_hidden_layers):
153+
encoder.update(
154+
attention_from_original_checkpoint(
155+
model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention"
156+
)
157+
)
158+
encoder.update(
159+
attention_from_original_checkpoint(
160+
model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention"
161+
)
162+
)
163+
164+
encoder.update(
165+
{
166+
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[
167+
f"{original_encoder_prefix}.{i}.intermediate.dense.weight"
168+
]
169+
}
170+
)
171+
encoder.update(
172+
{
173+
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[
174+
f"{original_encoder_prefix}.{i}.intermediate.dense.bias"
175+
]
176+
}
177+
)
178+
encoder.update(
179+
{
180+
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[
181+
f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight"
182+
]
183+
}
184+
)
185+
encoder.update(
186+
{
187+
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[
188+
f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias"
189+
]
190+
}
191+
)
192+
193+
encoder.update(
194+
output_layers_from_original_checkpoint(
195+
model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output"
196+
)
197+
)
198+
encoder.update(
199+
output_layers_from_original_checkpoint(
200+
model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query"
201+
)
202+
)
203+
return encoder
204+
205+
206+
def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):
207+
visual_encoder_layer = {}
208+
209+
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]})
210+
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]})
211+
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]})
212+
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]})
213+
visual_encoder_layer.update(
214+
{f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]}
215+
)
216+
visual_encoder_layer.update(
217+
{f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]}
218+
)
219+
visual_encoder_layer.update(
220+
{f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]}
221+
)
222+
visual_encoder_layer.update(
223+
{f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]}
224+
)
225+
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]})
226+
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]})
227+
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]})
228+
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]})
229+
230+
return visual_encoder_layer
231+
232+
233+
def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):
234+
visual_encoder = {}
235+
236+
visual_encoder.update(
237+
{
238+
f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"]
239+
.unsqueeze(0)
240+
.unsqueeze(0)
241+
}
242+
)
243+
visual_encoder.update(
244+
{
245+
f"{diffuser_prefix}.embeddings.position_embedding": model[
246+
f"{original_prefix}.positional_embedding"
247+
].unsqueeze(0)
248+
}
249+
)
250+
visual_encoder.update(
251+
{f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]}
252+
)
253+
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]})
254+
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]})
255+
256+
for i in range(blip2config.vision_config.num_hidden_layers):
257+
visual_encoder.update(
258+
visual_encoder_layer_from_original_checkpoint(
259+
model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}"
260+
)
261+
)
262+
263+
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]})
264+
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]})
265+
266+
return visual_encoder
267+
268+
269+
def qformer_original_checkpoint_to_diffusers_checkpoint(model):
270+
qformer_checkpoint = {}
271+
qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings"))
272+
qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]})
273+
qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer"))
274+
qformer_checkpoint.update(
275+
encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer")
276+
)
277+
qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder"))
278+
return qformer_checkpoint
279+
280+
281+
def get_qformer(model):
282+
print("loading qformer")
283+
284+
qformer = qformer_model_from_original_config()
285+
qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)
286+
287+
load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)
288+
289+
print("done loading qformer")
290+
return qformer
291+
292+
293+
def load_checkpoint_to_model(checkpoint, model):
294+
with tempfile.NamedTemporaryFile(delete=False) as file:
295+
torch.save(checkpoint, file.name)
296+
del checkpoint
297+
model.load_state_dict(torch.load(file.name), strict=False)
298+
299+
os.remove(file.name)
300+
301+
302+
def save_blip_diffusion_model(model, args):
303+
qformer = get_qformer(model)
304+
qformer.eval()
305+
306+
text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
307+
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
308+
309+
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
310+
vae.eval()
311+
text_encoder.eval()
312+
scheduler = PNDMScheduler(
313+
beta_start=0.00085,
314+
beta_end=0.012,
315+
beta_schedule="scaled_linear",
316+
set_alpha_to_one=False,
317+
skip_prk_steps=True,
318+
)
319+
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
320+
image_processor = BlipImageProcessor()
321+
blip_diffusion = BlipDiffusionPipeline(
322+
tokenizer=tokenizer,
323+
text_encoder=text_encoder,
324+
vae=vae,
325+
unet=unet,
326+
scheduler=scheduler,
327+
qformer=qformer,
328+
image_processor=image_processor,
329+
)
330+
blip_diffusion.save_pretrained(args.checkpoint_path)
331+
332+
333+
def main(args):
334+
model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True)
335+
save_blip_diffusion_model(model.state_dict(), args)
336+
337+
338+
if __name__ == "__main__":
339+
parser = argparse.ArgumentParser()
340+
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
341+
args = parser.parse_args()
342+
343+
main(args)

0 commit comments

Comments
 (0)