From 9ea2047e9fe66810f757ed97a8c763cb2c34b6eb Mon Sep 17 00:00:00 2001 From: Amil Dravid <46203730+avdravid@users.noreply.github.com> Date: Thu, 13 Jun 2024 14:32:00 -0700 Subject: [PATCH] add utils --- utils.py | 161 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 utils.py diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..8e22df1 --- /dev/null +++ b/utils.py @@ -0,0 +1,161 @@ +import torch +import torchvision +import os +import shutil +import gc +import tqdm +import matplotlib.pyplot as plt +import torchvision.transforms as transforms +from transformers import CLIPTextModel +from lora_w2w import LoRAw2w +from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler +from safetensors.torch import save_file +from transformers import AutoTokenizer, PretrainedConfig +from PIL import Image +import warnings +warnings.filterwarnings("ignore") +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + UNet2DConditionModel, + PNDMScheduler, + StableDiffusionPipeline +) + + + +######## Basic utilities + +### load base models +def load_models(device): + pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51" + + revision = None + rank = 1 + weight_dtype = torch.bfloat16 + + # Load scheduler, tokenizer and models. + pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51", + torch_dtype=torch.float16,safety_checker = None, + requires_safety_checker = False).to(device) + noise_scheduler = pipe.scheduler + del pipe + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, subfolder="tokenizer", revision=revision + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, subfolder="text_encoder", revision=revision + ) + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision) + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, subfolder="unet", revision=revision + ) + + unet.requires_grad_(False) + unet.to(device, dtype=weight_dtype) + vae.requires_grad_(False) + + text_encoder.requires_grad_(False) + vae.requires_grad_(False) + vae.to(device, dtype=weight_dtype) + text_encoder.to(device, dtype=weight_dtype) + print("") + + return unet, vae, text_encoder, tokenizer, noise_scheduler + + + +### basic inference to generate images conditioned on text prompts +@torch.no_grad +def inference(network, unet, vae, text_encoder, tokenizer, prompt, negative_prompt, guidance_scale, noise_scheduler, ddim_steps, seed, generator, device): + generator = generator.manual_seed(seed) + latents = torch.randn( + (1, unet.in_channels, 512 // 8, 512 // 8), + generator = generator, + device = device + ).bfloat16() + + + text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") + + text_embeddings = text_encoder(text_input.input_ids.to(device))[0] + + max_length = text_input.input_ids.shape[-1] + uncond_input = tokenizer( + [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + noise_scheduler.set_timesteps(ddim_steps) + latents = latents * noise_scheduler.init_noise_sigma + + for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): + latent_model_input = torch.cat([latents] * 2) + latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) + with network: + noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample + #guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + latents = noise_scheduler.step(noise_pred, t, latents).prev_sample + + latents = 1 / 0.18215 * latents + image = vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + + return image + + + +### save model in w2w space (principal component representation) +def save_model_w2w(network, path): + proj = network.proj.clone().detach().float() + + if not os.path.exists(path): + os.makedirs(path) + + torch.save(proj, path+"/"+"w2wmodel.pt") + + +### save model in format compatible with Diffusers +def save_model_for_diffusers(network,std, mean, v, weight_dimensions, path): + proj = network.proj.clone().detach() + unproj = torch.matmul(proj,v[:, :].T)*std+mean + + final_weights0 = {} + counter = 0 + for key in weight_dimensions.keys(): + final_weights0[key] = unproj[0, counter:counter+weight_dimensions[key][0][0]].unflatten(0, weight_dimensions[key][1]) + counter += weight_dimensions[key][0][0] + + #renaming keys to be compatible with Diffusers + for key in list(final_weights0.keys()): + final_weights0[key.replace( "lora_unet_", "base_model.model.").replace("A", "down").replace("B", "up").replace( "weight", "identity1.weight").replace("_lora", ".lora").replace("lora_down", "lora_A").replace("lora_up", "lora_B")] = final_weights0.pop(key) + + + + final_weights0_keys = sorted(final_weights0.keys()) + + final_weights = {} + for i,key in enumerate(final_weights0_keys): + final_weights[key] = final_weights0[key] + + if not os.path.exists(path): + os.makedirs(path+"/unet") + else: + os.mkdir(path+"/unet") + + + + #add config for PeftConfig + shutil.copyfile("../files/adapter_config.json", path+"/unet/adapter_config.json") + + save_file(final_weights, path+"/unet/adapter_model.safetensors") + + + + + +