forked from Yujun-Shi/DragDiffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
drag_utils.py
executable file
·129 lines (105 loc) · 5.38 KB
/
drag_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
import imageio
import numpy as np
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
def point_tracking(F0, F1, handle_points, handle_points_init, args):
with torch.no_grad():
for i in range(len(handle_points)):
pi0, pi = handle_points_init[i], handle_points[i]
f0 = F0[:, :, int(pi0[0]), int(pi0[1])]
r1, r2 = int(pi[0])-args.r_p, int(pi[0])+args.r_p+1
c1, c2 = int(pi[1])-args.r_p, int(pi[1])+args.r_p+1
F1_neighbor = F1[:, :, r1:r2, c1:c2]
all_dist = (f0.unsqueeze(dim=-1).unsqueeze(dim=-1) - F1_neighbor).abs().sum(dim=1)
all_dist = all_dist.squeeze(dim=0)
# WARNING: no boundary protection right now
row, col = divmod(all_dist.argmin().item(), all_dist.shape[-1])
handle_points[i][0] = pi[0] - args.r_p + row
handle_points[i][1] = pi[1] - args.r_p + col
return handle_points
def check_handle_reach_target(handle_points, target_points):
# dist = (torch.cat(handle_points,dim=0) - torch.cat(target_points,dim=0)).norm(dim=-1)
all_dist = list(map(lambda p,q: (p-q).norm(), handle_points, target_points))
return (torch.tensor(all_dist) < 1.0).all()
# obtain the bilinear interpolated feature patch centered around (x, y) with radius r
def interpolate_feature_patch(feat, y, x, r):
x0 = torch.floor(x).long()
x1 = x0 + 1
y0 = torch.floor(y).long()
y1 = y0 + 1
wa = (x1.float() - x) * (y1.float() - y)
wb = (x1.float() - x) * (y - y0.float())
wc = (x - x0.float()) * (y1.float() - y)
wd = (x - x0.float()) * (y - y0.float())
Ia = feat[:, :, y0-r:y0+r+1, x0-r:x0+r+1]
Ib = feat[:, :, y1-r:y1+r+1, x0-r:x0+r+1]
Ic = feat[:, :, y0-r:y0+r+1, x1-r:x1+r+1]
Id = feat[:, :, y1-r:y1+r+1, x1-r:x1+r+1]
return Ia * wa + Ib * wb + Ic * wc + Id * wd
def drag_diffusion_update(model, init_code, t, handle_points, target_points, mask, args):
assert len(handle_points) == len(target_points), \
"number of handle point must equals target points"
text_emb = model.get_text_embeddings(args.prompt).detach()
# the init output feature of unet
with torch.no_grad():
unet_output, F0 = model.forward_unet_features(init_code, t, encoder_hidden_states=text_emb,
layer_idx=args.unet_feature_idx, interp_res=args.sup_res)
x_prev_0, _ = model.step(unet_output, t, init_code)
# prepare optimizable init_code and optimizer
init_code.requires_grad_(True)
optimizer = torch.optim.Adam([init_code], lr=args.lr)
# prepare for point tracking and background regularization
handle_points_init = copy.deepcopy(handle_points)
interp_mask = F.interpolate(mask, (init_code.shape[2],init_code.shape[3]), mode='nearest')
# prepare amp scaler for mixed-precision training
scaler = torch.cuda.amp.GradScaler()
for step_idx in range(args.n_pix_step):
with torch.autocast(device_type='cuda', dtype=torch.float16):
unet_output, F1 = model.forward_unet_features(init_code, t, encoder_hidden_states=text_emb,
layer_idx=args.unet_feature_idx, interp_res=args.sup_res)
x_prev_updated, _ = model.step(unet_output, t, init_code)
# do point tracking to update handle points before computing motion supervision loss
if step_idx != 0:
handle_points = point_tracking(F0, F1, handle_points, handle_points_init, args)
print('new handle points', handle_points)
# break if all handle points have reached the targets
if check_handle_reach_target(handle_points, target_points):
break
loss = 0.0
for i in range(len(handle_points)):
pi, ti = handle_points[i], target_points[i]
# skip if the distance between target and source is less than 1
if (ti - pi).norm() < 1:
continue
di = (ti - pi) / (ti - pi).norm()
# motion supervision
f0_patch = F1[:,:,int(pi[0])-args.r_m:int(pi[0])+args.r_m+1, int(pi[1])-args.r_m:int(pi[1])+args.r_m+1].detach()
f1_patch = interpolate_feature_patch(F1, pi[0] + di[0], pi[1] + di[1], args.r_m)
loss += ((2*args.r_m+1)**2)*F.l1_loss(f0_patch, f1_patch)
# masked region must stay unchanged
loss += args.lam * ((x_prev_updated-x_prev_0)*(1.0-interp_mask)).abs().sum()
print('loss total=%f'%(loss.item()))
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
return init_code, text_emb