Skip to content

Commit

Permalink
add inversion
Browse files Browse the repository at this point in the history
  • Loading branch information
avdravid authored Jul 19, 2024
1 parent c6b04ba commit 5ef5ed7
Show file tree
Hide file tree
Showing 8 changed files with 681 additions and 0 deletions.
Binary file added inversion/images/ood_image/mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inversion/images/ood_image/ood/im/painting.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inversion/images/real_image/mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added inversion/images/real_image/real/im/43883.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
290 changes: 290 additions & 0 deletions inversion/inversion_ood.ipynb

Large diffs are not rendered by default.

282 changes: 282 additions & 0 deletions inversion/inversion_real.ipynb

Large diffs are not rendered by default.

87 changes: 87 additions & 0 deletions inversion/invert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import sys
import os
sys.path.append(os.path.abspath(os.path.join("", "..")))
import torch
import torchvision
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
from lora_w2w import LoRAw2w
from utils import load_models, inference, save_model_w2w, save_model_for_diffusers
from inversion import invert
import argparse


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--device", default="cuda:0", type=str)
parser.add_argument("--mean_path", default="/files/mean.pt", type=str, help="Path to file with parameter means")
parser.add_argument("--std_path", default="/files/std.pt", type=str, help="Path to file with parameter standard deviations.")
parser.add_argument("--v_path", default="/files/V.pt", type=str, help="Path to V orthogonal projection/unprojection matrix.")
parser.add_argument("--dim_path", default="/files/weight_dimensions.pt", type=str, help="Path to file with dimensions of LoRA layers. Used for saving in Diffusers pipeline format.")
parser.add_argument("--imfolder", default="/inversion/images/real_image/real/", type=str, help="Path to folder containing image.")
parser.add_argument("--mask_path", default=None, type=str, help="Path to mask file.")
parser.add_argument("--epochs", default=400, type=int)
parser.add_argument("--lr", default= 1e-1, type=float)
parser.add_argument("--weight_decay", default= 1e-10, type=float)
parser.add_argument("--dim", default= 10000, type=int, help="Number of principal component coefficients to optimize.")
parser.add_argument("--diffusers_format", default=False, action="store_true", help="Whether to save in mode that can be loaded in Diffusers pipeline")
parser.add_argument("--save_name", default="/files/inversion1.pt", type=str, help="Output path + filename.")



### variables
args = parser.parse_args()
device = args.device
mean_path = args.mean_path
std_path = args.std_path
v_path = args.v_path
dim_path = args.dim_path
imfolder = args.imfolder
mask_path = args.mask_path
epochs = args.epochs
lr = args.lr
weight_decay = args.weight_decay
dim = args.dim
diffusers_format = args.diffusers_format
save_name = args.save_name


### load models
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)

### load files
mean = torch.load(mean_path).bfloat16().to(device)
std = torch.load(std_path).bfloat16().to(device)
v = torch.load(v_path).bfloat16().to(device)
weight_dimensions = torch.load(dim_path)

### initialize network

proj = torch.zeros(1,dim).bfloat16().to(device)
network = LoRAw2w( proj, mean, std, v[:,:dim],
unet,
rank=1,
multiplier=1.0,
alpha=27.0,
train_method="xattn-strict"
).to(device, torch.bfloat16)
### run inversion
network = invert(network=network, unet=unet, vae=vae,
text_encoder=text_encoder, tokenizer=tokenizer,
prompt = "sks person", noise_scheduler = noise_scheduler, epochs=epochs,
image_path = imfolder, mask_path = mask_path, device = device)


### save model

if diffusers_format:
save_model_for_diffusers(network,std, mean, v, weight_dimensions,
path=save_name)
else:
save_model_w2w(network, path=save_name)



if __name__ == "__main__":
main()
22 changes: 22 additions & 0 deletions inversion/run_inversion.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
python inversion/invert.py \
--device="cuda:0" \
--mean_path="files/mean.pt" \
--std_path="files/std.pt" \
--v_path="files/V.pt" \
--std_path="files/std.pt" \
--dim_path="files/weight_dimensions.pt" \
--imfolder="inversion/images/real_image/real/" \
--mask_path="inversion/images/real_image/mask.png" \
--epochs=400 \
--lr=1e-1 \
--weight_decay=1e-10 \
--dim=10000 \
--save_name="files/real_inversion1.pt"








0 comments on commit 5ef5ed7

Please sign in to comment.