-
Notifications
You must be signed in to change notification settings - Fork 21
/
consistory_utils.py
200 lines (146 loc) · 7.49 KB
/
consistory_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Copyright (C) 2024 NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the LICENSE file
# located at the root directory.
import numpy as np
import torch
from collections import defaultdict
from diffusers.utils.import_utils import is_xformers_available
from typing import Optional, List
from utils.general_utils import get_dynamic_threshold
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class FeatureInjector:
def __init__(self, nn_map, nn_distances, attn_masks, inject_range_alpha=[(10,20,0.8)], swap_strategy='min', dist_thr='dynamic', inject_unet_parts=['up']):
self.nn_map = nn_map
self.nn_distances = nn_distances
self.attn_masks = attn_masks
self.inject_range_alpha = inject_range_alpha if isinstance(inject_range_alpha, list) else [inject_range_alpha]
self.swap_strategy = swap_strategy # 'min / 'mean' / 'first'
self.dist_thr = dist_thr
self.inject_unet_parts = inject_unet_parts
self.inject_res = [64]
def inject_outputs(self, output, curr_iter, output_res, extended_mapping, place_in_unet, anchors_cache=None):
curr_unet_part = place_in_unet.split('_')[0]
# Inject only in the specified unet parts (up, mid, down)
if (curr_unet_part not in self.inject_unet_parts) or output_res not in self.inject_res:
return output
bsz = output.shape[0]
nn_map = self.nn_map[output_res]
nn_distances = self.nn_distances[output_res]
attn_masks = self.attn_masks[output_res]
vector_dim = output_res**2
alpha = next((alpha for min_range, max_range, alpha in self.inject_range_alpha if min_range <= curr_iter <= max_range), None)
if alpha:
old_output = output#.clone()
for i in range(bsz):
other_outputs = []
if self.swap_strategy == 'min':
curr_mapping = extended_mapping[i]
# If the current image is not mapped to any other image, skip
if not torch.any(torch.cat([curr_mapping[:i], curr_mapping[i+1:]])):
continue
min_dists = nn_distances[i][curr_mapping].argmin(dim=0)
curr_nn_map = nn_map[i][curr_mapping][min_dists, torch.arange(vector_dim)]
curr_nn_distances = nn_distances[i][curr_mapping][min_dists, torch.arange(vector_dim)]
dist_thr = get_dynamic_threshold(curr_nn_distances) if self.dist_thr == 'dynamic' else self.dist_thr
dist_mask = curr_nn_distances < dist_thr
final_mask_tgt = attn_masks[i] & dist_mask
other_outputs = old_output[curr_mapping][min_dists, curr_nn_map][final_mask_tgt]
output[i][final_mask_tgt] = alpha * other_outputs + (1 - alpha)*old_output[i][final_mask_tgt]
if anchors_cache and anchors_cache.is_cache_mode():
if place_in_unet not in anchors_cache.h_out_cache:
anchors_cache.h_out_cache[place_in_unet] = {}
anchors_cache.h_out_cache[place_in_unet][curr_iter] = output
return output
def inject_anchors(self, output, curr_iter, output_res, extended_mapping, place_in_unet, anchors_cache):
curr_unet_part = place_in_unet.split('_')[0]
# Inject only in the specified unet parts (up, mid, down)
if (curr_unet_part not in self.inject_unet_parts) or output_res not in self.inject_res:
return output
bsz = output.shape[0]
nn_map = self.nn_map[output_res]
nn_distances = self.nn_distances[output_res]
attn_masks = self.attn_masks[output_res]
vector_dim = output_res**2
alpha = next((alpha for min_range, max_range, alpha in self.inject_range_alpha if min_range <= curr_iter <= max_range), None)
if alpha:
anchor_outputs = anchors_cache.h_out_cache[place_in_unet][curr_iter]
old_output = output#.clone()
for i in range(bsz):
other_outputs = []
if self.swap_strategy == 'min':
min_dists = nn_distances[i].argmin(dim=0)
curr_nn_map = nn_map[i][min_dists, torch.arange(vector_dim)]
curr_nn_distances = nn_distances[i][min_dists, torch.arange(vector_dim)]
dist_thr = get_dynamic_threshold(curr_nn_distances) if self.dist_thr == 'dynamic' else self.dist_thr
dist_mask = curr_nn_distances < dist_thr
final_mask_tgt = attn_masks[i] & dist_mask
other_outputs = anchor_outputs[min_dists, curr_nn_map][final_mask_tgt]
output[i][final_mask_tgt] = alpha * other_outputs + (1 - alpha)*old_output[i][final_mask_tgt]
return output
class AnchorCache:
def __init__(self):
self.input_h_cache = {} # place_in_unet, iter, h_in
self.h_out_cache = {} # place_in_unet, iter, h_out
self.anchors_last_mask = None
self.dift_cache = None
self.mode = 'cache' # mode can be 'cache' or 'inject'
def set_mode(self, mode):
self.mode = mode
def set_mode_inject(self):
self.mode = 'inject'
def set_mode_cache(self):
self.mode = 'cache'
def is_inject_mode(self):
return self.mode == 'inject'
def is_cache_mode(self):
return self.mode == 'cache'
def to_device(self, device):
for key, value in self.input_h_cache.items():
self.input_h_cache[key] = {k: v.to(device) for k, v in value.items()}
for key, value in self.h_out_cache.items():
self.h_out_cache[key] = {k: v.to(device) for k, v in value.items()}
if self.anchors_last_mask:
self.anchors_last_mask = {k: v.to(device) for k, v in self.anchors_last_mask.items()}
if self.dift_cache is not None:
self.dift_cache = self.dift_cache.to(device)
class QueryStore:
def __init__(self, mode='store', t_range=[0, 1000], strength_start=1, strength_end=1):
"""
Initialize an empty ActivationsStore
"""
self.query_store = defaultdict(list)
self.mode = mode
self.t_range = t_range
self.strengthes = np.linspace(strength_start, strength_end, (t_range[1] - t_range[0])+1)
def set_mode(self, mode): # mode can be 'cache' or 'inject'
self.mode = mode
def cache_query(self, query, place_in_unet: str):
self.query_store[place_in_unet] = query
def inject_query(self, query, place_in_unet, t):
if t >= self.t_range[0] and t <= self.t_range[1]:
relative_t = t - self.t_range[0]
strength = self.strengthes[relative_t]
new_query = strength * self.query_store[place_in_unet] + (1 - strength) * query
else:
new_query = query
return new_query
class DIFTLatentStore:
def __init__(self, steps: List[int], up_ft_indices: List[int]):
self.steps = steps
self.up_ft_indices = up_ft_indices
self.dift_features = {}
def __call__(self, features: torch.Tensor, t: int, layer_index: int):
if t in self.steps and layer_index in self.up_ft_indices:
self.dift_features[f'{int(t)}_{layer_index}'] = features
def copy(self):
copy_dift = DIFTLatentStore(self.steps, self.up_ft_indices)
for key, value in self.dift_features.items():
copy_dift.dift_features[key] = value.clone()
return copy_dift
def reset(self):
self.dift_features = {}