-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathif_inference.py
144 lines (122 loc) · 5.79 KB
/
if_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from transformers import T5EncoderModel
from pipeline_if import IFPipeline as DiffusionPipeline
from scheduling_ddpm import step
from diffusers.schedulers import DDPMScheduler
import gc
import torch
from dnnlib.util import save_images, save_image
import os
import math
from collections import OrderedDict
import numpy as np
import dnnlib
import click
from tqdm import tqdm
def flush():
gc.collect()
torch.cuda.empty_cache()
@click.command()
@click.option('--upscale', is_flag=True, help='Upscale images', default=False)
@click.option('--checkpoint_path', type=str, help='Path to checkpoint file', required=True)
@click.option('--output_dir', type=str, help='Output directory', required=True)
@click.option('--batch_size', type=int, help='Batch size', default=16)
@click.option('--num_images', type=int, help='Number of images', default=50000)
@click.option('--batch_resume_index', type=int, help='Batch resume index', default=0)
@click.option('--save_collage', is_flag=True, help='Save collage', default=False)
@click.option('--save_separate', is_flag=True, help='Save separate', default=True)
@click.option('--corruption_probability', type=float, help='Corruption probability', default=0.8)
@click.option('--delta_probability', type=float, help='Delta probability', default=0.1)
@click.option('--corruption_pattern', type=str, help='Corruption pattern', default="dust")
def main(upscale, checkpoint_path, output_dir, batch_size, num_images, batch_resume_index, save_collage, save_separate, corruption_probability, delta_probability, corruption_pattern):
os.makedirs(output_dir, exist_ok=True)
text_encoder = T5EncoderModel.from_pretrained(
"DeepFloyd/IF-I-XL-v1.0",
subfolder="text_encoder",
device_map="auto",
load_in_8bit=True,
variant="8bit"
)
pipe = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-I-XL-v1.0",
text_encoder=text_encoder, # pass the previously instantiated 8bit text encoder
unet=None,
device_map="auto"
)
prompts = batch_size * [""]
prompt_embeds, negative_embeds = pipe.encode_prompt(prompts)
del text_encoder
del pipe
flush()
pipe = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-I-XL-v1.0",
text_encoder=None,
variant="fp16",
torch_dtype=torch.float16,
device_map="auto"
)
if checkpoint_path is not None:
loaded_dict = torch.load(checkpoint_path, map_location="cuda")
fixed_dict = OrderedDict({key.replace("_orig_mod.module.", ""): value for key, value in loaded_dict.items()})
pipe.unet.load_state_dict(fixed_dict)
pipe.scheduler.variance_type = None
pipe.scheduler.config.variance_type = None
pipe.scheduler.step = step.__get__(pipe.scheduler, DDPMScheduler)
for batch_index in tqdm(range(int(np.ceil(num_images / batch_size)))):
if batch_index < batch_resume_index:
continue
num_rows = int(math.sqrt(batch_size))
pipe.safety_checker = None
images = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
output_type="pt",
corruption_probability=corruption_probability,
delta_probability=delta_probability,
guidance_scale=0.0,
corruption_pattern=corruption_pattern,
).images
if save_separate:
for image_index, image in enumerate(images):
save_image(image, os.path.join(output_dir, f"{batch_index}_{image_index}.png"))
if save_collage:
save_images(images, os.path.join(output_dir, f"{batch_index}.png"), num_rows=num_rows, num_cols=num_rows)
if upscale:
del pipe
flush()
assert save_collage == False, "Cannot upscale collages for now"
super_res_pipe = DiffusionPipeline.from_pretrained(
"DeepFloyd/IF-II-L-v1.0",
text_encoder=None, # no use of text encoder => memory savings!
variant="fp16",
torch_dtype=torch.float16,
device_map="auto"
)
c = dnnlib.EasyDict()
c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset',
path=output_dir, use_labels=False, xflip=False, cache=False,
corruption_probability=0.0, delta_probability=0.0)
dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs) # subclass of training.dataset.Dataset
c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=1, prefetch_factor=2)
dataset_sampler = torch.utils.data.distributed.DistributedSampler(dataset_obj, num_replicas=1, rank=0, seed=42, shuffle=False)
dataset_iterator = iter(
torch.utils.data.DataLoader(
dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_size, **c.data_loader_kwargs)
)
with torch.no_grad():
for iter_index, dataset_iter in enumerate(dataset_iterator):
image = dataset_iter[0]
image = torch.tensor(image, device='cuda').to(torch.float32)
super_res_images = super_res_pipe(
image=image,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
guidance_scale=0.0,
output_type="pt",
).images
if save_separate:
for image_index, image in enumerate(super_res_images):
save_image(image, os.path.join(output_dir, f"{batch_index}_{image_index}_super_res.png"))
#----------------------------------------------------------------------------
if __name__ == "__main__":
main()
#----------------------------------------------------------------------------