forked from tencent-ailab/IP-Adapter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tutorial_train_faceid.py
407 lines (347 loc) · 15.7 KB
/
tutorial_train_faceid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from transformers import CLIPImageProcessor
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ip_adapter.ip_adapter_faceid import MLPProjModel
from ip_adapter.utils import is_torch2_available
from ip_adapter.attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
# Dataset
class MyDataset(torch.utils.data.Dataset):
def __init__(self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""):
super().__init__()
self.tokenizer = tokenizer
self.size = size
self.i_drop_rate = i_drop_rate
self.t_drop_rate = t_drop_rate
self.ti_drop_rate = ti_drop_rate
self.image_root_path = image_root_path
self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "id_embed_file": "faceid.bin"}]
self.transform = transforms.Compose([
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(self.size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def __getitem__(self, idx):
item = self.data[idx]
text = item["text"]
image_file = item["image_file"]
# read image
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
image = self.transform(raw_image.convert("RGB"))
face_id_embed = torch.load(item["id_embed_file"], map_location="cpu")
face_id_embed = torch.from_numpy(face_id_embed)
# drop
drop_image_embed = 0
rand_num = random.random()
if rand_num < self.i_drop_rate:
drop_image_embed = 1
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
text = ""
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
text = ""
drop_image_embed = 1
if drop_image_embed:
face_id_embed = torch.zeros_like(face_id_embed)
# get text and tokenize
text_input_ids = self.tokenizer(
text,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
return {
"image": image,
"text_input_ids": text_input_ids,
"face_id_embed": face_id_embed,
"drop_image_embed": drop_image_embed
}
def __len__(self):
return len(self.data)
def collate_fn(data):
images = torch.stack([example["image"] for example in data])
text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
face_id_embed = torch.stack([example["face_id_embed"] for example in data])
drop_image_embeds = [example["drop_image_embed"] for example in data]
return {
"images": images,
"text_input_ids": text_input_ids,
"face_id_embed": face_id_embed,
"drop_image_embeds": drop_image_embeds
}
class IPAdapter(torch.nn.Module):
"""IP-Adapter"""
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
super().__init__()
self.unet = unet
self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules
if ckpt_path is not None:
self.load_from_checkpoint(ckpt_path)
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
ip_tokens = self.image_proj_model(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
# Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
return noise_pred
def load_from_checkpoint(self, ckpt_path: str):
# Calculate original checksums
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
state_dict = torch.load(ckpt_path, map_location="cpu")
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
# Verify if the weights have changed
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_ip_adapter_path",
type=str,
default=None,
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
)
parser.add_argument(
"--data_json_file",
type=str,
default=None,
required=True,
help="Training data",
)
parser.add_argument(
"--data_root_path",
type=str,
default="",
required=True,
help="Training data root path",
)
parser.add_argument(
"--image_encoder_path",
type=str,
default=None,
required=True,
help="Path to CLIP image encoder",
)
parser.add_argument(
"--output_dir",
type=str,
default="sd-ip_adapter",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images"
),
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate to use.",
)
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--save_steps",
type=int,
default=2000,
help=(
"Save a checkpoint of the training state every X updates"
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
def main():
args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
# image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
#image_encoder.requires_grad_(False)
#ip-adapter
image_proj_model = MLPProjModel(
cross_attention_dim=unet.config.cross_attention_dim,
id_embeddings_dim=512,
num_tokens=4,
)
# init adapter modules
lora_rank = 128
attn_procs = {}
unet_sd = unet.state_dict()
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank)
else:
layer_name = name.split(".processor")[0]
weights = {
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
}
attn_procs[name] = LoRAIPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank)
attn_procs[name].load_state_dict(weights, strict=False)
unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
#unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
#image_encoder.to(accelerator.device, dtype=weight_dtype)
# optimizer
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
# dataloader
train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Prepare everything with our `accelerator`.
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
global_step = 0
for epoch in range(0, args.num_train_epochs):
begin = time.perf_counter()
for step, batch in enumerate(train_dataloader):
load_data_time = time.perf_counter() - begin
with accelerator.accumulate(ip_adapter):
# Convert images to latent space
with torch.no_grad():
latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
image_embeds = batch["face_id_embed"].to(accelerator.device, dtype=weight_dtype)
with torch.no_grad():
encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
# Backpropagate
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
if accelerator.is_main_process:
print("Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
epoch, step, load_data_time, time.perf_counter() - begin, avg_loss))
global_step += 1
if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
begin = time.perf_counter()
if __name__ == "__main__":
main()