-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmodels_cross.py
336 lines (272 loc) · 13.8 KB
/
models_cross.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
from functools import partial
import numpy as np
import math
import torch
import torch.nn as nn
from transformer_utils import Block, CrossAttentionBlock, PatchEmbed
from util.pos_embed import get_2d_sincos_pos_embed
class WeightedFeatureMaps(nn.Module):
def __init__(self, k, embed_dim, *, norm_layer=nn.LayerNorm, decoder_depth):
super(WeightedFeatureMaps, self).__init__()
self.linear = nn.Linear(k, decoder_depth, bias=False)
std_dev = 1. / math.sqrt(k)
nn.init.normal_(self.linear.weight, mean=0., std=std_dev)
def forward(self, feature_maps):
# Ensure the input is a list
assert isinstance(feature_maps, list), "Input should be a list of feature maps"
# Ensure the list has the same length as the number of weights
assert len(feature_maps) == (self.linear.weight.shape[1]), "Number of feature maps and weights should match"
stacked_feature_maps = torch.stack(feature_maps, dim=-1) # shape: (B, L, C, k)
# compute a weighted average of the feature maps
# decoder_depth is denoted as j
output = self.linear(stacked_feature_maps)
return output
class MaskedAutoencoderViT(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3,
embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
weight_fm=False,
use_fm=[-1], use_input=False, self_attn=False,
):
super().__init__()
# --------------------------------------------------------------------------
# MAE encoder specifics
self.img_size = img_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.decoder_embed_dim = decoder_embed_dim
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) # these are needed regardless of the patch sampling method
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(depth)])
# --------------------------------------------------------------------------
# weighted feature maps for cross attention
self.weight_fm = weight_fm
self.use_input = use_input # use input as one of the feature maps
if len(use_fm) == 1 and use_fm[0] == -1:
self.use_fm = list(range(depth))
else:
self.use_fm = [i if i >= 0 else depth + i for i in use_fm]
if self.weight_fm:
# print("Weighting feature maps!")
# print("using feature maps: ", self.use_fm)
dec_norms = []
for i in range(decoder_depth):
norm_layer_i = norm_layer(embed_dim)
dec_norms.append(norm_layer_i)
self.dec_norms = nn.ModuleList(dec_norms)
# feature weighting
self.wfm = WeightedFeatureMaps(len(self.use_fm) + (1 if self.use_input else 0), embed_dim, norm_layer=norm_layer, decoder_depth=decoder_depth)
# --------------------------------------------------------------------------
# MAE decoder specifics
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
print("use self attention: ", self_attn)
self.decoder_blocks = nn.ModuleList([
CrossAttentionBlock(embed_dim, decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer, self_attn=self_attn)
for i in range(decoder_depth)])
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
# --------------------------------------------------------------------------
# Dealing with positional embedding, patch sampling
# encoder
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
# decoder
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
# --------------------------------------------------------------------------
self.norm_pix_loss = norm_pix_loss
self.initialize_weights()
def initialize_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=.02)
torch.nn.init.normal_(self.mask_token, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.patch_embed.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def random_masking(self, x, mask_ratio, kept_mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
len_masked = int(L * (mask_ratio - kept_mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :(len_keep + len_masked)] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def grid_patchify(self, x):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
return x
def forward_encoder(self, x, mask_ratio, kept_mask_ratio):
x = self.grid_patchify(x)
coords = None
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio, kept_mask_ratio)
# append cls token
# cls_token = self.cls_token + self.pos_embed[:, :1, :] # pos embed for cls token is 0
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
x_feats = []
if self.use_input:
x_feats.append(x)
for idx, blk in enumerate(self.blocks):
x = blk(x)
if self.weight_fm and idx in self.use_fm:
x_feats.append(x)
if self.weight_fm:
return x_feats, mask, ids_restore, coords
else:
x = self.norm(x)
return x, mask, ids_restore, coords
def mask_tokens_grid(self, mask, ids_restore):
N, L = ids_restore.shape
# contruct mask tokens
x = self.decoder_pos_embed[:, 1:].masked_select(mask.bool().unsqueeze(-1)).reshape(N, -1, self.mask_token.shape[-1])
x = x + self.mask_token
return x
def forward_decoder(self, y, mask, ids_restore, coords, mask_ratio, kept_mask_ratio):
x = self.mask_tokens_grid(mask, ids_restore)
if self.weight_fm:
# y input: a list of Tensors (B, C, D)
y = self.wfm(y)
for i, blk in enumerate(self.decoder_blocks):
if self.weight_fm:
x = blk(x, self.dec_norms[i](y[..., i]))
else:
x = blk(x, y)
x = self.decoder_norm(x)
x = self.decoder_pred(x) # N, L, patch_size**2 *3
return x, None
def forward_loss(self, imgs, pred, mask, coords):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
target = self.patchify(imgs)
target = target.masked_select(mask.bool().unsqueeze(-1)).reshape(target.shape[0], -1, target.shape[-1])
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
loss = (pred - target) ** 2
loss = loss.mean()
return loss, target
def forward(self, imgs, mask_ratio=0.75, kept_mask_ratio=0.75, vis=False):
with torch.cuda.amp.autocast():
latent, mask, ids_restore, coords = self.forward_encoder(imgs, mask_ratio, kept_mask_ratio)
pred, combined = self.forward_decoder(latent, mask, ids_restore, coords, mask_ratio, kept_mask_ratio) # [N, L, p*p*3]
loss, target = self.forward_loss(imgs, pred, mask, coords)
if vis:
# assumes mask ratio is the same as kept_mask_ratio for visualizations
assert mask_ratio == kept_mask_ratio, "mask_ratio needs to be the same as kept_mask_ratio for visualizations. Otherwise we have unpredicted patches."
# create some zero tensors
with torch.no_grad():
N, L = mask.shape[0], mask.shape[1]
combined = torch.zeros(N, L, pred.shape[2], device=pred.device, dtype=pred.dtype)
combined[mask.bool()] = pred.view(-1, pred.shape[2])
pred_combined = combined
combined = torch.zeros(N, L, pred.shape[2], device=pred.device, dtype=pred.dtype)
combined[mask.bool()] = target.view(-1, target.shape[2])
target_combined = combined
return loss, pred_combined, target_combined, mask
else:
return loss
def mae_vit_small_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=384, depth=12, num_heads=6,
decoder_embed_dim=256, decoder_num_heads=8,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def mae_vit_base_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=512, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def mae_vit_large_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def mae_vit_huge_patch14_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=14, embed_dim=1280, depth=32, num_heads=16,
decoder_embed_dim=512, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
# set recommended archs
mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks