diff --git a/examples/05_stable_diffusion/__init__.py b/examples/05_stable_diffusion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/05_stable_diffusion/scripts/demo_alt.py b/examples/05_stable_diffusion/scripts/demo_alt.py index 28b322f02..d571447f2 100644 --- a/examples/05_stable_diffusion/scripts/demo_alt.py +++ b/examples/05_stable_diffusion/scripts/demo_alt.py @@ -31,17 +31,19 @@ help="Model weights to apply to compiled model (with --include-constants false)", ) @click.option("--ckpt", default=None, help="e.g. v1-5-pruned-emaonly.ckpt") -@click.option("--width", default=512, help="Width of generated image") -@click.option("--height", default=512, help="Height of generated image") +@click.option("--width", default=768, help="Width of generated image") +@click.option("--height", default=768, help="Height of generated image") @click.option("--batch", default=1, help="Batch size of generated image") @click.option("--prompt", default="A vision of paradise, Unreal Engine", help="prompt") @click.option("--negative_prompt", default="", help="prompt") @click.option("--steps", default=50, help="Number of inference steps") @click.option("--cfg", default=7.5, help="Guidance scale") +@click.option("--workdir", default="v21", help="Workdir") def run( - hf_hub_or_path, ckpt, width, height, batch, prompt, negative_prompt, steps, cfg + hf_hub_or_path, ckpt, width, height, batch, prompt, negative_prompt, steps, cfg, workdir ): pipe = StableDiffusionAITPipeline( + workdir=workdir, hf_hub_or_path=hf_hub_or_path, ckpt=ckpt, ) diff --git a/examples/05_stable_diffusion/scripts/demo_img2img_alt.py b/examples/05_stable_diffusion/scripts/demo_img2img_alt.py new file mode 100644 index 000000000..78b92e253 --- /dev/null +++ b/examples/05_stable_diffusion/scripts/demo_img2img_alt.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from io import BytesIO + +import click +import torch +import requests + +from PIL import Image +from aitemplate.utils.import_path import import_parent + +if __name__ == "__main__": + import_parent(filepath=__file__, level=1) + +from src.pipeline_stable_diffusion_ait_alt import StableDiffusionAITPipeline + + +@click.command() +@click.option( + "--hf-hub-or-path", + default="runwayml/stable-diffusion-v1-5", + help="Model weights to apply to compiled model (with --include-constants false)", +) +@click.option("--ckpt", default=None, help="e.g. v1-5-pruned-emaonly.ckpt") +@click.option("--width", default=768, help="Width of generated image") +@click.option("--height", default=768, help="Height of generated image") +@click.option("--batch", default=1, help="Batch size of generated image") +@click.option("--prompt", default="A vision of paradise, Unreal Engine", help="prompt") +@click.option("--negative_prompt", default="", help="prompt") +@click.option("--steps", default=50, help="Number of inference steps") +@click.option("--cfg", default=7.5, help="Guidance scale") +@click.option("--strength", default=0.8, help="Guidance scale") +@click.option("--workdir", default="v15", help="Workdir") +def run( + hf_hub_or_path, ckpt, width, height, batch, prompt, negative_prompt, steps, cfg, strength, workdir +): + pipe = StableDiffusionAITPipeline( + workdir=workdir, + hf_hub_or_path=hf_hub_or_path, + ckpt=ckpt, + ) + + prompt = [prompt] * batch + negative_prompt = [negative_prompt] * batch + + url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + response = requests.get(url) + init_image = Image.open(BytesIO(response.content)).convert("RGB") + init_image = init_image.resize((height, width)) + + with torch.autocast("cuda"): + image = pipe( + prompt=prompt, + init_image=init_image, + height=height, + width=width, + negative_prompt=negative_prompt, + num_inference_steps=steps, + guidance_scale=cfg, + strength=strength, + ).images[0] + image.save("example_ait.png") + + +if __name__ == "__main__": + run() diff --git a/examples/05_stable_diffusion/src/pipeline_stable_diffusion_ait_alt.py b/examples/05_stable_diffusion/src/pipeline_stable_diffusion_ait_alt.py index 419184628..8f3d90969 100644 --- a/examples/05_stable_diffusion/src/pipeline_stable_diffusion_ait_alt.py +++ b/examples/05_stable_diffusion/src/pipeline_stable_diffusion_ait_alt.py @@ -13,532 +13,27 @@ # limitations under the License. # import inspect - import os import re from typing import List, Optional, Union +import PIL import torch from aitemplate.compiler import Model - -from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel - +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, + LMSDiscreteScheduler, + EulerAncestralDiscreteScheduler +) from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils.pil_utils import numpy_to_pil from tqdm import tqdm - from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from .compile_lib.compile_vae_alt import map_vae - - -def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) - - -def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments( - new_item, n_shave_prefix_segments=n_shave_prefix_segments - ) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments( - new_item, n_shave_prefix_segments=n_shave_prefix_segments - ) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") - - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") - - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") - - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") - - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - - new_item = shave_segments( - new_item, n_shave_prefix_segments=n_shave_prefix_segments - ) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, additional_replacements=None -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits - attention layers, and takes into account additional replacements that may arise. - - Assigns the weights to the new checkpoint. - """ - assert isinstance( - paths, list - ), "Paths should be a list of dicts containing 'old' and 'new' keys." - - for path in paths: - new_path = path["new"] - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - -# ================# -# VAE Conversion # -# ================# - - -def convert_ldm_vae_checkpoint(vae_state_dict): - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[ - "encoder.conv_out.weight" - ] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[ - "encoder.norm_out.weight" - ] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[ - "encoder.norm_out.bias" - ] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[ - "decoder.conv_out.weight" - ] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[ - "decoder.norm_out.weight" - ] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[ - "decoder.norm_out.bias" - ] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len( - { - ".".join(layer.split(".")[:3]) - for layer in vae_state_dict - if "encoder.down" in layer - } - ) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] - for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len( - { - ".".join(layer.split(".")[:3]) - for layer in vae_state_dict - if "decoder.up" in layer - } - ) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] - for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [ - key - for key in down_blocks[i] - if f"down.{i}" in key and f"down.{i}.downsample" not in key - ] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[ - f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" - ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") - new_checkpoint[ - f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" - ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint( - paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] - ) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint( - paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] - ) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] - ) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key - for key in up_blocks[block_id] - if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[ - f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" - ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] - new_checkpoint[ - f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" - ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint( - paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] - ) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint( - paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] - ) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] - ) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint - - -# =================# -# UNet Conversion # -# =================# -def convert_ldm_unet_checkpoint(unet_state_dict, layers_per_block=2): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict[ - "time_embed.0.weight" - ] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict[ - "time_embed.0.bias" - ] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict[ - "time_embed.2.weight" - ] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict[ - "time_embed.2.bias" - ] - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len( - { - ".".join(layer.split(".")[:2]) - for layer in unet_state_dict - if "input_blocks" in layer - } - ) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len( - { - ".".join(layer.split(".")[:2]) - for layer in unet_state_dict - if "middle_block" in layer - } - ) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len( - { - ".".join(layer.split(".")[:2]) - for layer in unet_state_dict - if "output_blocks" in layer - } - ) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (layers_per_block + 1) - layer_in_block_id = (i - 1) % (layers_per_block + 1) - - resnets = [ - key - for key in input_blocks[i] - if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[ - f"down_blocks.{block_id}.downsamplers.0.conv.weight" - ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight") - new_checkpoint[ - f"down_blocks.{block_id}.downsamplers.0.conv.bias" - ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") - - paths = renew_resnet_paths(resnets) - meta_path = { - "old": f"input_blocks.{i}.0", - "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path] - ) - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"input_blocks.{i}.1", - "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - ) - - for i in range(num_output_blocks): - block_id = i // (layers_per_block + 1) - layer_in_block_id = i % (layers_per_block + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [ - key for key in output_blocks[i] if f"output_blocks.{i}.1" in key - ] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = { - "old": f"output_blocks.{i}.0", - "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - ) - - output_block_list = {k: sorted(v) for k, v in output_block_list.items()} - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index( - ["conv.bias", "conv.weight"] - ) - new_checkpoint[ - f"up_blocks.{block_id}.upsamplers.0.conv.weight" - ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"] - new_checkpoint[ - f"up_blocks.{block_id}.upsamplers.0.conv.bias" - ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - ) - else: - resnet_0_paths = renew_resnet_paths( - output_block_layers, n_shave_prefix_segments=1 - ) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join( - [ - "up_blocks", - str(block_id), - "resnets", - str(layer_in_block_id), - path["new"], - ] - ) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - return new_checkpoint - +from .modeling.vae import AutoencoderKL as ait_AutoencoderKL +from .pipeline_utils import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, preprocess, map_clip_state_dict, map_unet_state_dict textenc_conversion_lst = [ ("positional_embedding", "text_model.embeddings.position_embedding.weight"), @@ -581,7 +76,7 @@ def convert_text_enc_state_dict(state_dict): if key in textenc_conversion_map: new_state_dict[textenc_conversion_map[key]] = arr if key.startswith("transformer."): - new_key = key[len("transformer.") :] + new_key = key[len("transformer."):] if new_key.endswith(".in_proj_weight"): new_key = new_key[: -len(".in_proj_weight")] new_key = textenc_pattern.sub( @@ -589,17 +84,17 @@ def convert_text_enc_state_dict(state_dict): ) new_state_dict[new_key + ".q_proj.weight"] = arr[:d_model, :] new_state_dict[new_key + ".k_proj.weight"] = arr[ - d_model : d_model * 2, : - ] - new_state_dict[new_key + ".v_proj.weight"] = arr[d_model * 2 :, :] + d_model: d_model * 2, : + ] + new_state_dict[new_key + ".v_proj.weight"] = arr[d_model * 2:, :] elif new_key.endswith(".in_proj_bias"): new_key = new_key[: -len(".in_proj_bias")] new_key = textenc_pattern.sub( lambda m: protected[re.escape(m.group(0))], new_key ) new_state_dict[new_key + ".q_proj.bias"] = arr[:d_model] - new_state_dict[new_key + ".k_proj.bias"] = arr[d_model : d_model * 2] - new_state_dict[new_key + ".v_proj.bias"] = arr[d_model * 2 :] + new_state_dict[new_key + ".k_proj.bias"] = arr[d_model: d_model * 2] + new_state_dict[new_key + ".v_proj.bias"] = arr[d_model * 2:] else: new_key = textenc_pattern.sub( lambda m: protected[re.escape(m.group(0))], new_key @@ -608,61 +103,9 @@ def convert_text_enc_state_dict(state_dict): return new_state_dict -# =========================# -# AITemplate mapping # -# =========================# -def map_unet_state_dict(state_dict, dim=320): - params_ait = {} - for key, arr in state_dict.items(): - arr = arr.to("cuda", dtype=torch.float16) - if len(arr.shape) == 4: - arr = arr.permute((0, 2, 3, 1)).contiguous() - elif key.endswith("ff.net.0.proj.weight"): - # print("ff.net.0.proj.weight") - w1, w2 = arr.chunk(2, dim=0) - params_ait[key.replace(".", "_")] = w1 - params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 - continue - elif key.endswith("ff.net.0.proj.bias"): - # print("ff.net.0.proj.bias") - w1, w2 = arr.chunk(2, dim=0) - params_ait[key.replace(".", "_")] = w1 - params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 - continue - params_ait[key.replace(".", "_")] = arr - - params_ait["arange"] = ( - torch.arange(start=0, end=dim // 2, dtype=torch.float32).cuda().half() - ) - return params_ait - - -def map_clip_state_dict(state_dict): - params_ait = {} - for key, arr in state_dict.items(): - arr = arr.to("cuda", dtype=torch.float16) - name = key.replace("text_model.", "") - ait_name = name.replace(".", "_") - if name.endswith("out_proj.weight"): - ait_name = ait_name.replace("out_proj", "proj") - elif name.endswith("out_proj.bias"): - ait_name = ait_name.replace("out_proj", "proj") - elif "q_proj" in name: - ait_name = ait_name.replace("q_proj", "proj_q") - elif "k_proj" in name: - ait_name = ait_name.replace("k_proj", "proj_k") - elif "v_proj" in name: - ait_name = ait_name.replace("v_proj", "proj_v") - params_ait[ait_name] = arr - - return params_ait - - class StableDiffusionAITPipeline: - def __init__(self, hf_hub_or_path, ckpt): - self.device = torch.device("cuda") - workdir = "tmp/" - state_dict = None + def __init__(self, hf_hub_or_path, ckpt, workdir="tmp/"): + self.device = torch.device(0) if ckpt is not None: state_dict = torch.load(ckpt, map_location="cpu") while "state_dict" in state_dict: @@ -697,7 +140,7 @@ def __init__(self, hf_hub_or_path, ckpt): subfolder="text_encoder", revision="fp16", torch_dtype=torch.float16, - ).cuda() + ).to(self.device) else: config = CLIPTextConfig.from_pretrained( hf_hub_or_path, subfolder="text_encoder" @@ -713,8 +156,8 @@ def __init__(self, hf_hub_or_path, ckpt): print("Folding constants") self.clip_ait_exe.fold_constants() # cleanup - self.clip_pt = None - clip_params_ait = None + del self.clip_pt + del clip_params_ait self.unet_ait_exe = self.init_ait_module( model_name="UNet2DConditionModel", workdir=workdir @@ -727,7 +170,7 @@ def __init__(self, hf_hub_or_path, ckpt): subfolder="unet", revision="fp16", torch_dtype=torch.float16, - ).cuda() + ).to(self.device) self.unet_pt = self.unet_pt.state_dict() else: self.unet_pt = unet_state_dict @@ -737,59 +180,95 @@ def __init__(self, hf_hub_or_path, ckpt): print("Folding constants") self.unet_ait_exe.fold_constants() # cleanup - self.unet_pt = None - unet_params_ait = None + del self.unet_pt + del unet_params_ait self.vae_ait_exe = self.init_ait_module( model_name="AutoencoderKL", workdir=workdir ) print("Loading PyTorch VAE") if ckpt is None: - self.vae_pt = AutoencoderKL.from_pretrained( + self.vae = AutoencoderKL.from_pretrained( hf_hub_or_path, subfolder="vae", revision="fp16", torch_dtype=torch.float16, - ).cuda() + ).to(self.device) else: - self.vae_pt = dict(vae_state_dict) - + self.vae = dict(vae_state_dict) + in_channels = 3 + out_channels = 3 + down_block_types = [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ] + up_block_types = [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ] + block_out_channels = [128, 256, 512, 512] + layers_per_block = 2 + act_fn = "silu" + latent_channels = 4 + sample_size = 512 + + ait_vae = ait_AutoencoderKL( + 1, + 64, + 64, + in_channels=in_channels, + out_channels=out_channels, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + latent_channels=latent_channels, + sample_size=sample_size, + ) print("Mapping parameters...") - vae_params_ait = map_vae(self.vae_pt) + vae_params_ait = map_vae(ait_vae, self.vae) print("Setting constants") self.vae_ait_exe.set_many_constants_with_tensors(vae_params_ait) print("Folding constants") self.vae_ait_exe.fold_constants() # cleanup - self.vae_pt = None - vae_params_ait = None + del ait_vae + del vae_params_ait self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - self.scheduler = EulerDiscreteScheduler.from_pretrained( + self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained( hf_hub_or_path, subfolder="scheduler" ) + # self.scheduler = PNDMScheduler.from_pretrained( + # hf_hub_or_path, subfolder="scheduler" + # ) self.batch = 1 def init_ait_module( - self, - model_name, - workdir, + self, + model_name, + workdir, ): mod = Model(os.path.join(workdir, model_name, "test.so")) return mod def unet_inference( - self, latent_model_input, timesteps, encoder_hidden_states, height, width + self, latent_model_input, timesteps, encoder_hidden_states, height, width ): exe_module = self.unet_ait_exe timesteps_pt = timesteps.expand(self.batch * 2) inputs = { "input0": latent_model_input.permute((0, 2, 3, 1)) .contiguous() - .cuda() + .to(self.device) .half(), - "input1": timesteps_pt.cuda().half(), - "input2": encoder_hidden_states.cuda().half(), + "input1": timesteps_pt.to(self.device).half(), + "input2": encoder_hidden_states.to(self.device).half(), } ys = [] num_outputs = len(exe_module.get_output_name_to_index_map()) @@ -798,7 +277,7 @@ def unet_inference( shape[0] = self.batch * 2 shape[1] = height // 8 shape[2] = width // 8 - ys.append(torch.empty(shape).cuda().half()) + ys.append(torch.empty(shape).to(self.device).half()) exe_module.run_with_tensors(inputs, ys, graph_mode=False) noise_pred = ys[0].permute((0, 3, 1, 2)).float() return noise_pred @@ -806,7 +285,7 @@ def unet_inference( def clip_inference(self, input_ids, seqlen=77): exe_module = self.clip_ait_exe bs = input_ids.shape[0] - position_ids = torch.arange(seqlen).expand((bs, -1)).cuda() + position_ids = torch.arange(seqlen).expand((bs, -1)).to(self.device) inputs = { "input0": input_ids, "input1": position_ids, @@ -816,13 +295,13 @@ def clip_inference(self, input_ids, seqlen=77): for i in range(num_outputs): shape = exe_module.get_output_maximum_shape(i) shape[0] = self.batch - ys.append(torch.empty(shape).cuda().half()) + ys.append(torch.empty(shape).to(self.device).half()) exe_module.run_with_tensors(inputs, ys, graph_mode=False) return ys[0].float() def vae_inference(self, vae_input, height, width): exe_module = self.vae_ait_exe - inputs = [torch.permute(vae_input, (0, 2, 3, 1)).contiguous().cuda().half()] + inputs = [torch.permute(vae_input, (0, 2, 3, 1)).contiguous().to(self.device).half()] ys = [] num_outputs = len(exe_module.get_output_name_to_index_map()) for i in range(num_outputs): @@ -830,25 +309,25 @@ def vae_inference(self, vae_input, height, width): shape[0] = self.batch shape[1] = height shape[2] = width - ys.append(torch.empty(shape).cuda().half()) + ys.append(torch.empty(shape).to(self.device).half()) exe_module.run_with_tensors(inputs, ys, graph_mode=False) vae_out = ys[0].permute((0, 3, 1, 2)).float() return vae_out @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, + def generate( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, ): r""" Function invoked when calling the pipeline for generation. @@ -1035,7 +514,7 @@ def __call__( if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond + noise_pred_text - noise_pred_uncond ) latents = self.scheduler.step( @@ -1046,7 +525,313 @@ def __call__( latents = 1 / 0.18215 * latents image = self.vae_inference(latents, height, width) # pytorch equivalent - # image = self.vae_pt.decode(latents).sample + # image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + has_nsfw_concept = None + + if output_type == "pil": + image = numpy_to_pil(image) + + if not return_dict: + return image, has_nsfw_concept + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = 512, + width: Optional[int] = 512, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + latents: Optional[torch.FloatTensor] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`): + The negative prompt or prompts to guide the image generation. + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + args = { + "prompt": prompt, + "height": height, + "width": width, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + "negative_prompt": negative_prompt, + "eta": eta, + "generator": generator, + "latents": latents, + "output_type": output_type, + "return_dict": return_dict + } + if init_image is not None: + args = { + "prompt": prompt, + "strength": strength, + "init_image": init_image, + "height": height, + "width": width, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + "negative_prompt": negative_prompt, + "eta": eta, + "generator": generator, + "output_type": output_type, + "return_dict": return_dict + } + return self.img2img(**args) + return self.generate(**args) + + @torch.no_grad() + def img2img( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + height: Optional[int] = 512, + width: Optional[int] = 512, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`): + The negative prompt or prompts to guide the image generation. + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + self.batch = batch_size + + if strength < 0 or strength > 1: + raise ValueError( + f"The value of strength should in [0.0, 1.0] but is {strength}" + ) + + # set timesteps + accepts_offset = "offset" in set( + inspect.signature(self.scheduler.set_timesteps).parameters.keys() + ) + extra_set_kwargs = {} + if accepts_offset: + offset = 1 + extra_set_kwargs["offset"] = offset + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image, width=width, height=height) + + # encode the init image into latents and scale the latents + init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + # expand init_latents for batch_size + init_latents = torch.cat([init_latents] * batch_size) + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self.device) + latent_timestep = timesteps[:1].repeat(batch_size) + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device) + init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep).to( + self.device + ) + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.clip_inference(text_input.input_ids.to(self.device)) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + if isinstance(negative_prompt, list): + negative_prompt = negative_prompt[0] + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [negative_prompt] * batch_size, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + uncond_embeddings = self.clip_inference( + uncond_input.input_ids.to(self.device) + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + for i, t in enumerate(tqdm(self.scheduler.timesteps)): + t_index = i + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[t_index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5) + latent_model_input = latent_model_input.to(self.unet.dtype) + t = t.to(self.unet.dtype) + + # predict the noise residual + noise_pred = self.unet_inference( + latent_model_input, t, encoder_hidden_states=text_embeddings, height=height, width=width + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step( + noise_pred, t_index, latents, **extra_step_kwargs + ).prev_sample + else: + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae_inference(latents, width=width, height=height) image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() @@ -1057,8 +842,17 @@ def __call__( image = numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept) + return image, has_nsfw_concept return StableDiffusionPipelineOutput( images=image, nsfw_content_detected=has_nsfw_concept ) + + def get_timesteps(self, num_inference_steps, strength, device=torch.device(0)): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] + + return timesteps, num_inference_steps - t_start diff --git a/examples/05_stable_diffusion/src/pipeline_stable_diffusion_controlnet_ait.py b/examples/05_stable_diffusion/src/pipeline_stable_diffusion_controlnet_ait.py index 8c2230368..7e2ef4fbc 100644 --- a/examples/05_stable_diffusion/src/pipeline_stable_diffusion_controlnet_ait.py +++ b/examples/05_stable_diffusion/src/pipeline_stable_diffusion_controlnet_ait.py @@ -30,564 +30,8 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from .compile_lib.compile_vae_alt import map_vae - - -def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) - - -def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments( - new_item, n_shave_prefix_segments=n_shave_prefix_segments - ) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments( - new_item, n_shave_prefix_segments=n_shave_prefix_segments - ) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") - - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") - - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") - - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") - - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - - new_item = shave_segments( - new_item, n_shave_prefix_segments=n_shave_prefix_segments - ) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, additional_replacements=None -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits - attention layers, and takes into account additional replacements that may arise. - - Assigns the weights to the new checkpoint. - """ - assert isinstance( - paths, list - ), "Paths should be a list of dicts containing 'old' and 'new' keys." - - for path in paths: - new_path = path["new"] - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - -# ================# -# VAE Conversion # -# ================# - - -def convert_ldm_vae_checkpoint(vae_state_dict): - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[ - "encoder.conv_out.weight" - ] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[ - "encoder.norm_out.weight" - ] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[ - "encoder.norm_out.bias" - ] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[ - "decoder.conv_out.weight" - ] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[ - "decoder.norm_out.weight" - ] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[ - "decoder.norm_out.bias" - ] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len( - { - ".".join(layer.split(".")[:3]) - for layer in vae_state_dict - if "encoder.down" in layer - } - ) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] - for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len( - { - ".".join(layer.split(".")[:3]) - for layer in vae_state_dict - if "decoder.up" in layer - } - ) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] - for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [ - key - for key in down_blocks[i] - if f"down.{i}" in key and f"down.{i}.downsample" not in key - ] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[ - f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" - ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") - new_checkpoint[ - f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" - ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint( - paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] - ) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint( - paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] - ) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] - ) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key - for key in up_blocks[block_id] - if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[ - f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" - ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] - new_checkpoint[ - f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" - ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint( - paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] - ) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint( - paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] - ) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] - ) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint - - -# =================# -# UNet Conversion # -# =================# -def convert_ldm_unet_checkpoint(unet_state_dict, layers_per_block=2): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict[ - "time_embed.0.weight" - ] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict[ - "time_embed.0.bias" - ] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict[ - "time_embed.2.weight" - ] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict[ - "time_embed.2.bias" - ] - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len( - { - ".".join(layer.split(".")[:2]) - for layer in unet_state_dict - if "input_blocks" in layer - } - ) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len( - { - ".".join(layer.split(".")[:2]) - for layer in unet_state_dict - if "middle_block" in layer - } - ) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len( - { - ".".join(layer.split(".")[:2]) - for layer in unet_state_dict - if "output_blocks" in layer - } - ) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (layers_per_block + 1) - layer_in_block_id = (i - 1) % (layers_per_block + 1) - - resnets = [ - key - for key in input_blocks[i] - if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[ - f"down_blocks.{block_id}.downsamplers.0.conv.weight" - ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight") - new_checkpoint[ - f"down_blocks.{block_id}.downsamplers.0.conv.bias" - ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") - - paths = renew_resnet_paths(resnets) - meta_path = { - "old": f"input_blocks.{i}.0", - "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path] - ) - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"input_blocks.{i}.1", - "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - ) - - for i in range(num_output_blocks): - block_id = i // (layers_per_block + 1) - layer_in_block_id = i % (layers_per_block + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [ - key for key in output_blocks[i] if f"output_blocks.{i}.1" in key - ] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = { - "old": f"output_blocks.{i}.0", - "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - ) - - output_block_list = {k: sorted(v) for k, v in output_block_list.items()} - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index( - ["conv.bias", "conv.weight"] - ) - new_checkpoint[ - f"up_blocks.{block_id}.upsamplers.0.conv.weight" - ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"] - new_checkpoint[ - f"up_blocks.{block_id}.upsamplers.0.conv.bias" - ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - ) - else: - resnet_0_paths = renew_resnet_paths( - output_block_layers, n_shave_prefix_segments=1 - ) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join( - [ - "up_blocks", - str(block_id), - "resnets", - str(layer_in_block_id), - path["new"], - ] - ) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - return new_checkpoint - - -# =========================# -# AITemplate mapping # -# =========================# -def map_unet_state_dict(state_dict, dim=320): - params_ait = {} - for key, arr in state_dict.items(): - arr = arr.to("cuda", dtype=torch.float16) - if len(arr.shape) == 4: - arr = arr.permute((0, 2, 3, 1)).contiguous() - elif key.endswith("ff.net.0.proj.weight"): - # print("ff.net.0.proj.weight") - w1, w2 = arr.chunk(2, dim=0) - params_ait[key.replace(".", "_")] = w1 - params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 - continue - elif key.endswith("ff.net.0.proj.bias"): - # print("ff.net.0.proj.bias") - w1, w2 = arr.chunk(2, dim=0) - params_ait[key.replace(".", "_")] = w1 - params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 - continue - params_ait[key.replace(".", "_")] = arr - - params_ait["arange"] = ( - torch.arange(start=0, end=dim // 2, dtype=torch.float32).cuda().half() - ) - return params_ait - - -def map_clip_state_dict(state_dict): - params_ait = {} - for key, arr in state_dict.items(): - arr = arr.to("cuda", dtype=torch.float16) - name = key.replace("text_model.", "") - ait_name = name.replace(".", "_") - if name.endswith("out_proj.weight"): - ait_name = ait_name.replace("out_proj", "proj") - elif name.endswith("out_proj.bias"): - ait_name = ait_name.replace("out_proj", "proj") - elif "q_proj" in name: - ait_name = ait_name.replace("q_proj", "proj_q") - elif "k_proj" in name: - ait_name = ait_name.replace("k_proj", "proj_k") - elif "v_proj" in name: - ait_name = ait_name.replace("v_proj", "proj_v") - params_ait[ait_name] = arr - - return params_ait +from .modeling.vae import AutoencoderKL as ait_AutoencoderKL +from .pipeline_utils import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, map_clip_state_dict, map_unet_state_dict def map_controlnet_params(pt_mod): @@ -611,16 +55,14 @@ def map_controlnet_params(pt_mod): params_ait["controlnet_cond_embedding_conv_in_weight"], (0, 1, 0, 0, 0, 0, 0, 0) ) params_ait["arange"] = ( - torch.arange(start=0, end=320 // 2, dtype=torch.float32).cuda().half() + torch.arange(start=0, end=320 // 2, dtype=torch.float32).to(torch.device(0)).half() ) return params_ait class StableDiffusionAITPipeline: - def __init__(self, hf_hub_or_path, ckpt): - self.device = torch.device("cuda") - workdir = "tmp/" - state_dict = None + def __init__(self, hf_hub_or_path, ckpt, workdir="tmp/"): + self.device = torch.device(0) if ckpt is not None: state_dict = torch.load(ckpt, map_location="cpu") while "state_dict" in state_dict: @@ -645,7 +87,6 @@ def __init__(self, hf_hub_or_path, ckpt): # clip_state_dict = convert_text_enc_state_dict(clip_state_dict) unet_state_dict = convert_ldm_unet_checkpoint(unet_state_dict) vae_state_dict = convert_ldm_vae_checkpoint(vae_state_dict) - state_dict = None self.controlnet_ait_exe = self.init_ait_module("ControlNetModel", "./tmp") print("Loading PyTorch ControlNet") @@ -666,7 +107,7 @@ def __init__(self, hf_hub_or_path, ckpt): subfolder="text_encoder", revision="fp16", torch_dtype=torch.float16, - ).cuda() + ).to(self.device) else: config = CLIPTextConfig.from_pretrained( hf_hub_or_path, subfolder="text_encoder" @@ -679,8 +120,8 @@ def __init__(self, hf_hub_or_path, ckpt): print("Folding constants") self.clip_ait_exe.fold_constants() # cleanup - self.clip_pt = None - clip_params_ait = None + del self.clip_pt + del clip_params_ait self.unet_ait_exe = self.init_ait_module( model_name="ControlNetUNet2DConditionModel", workdir=workdir @@ -693,7 +134,7 @@ def __init__(self, hf_hub_or_path, ckpt): subfolder="unet", revision="fp16", torch_dtype=torch.float16, - ).cuda() + ).to(self.device) self.unet_pt = self.unet_pt.state_dict() else: self.unet_pt = unet_state_dict @@ -703,8 +144,8 @@ def __init__(self, hf_hub_or_path, ckpt): print("Folding constants") self.unet_ait_exe.fold_constants() # cleanup - self.unet_pt = None - unet_params_ait = None + del self.unet_pt + del unet_params_ait self.vae_ait_exe = self.init_ait_module( model_name="AutoencoderKL", workdir=workdir @@ -716,19 +157,53 @@ def __init__(self, hf_hub_or_path, ckpt): subfolder="vae", revision="fp16", torch_dtype=torch.float16, - ).cuda() + ).to(self.device) else: self.vae_pt = dict(vae_state_dict) - + in_channels = 3 + out_channels = 3 + down_block_types = [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ] + up_block_types = [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ] + block_out_channels = [128, 256, 512, 512] + layers_per_block = 2 + act_fn = "silu" + latent_channels = 4 + sample_size = 512 + + ait_vae = ait_AutoencoderKL( + 1, + 64, + 64, + in_channels=in_channels, + out_channels=out_channels, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + latent_channels=latent_channels, + sample_size=sample_size, + ) print("Mapping parameters...") - vae_params_ait = map_vae(self.vae_pt) + vae_params_ait = map_vae(ait_vae, self.vae_pt) print("Setting constants") self.vae_ait_exe.set_many_constants_with_tensors(vae_params_ait) print("Folding constants") self.vae_ait_exe.fold_constants() # cleanup - self.vae_pt = None - vae_params_ait = None + del self.vae_pt + del ait_vae + del vae_params_ait self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") self.scheduler = EulerDiscreteScheduler.from_pretrained( @@ -737,56 +212,56 @@ def __init__(self, hf_hub_or_path, ckpt): self.batch = 1 def init_ait_module( - self, - model_name, - workdir, + self, + model_name, + workdir, ): mod = Model(os.path.join(workdir, model_name, "test.so")) return mod def controlnet_inference( - self, latent_model_input, timesteps, encoder_hidden_states, controlnet_cond + self, latent_model_input, timesteps, encoder_hidden_states, controlnet_cond ): exe_module = self.controlnet_ait_exe timesteps_pt = timesteps.expand(latent_model_input.shape[0]) inputs = { "input0": latent_model_input.permute((0, 2, 3, 1)) .contiguous() - .cuda() + .to(self.device) .half(), - "input1": timesteps_pt.cuda().half(), - "input2": encoder_hidden_states.cuda().half(), - "input3": controlnet_cond.permute((0, 2, 3, 1)).contiguous().cuda().half(), + "input1": timesteps_pt.to(self.device).half(), + "input2": encoder_hidden_states.to(self.device).half(), + "input3": controlnet_cond.permute((0, 2, 3, 1)).contiguous().to(self.device).half(), } ys = [] num_outputs = len(exe_module.get_output_name_to_index_map()) for i in range(num_outputs): shape = exe_module.get_output_maximum_shape(i) - ys.append(torch.empty(shape).cuda().half()) + ys.append(torch.empty(shape).to(self.device).half()) exe_module.run_with_tensors(inputs, ys, graph_mode=False) down_block_residuals = (y for y in ys[:-1]) mid_block_residuals = ys[-1] return down_block_residuals, mid_block_residuals def unet_inference( - self, - latent_model_input, - timesteps, - encoder_hidden_states, - height, - width, - down_block_residuals, - mid_block_residual, + self, + latent_model_input, + timesteps, + encoder_hidden_states, + height, + width, + down_block_residuals, + mid_block_residual, ): exe_module = self.unet_ait_exe timesteps_pt = timesteps.expand(self.batch * 2) inputs = { "input0": latent_model_input.permute((0, 2, 3, 1)) .contiguous() - .cuda() + .to(self.device) .half(), - "input1": timesteps_pt.cuda().half(), - "input2": encoder_hidden_states.cuda().half(), + "input1": timesteps_pt.to(self.device).half(), + "input2": encoder_hidden_states.to(self.device).half(), } for i, y in enumerate(down_block_residuals): inputs[f"down_block_residual_{i}"] = y @@ -798,7 +273,7 @@ def unet_inference( shape[0] = self.batch * 2 shape[1] = height // 8 shape[2] = width // 8 - ys.append(torch.empty(shape).cuda().half()) + ys.append(torch.empty(shape).to(self.device).half()) exe_module.run_with_tensors(inputs, ys, graph_mode=False) noise_pred = ys[0].permute((0, 3, 1, 2)).float() return noise_pred @@ -806,7 +281,7 @@ def unet_inference( def clip_inference(self, input_ids, seqlen=77): exe_module = self.clip_ait_exe bs = input_ids.shape[0] - position_ids = torch.arange(seqlen).expand((bs, -1)).cuda() + position_ids = torch.arange(seqlen).expand((bs, -1)).to(self.device) inputs = { "input0": input_ids, "input1": position_ids, @@ -816,13 +291,13 @@ def clip_inference(self, input_ids, seqlen=77): for i in range(num_outputs): shape = exe_module.get_output_maximum_shape(i) shape[0] = self.batch - ys.append(torch.empty(shape).cuda().half()) + ys.append(torch.empty(shape).to(self.device).half()) exe_module.run_with_tensors(inputs, ys, graph_mode=False) return ys[0].float() def vae_inference(self, vae_input, height, width): exe_module = self.vae_ait_exe - inputs = [torch.permute(vae_input, (0, 2, 3, 1)).contiguous().cuda().half()] + inputs = [torch.permute(vae_input, (0, 2, 3, 1)).contiguous().to(self.device).half()] ys = [] num_outputs = len(exe_module.get_output_name_to_index_map()) for i in range(num_outputs): @@ -830,26 +305,26 @@ def vae_inference(self, vae_input, height, width): shape[0] = self.batch * 2 shape[1] = height shape[2] = width - ys.append(torch.empty(shape).cuda().half()) + ys.append(torch.empty(shape).to(self.device).half()) exe_module.run_with_tensors(inputs, ys, graph_mode=False) vae_out = ys[0].permute((0, 3, 1, 2)).float() return vae_out @torch.no_grad() def __call__( - self, - prompt: Union[str, List[str]], - control_cond: torch.FloatTensor, - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, + self, + prompt: Union[str, List[str]], + control_cond: torch.FloatTensor, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, ): r""" Function invoked when calling the pipeline for generation. @@ -1040,7 +515,7 @@ def __call__( if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond + noise_pred_text - noise_pred_uncond ) latents = self.scheduler.step( @@ -1062,7 +537,7 @@ def __call__( image = numpy_to_pil(image) if not return_dict: - return (image, has_nsfw_concept) + return image, has_nsfw_concept return StableDiffusionPipelineOutput( images=image, nsfw_content_detected=has_nsfw_concept diff --git a/examples/05_stable_diffusion/src/pipeline_stable_diffusion_img2img_ait.py b/examples/05_stable_diffusion/src/pipeline_stable_diffusion_img2img_ait.py index 893db028d..1e249e6d0 100644 --- a/examples/05_stable_diffusion/src/pipeline_stable_diffusion_img2img_ait.py +++ b/examples/05_stable_diffusion/src/pipeline_stable_diffusion_img2img_ait.py @@ -18,12 +18,9 @@ import os from typing import List, Optional, Union -import numpy as np - import PIL import torch from aitemplate.compiler import Model - from diffusers import ( AutoencoderKL, DDIMScheduler, @@ -37,16 +34,7 @@ StableDiffusionSafetyChecker, ) from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - - -def preprocess(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 +from .pipeline_utils import preprocess class StableDiffusionImg2ImgAITPipeline(StableDiffusionImg2ImgPipeline): @@ -78,15 +66,15 @@ class StableDiffusionImg2ImgAITPipeline(StableDiffusionImg2ImgPipeline): """ def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__( vae=vae, @@ -122,9 +110,9 @@ def __init__( self.batch = 1 def init_ait_module( - self, - model_name, - workdir, + self, + model_name, + workdir, ): mod = Model(os.path.join(workdir, model_name, "test.so")) return mod @@ -182,16 +170,16 @@ def vae_inference(self, vae_input): @torch.no_grad() def __call__( - self, - prompt: Union[str, List[str]], - init_image: Union[torch.FloatTensor, PIL.Image.Image], - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, ): r""" Function invoked when calling the pipeline for generation. @@ -353,7 +341,7 @@ def __call__( if isinstance(self.scheduler, LMSDiscreteScheduler): sigma = self.scheduler.sigmas[t_index] # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5) latent_model_input = latent_model_input.to(self.unet.dtype) t = t.to(self.unet.dtype) @@ -366,7 +354,7 @@ def __call__( if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond + noise_pred_text - noise_pred_uncond ) # compute the previous noisy sample x_t -> x_t-1 diff --git a/examples/05_stable_diffusion/src/pipeline_utils.py b/examples/05_stable_diffusion/src/pipeline_utils.py new file mode 100644 index 000000000..133ddffd3 --- /dev/null +++ b/examples/05_stable_diffusion/src/pipeline_utils.py @@ -0,0 +1,570 @@ +import PIL +import numpy as np +import torch + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, additional_replacements=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance( + paths, list + ), "Paths should be a list of dicts containing 'old' and 'new' keys." + + for path in paths: + new_path = path["new"] + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +# ================# +# VAE Conversion # +# ================# + + +def convert_ldm_vae_checkpoint(vae_state_dict): + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[ + "encoder.conv_out.weight" + ] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[ + "encoder.norm_out.weight" + ] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[ + "encoder.norm_out.bias" + ] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[ + "decoder.conv_out.weight" + ] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[ + "decoder.norm_out.weight" + ] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[ + "decoder.norm_out.bias" + ] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len( + { + ".".join(layer.split(".")[:3]) + for layer in vae_state_dict + if "encoder.down" in layer + } + ) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] + for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len( + { + ".".join(layer.split(".")[:3]) + for layer in vae_state_dict + if "decoder.up" in layer + } + ) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] + for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [ + key + for key in down_blocks[i] + if f"down.{i}" in key and f"down.{i}.downsample" not in key + ] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[ + f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" + ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") + new_checkpoint[ + f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" + ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint( + paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] + ) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint( + paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] + ) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] + ) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key + for key in up_blocks[block_id] + if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[ + f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" + ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] + new_checkpoint[ + f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" + ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint( + paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] + ) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint( + paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] + ) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path] + ) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +# =================# +# UNet Conversion # +# =================# +def convert_ldm_unet_checkpoint(unet_state_dict, layers_per_block=2): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict[ + "time_embed.0.weight" + ] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict[ + "time_embed.0.bias" + ] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict[ + "time_embed.2.weight" + ] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict[ + "time_embed.2.bias" + ] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "input_blocks" in layer + } + ) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "middle_block" in layer + } + ) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "output_blocks" in layer + } + ) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (layers_per_block + 1) + layer_in_block_id = (i - 1) % (layers_per_block + 1) + + resnets = [ + key + for key in input_blocks[i] + if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[ + f"down_blocks.{block_id}.downsamplers.0.conv.weight" + ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight") + new_checkpoint[ + f"down_blocks.{block_id}.downsamplers.0.conv.bias" + ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") + + paths = renew_resnet_paths(resnets) + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path] + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"input_blocks.{i}.1", + "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + ) + + for i in range(num_output_blocks): + block_id = i // (layers_per_block + 1) + layer_in_block_id = i % (layers_per_block + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [ + key for key in output_blocks[i] if f"output_blocks.{i}.1" in key + ] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index( + ["conv.bias", "conv.weight"] + ) + new_checkpoint[ + f"up_blocks.{block_id}.upsamplers.0.conv.weight" + ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"] + new_checkpoint[ + f"up_blocks.{block_id}.upsamplers.0.conv.bias" + ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + ) + else: + resnet_0_paths = renew_resnet_paths( + output_block_layers, n_shave_prefix_segments=1 + ) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join( + [ + "up_blocks", + str(block_id), + "resnets", + str(layer_in_block_id), + path["new"], + ] + ) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +def preprocess(image, width=512, height=512): + width, height = map(lambda x: x - x % 32, (width, height)) # resize to integer multiple of 32 + image = image.resize((width, height), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def map_clip_state_dict(state_dict): + params_ait = {} + for key, arr in state_dict.items(): + arr = arr.to("cuda", dtype=torch.float16) + name = key.replace("text_model.", "") + ait_name = name.replace(".", "_") + if name.endswith("out_proj.weight"): + ait_name = ait_name.replace("out_proj", "proj") + elif name.endswith("out_proj.bias"): + ait_name = ait_name.replace("out_proj", "proj") + elif "q_proj" in name: + ait_name = ait_name.replace("q_proj", "proj_q") + elif "k_proj" in name: + ait_name = ait_name.replace("k_proj", "proj_k") + elif "v_proj" in name: + ait_name = ait_name.replace("v_proj", "proj_v") + params_ait[ait_name] = arr + + return params_ait + + +# =========================# +# AITemplate mapping # +# =========================# +def map_unet_state_dict(state_dict, dim=320): + params_ait = {} + for key, arr in state_dict.items(): + arr = arr.to("cuda", dtype=torch.float16) + if len(arr.shape) == 4: + arr = arr.permute((0, 2, 3, 1)).contiguous() + elif key.endswith("ff.net.0.proj.weight"): + # print("ff.net.0.proj.weight") + w1, w2 = arr.chunk(2, dim=0) + params_ait[key.replace(".", "_")] = w1 + params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 + continue + elif key.endswith("ff.net.0.proj.bias"): + # print("ff.net.0.proj.bias") + w1, w2 = arr.chunk(2, dim=0) + params_ait[key.replace(".", "_")] = w1 + params_ait[key.replace(".", "_").replace("proj", "gate")] = w2 + continue + params_ait[key.replace(".", "_")] = arr + + params_ait["arange"] = ( + torch.arange(start=0, end=dim // 2, dtype=torch.float32).cuda().half() + ) + return params_ait diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 000000000..e69de29bb