Skip to content

Commit af38d05

Browse files
committed
MAE IN CIFAR10
MAE IN CIFAR10
1 parent 48084ba commit af38d05

9 files changed

+17829
-0
lines changed

MAE/MAE_In_CIFAR.ipynb

Lines changed: 17372 additions & 0 deletions
Large diffs are not rendered by default.

MAE/README.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
## 基于CIFAR10 MAE的实现
2+
3+
由于可用资源有限,我们仅在 cifar10 上测试模型。我们主要想重现这样的结果:**使用 MAE 预训练 ViT 可以比直接使用标签进行监督学习训练获得更好的结果**。这应该是**自我监督学习比监督学习更有效的数据**的证据。
4+
5+
我们主要遵循论文中的实现细节。但是,由于 Cifar10 和 ImageNet 的区别,我们做了一些修改:
6+
7+
- 我们使用 vit-tiny 而不是 vit-base。
8+
- 由于 Cifar10 只有 50k 训练数据,我们将 pretraining epoch 从 400 增加到 2000,将 warmup epoch 从 40 增加到 200。我们注意到,在 2000 epoch 之后损失仍在减少。
9+
- 我们将训练分类器的批量大小从 1024 减少到 512 以减轻过度拟合。
10+
11+
### Install
12+
13+
`pip install -r requirements.txt`
14+
15+
### Run
16+
17+
首先进行预训练
18+
19+
```python
20+
# pretrained with mae
21+
python mae_pretrain.py
22+
```
23+
24+
训练未用MAE的分类器,也就是从头开始训练分类器
25+
26+
```
27+
# train classifier from scratch
28+
python train_classifier.py
29+
```
30+
31+
利用训练好的MAE的encoder作为输入,构建的分类模型作为分类器
32+
33+
```python
34+
# train classifier from pretrained model
35+
python train_classifier.py --pretrained_model_path vit-t-mae.pth --output_model_path vit-t-classifier-from_pretrained.pth
36+
```
37+
38+
集成了tensorboerd
39+
40+
```
41+
tensorboard --logdir logs
42+
```
43+
44+
可以查看结果
45+
46+
### Result
47+
48+
|Model|Validation Acc|
49+
|-----|--------------|
50+
|ViT-T w/o pretrain|74.13|
51+
|ViT-T w/ pretrain|**89.77**|
52+
53+
可视化CIFAR10前16张的图片,也可以在TensorBoard中查看
54+
55+
56+
57+
![avatar](pic/mae-cifar10-reconstruction.png)

MAE/mae_pretrain.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import os
2+
import argparse
3+
import math
4+
import torch
5+
import torchvision
6+
from torch.utils.tensorboard import SummaryWriter
7+
from torchvision.transforms import ToTensor, Compose, Normalize
8+
from tqdm import tqdm
9+
10+
from model import *
11+
from utils import setup_seed
12+
13+
if __name__ == '__main__':
14+
parser = argparse.ArgumentParser()
15+
parser.add_argument('--seed', type=int, default=42)
16+
parser.add_argument('-bs','--batch_size', type=int, default=4096)
17+
parser.add_argument('--max_device_batch_size', type=int, default=128)
18+
parser.add_argument('--base_learning_rate', type=float, default=1.5e-4)
19+
parser.add_argument('--weight_decay', type=float, default=0.05)
20+
parser.add_argument('--mask_ratio', type=float, default=0.75)
21+
parser.add_argument('--total_epoch', type=int, default=2000)
22+
parser.add_argument('--warmup_epoch', type=int, default=200)
23+
parser.add_argument('--model_path', type=str, default='vit-t-mae.pth')
24+
25+
args = parser.parse_args()
26+
27+
setup_seed(args.seed)
28+
29+
batch_size = args.batch_size
30+
load_batch_size = min(args.max_device_batch_size, batch_size)
31+
32+
assert batch_size % load_batch_size == 0
33+
steps_per_update = batch_size // load_batch_size
34+
35+
train_dataset = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
36+
val_dataset = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
37+
dataloader = torch.utils.data.DataLoader(train_dataset, load_batch_size, shuffle=True, num_workers=4)
38+
writer = SummaryWriter(os.path.join('logs', 'cifar10', 'mae-pretrain'))
39+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
40+
41+
model = MAE_ViT(mask_ratio=args.mask_ratio).to(device)
42+
if device == 'cuda':
43+
net = torch.nn.DataParallel(model)
44+
45+
optim = torch.optim.AdamW(model.parameters(), lr=args.base_learning_rate * args.batch_size / 256, betas=(0.9, 0.95), weight_decay=args.weight_decay)
46+
lr_func = lambda epoch: min((epoch + 1) / (args.warmup_epoch + 1e-8), 0.5 * (math.cos(epoch / args.total_epoch * math.pi) + 1))
47+
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_func, verbose=True)
48+
49+
step_count = 0
50+
optim.zero_grad()
51+
for e in range(args.total_epoch):
52+
model.train()
53+
losses = []
54+
train_step = len(dataloader)
55+
with tqdm(total=train_step,desc=f'Epoch {e+1}/{args.total_epoch}',postfix=dict,mininterval=0.3) as pbar:
56+
for img, label in iter(dataloader):
57+
step_count += 1
58+
img = img.to(device)
59+
predicted_img, mask = model(img)
60+
loss = torch.mean((predicted_img - img) ** 2 * mask) / args.mask_ratio
61+
loss.backward()
62+
if step_count % steps_per_update == 0:
63+
optim.step()
64+
optim.zero_grad()
65+
losses.append(loss.item())
66+
pbar.set_postfix(**{'Loss' : np.mean(losses)})
67+
pbar.update(1)
68+
lr_scheduler.step()
69+
avg_loss = sum(losses) / len(losses)
70+
writer.add_scalar('mae_loss', avg_loss, global_step=e)
71+
# print(f'In epoch {e}, average traning loss is {avg_loss}.')
72+
73+
''' visualize the first 16 predicted images on val dataset'''
74+
model.eval()
75+
with torch.no_grad():
76+
val_img = torch.stack([val_dataset[i][0] for i in range(16)])
77+
val_img = val_img.to(device)
78+
predicted_val_img, mask = model(val_img)
79+
predicted_val_img = predicted_val_img * mask + val_img * (1 - mask)
80+
img = torch.cat([val_img * (1 - mask), predicted_val_img, val_img], dim=0)
81+
img = rearrange(img, '(v h1 w1) c h w -> c (h1 h) (w1 v w)', w1=2, v=3)
82+
writer.add_image('mae_image', (img + 1) / 2, global_step=e)
83+
84+
''' save model '''
85+
torch.save(model, args.model_path)

MAE/model.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import torch
2+
import timm
3+
import numpy as np
4+
5+
from einops import repeat, rearrange
6+
from einops.layers.torch import Rearrange
7+
8+
9+
# 这里可以用两个timm模型进行构建我们的结果
10+
from timm.models.layers import trunc_normal_
11+
from timm.models.vision_transformer import Block
12+
13+
def random_indexes(size : int):
14+
forward_indexes = np.arange(size)
15+
np.random.shuffle(forward_indexes) # 打乱index
16+
backward_indexes = np.argsort(forward_indexes) # 得到原来index的位置,方便进行还原
17+
return forward_indexes, backward_indexes
18+
19+
def take_indexes(sequences, indexes):
20+
return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))
21+
22+
class PatchShuffle(torch.nn.Module):
23+
def __init__(self, ratio) -> None:
24+
super().__init__()
25+
self.ratio = ratio
26+
27+
def forward(self, patches : torch.Tensor):
28+
T, B, C = patches.shape # length, batch, dim
29+
remain_T = int(T * (1 - self.ratio))
30+
31+
indexes = [random_indexes(T) for _ in range(B)]
32+
forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
33+
backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
34+
35+
patches = take_indexes(patches, forward_indexes) # 随机打乱了数据的patch,这样所有的patch都被打乱了
36+
patches = patches[:remain_T] #得到未mask的pacth [T*0.25, B, C]
37+
38+
return patches, forward_indexes, backward_indexes
39+
40+
class MAE_Encoder(torch.nn.Module):
41+
def __init__(self,
42+
image_size=32,
43+
patch_size=2,
44+
emb_dim=192,
45+
num_layer=12,
46+
num_head=3,
47+
mask_ratio=0.75,
48+
) -> None:
49+
super().__init__()
50+
51+
self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
52+
self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
53+
54+
# 对patch进行shuffle 和 mask
55+
self.shuffle = PatchShuffle(mask_ratio)
56+
57+
# 这里得到一个 (3, dim, patch, patch)
58+
self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)
59+
60+
self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
61+
62+
# ViT的laynorm
63+
self.layer_norm = torch.nn.LayerNorm(emb_dim)
64+
65+
self.init_weight()
66+
67+
# 初始化类别编码和向量编码
68+
def init_weight(self):
69+
trunc_normal_(self.cls_token, std=.02)
70+
trunc_normal_(self.pos_embedding, std=.02)
71+
72+
def forward(self, img):
73+
patches = self.patchify(img)
74+
patches = rearrange(patches, 'b c h w -> (h w) b c')
75+
patches = patches + self.pos_embedding
76+
77+
patches, forward_indexes, backward_indexes = self.shuffle(patches)
78+
79+
patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
80+
patches = rearrange(patches, 't b c -> b t c')
81+
features = self.layer_norm(self.transformer(patches))
82+
features = rearrange(features, 'b t c -> t b c')
83+
84+
return features, backward_indexes
85+
86+
class MAE_Decoder(torch.nn.Module):
87+
def __init__(self,
88+
image_size=32,
89+
patch_size=2,
90+
emb_dim=192,
91+
num_layer=4,
92+
num_head=3,
93+
) -> None:
94+
super().__init__()
95+
96+
self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
97+
self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))
98+
99+
self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
100+
101+
self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
102+
self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)
103+
104+
self.init_weight()
105+
106+
def init_weight(self):
107+
trunc_normal_(self.mask_token, std=.02)
108+
trunc_normal_(self.pos_embedding, std=.02)
109+
110+
def forward(self, features, backward_indexes):
111+
T = features.shape[0]
112+
backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
113+
features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
114+
features = take_indexes(features, backward_indexes)
115+
features = features + self.pos_embedding # 加上了位置编码的信息
116+
117+
features = rearrange(features, 't b c -> b t c')
118+
features = self.transformer(features)
119+
features = rearrange(features, 'b t c -> t b c')
120+
features = features[1:] # remove global feature 去掉全局信息,得到图像信息
121+
122+
patches = self.head(features) # 用head得到patchs
123+
mask = torch.zeros_like(patches)
124+
mask[T:] = 1 # mask其他的像素全部设为 1
125+
mask = take_indexes(mask, backward_indexes[1:] - 1)
126+
img = self.patch2img(patches) # 得到 重构之后的 img
127+
mask = self.patch2img(mask)
128+
129+
return img, mask
130+
131+
class MAE_ViT(torch.nn.Module):
132+
def __init__(self,
133+
image_size=32,
134+
patch_size=2,
135+
emb_dim=192,
136+
encoder_layer=12,
137+
encoder_head=3,
138+
decoder_layer=4,
139+
decoder_head=3,
140+
mask_ratio=0.75,
141+
) -> None:
142+
super().__init__()
143+
144+
self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
145+
self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)
146+
147+
def forward(self, img):
148+
features, backward_indexes = self.encoder(img)
149+
predicted_img, mask = self.decoder(features, backward_indexes)
150+
return predicted_img, mask
151+
152+
class ViT_Classifier(torch.nn.Module):
153+
def __init__(self, encoder : MAE_Encoder, num_classes=10) -> None:
154+
super().__init__()
155+
self.cls_token = encoder.cls_token
156+
self.pos_embedding = encoder.pos_embedding
157+
self.patchify = encoder.patchify
158+
self.transformer = encoder.transformer
159+
self.layer_norm = encoder.layer_norm
160+
self.head = torch.nn.Linear(self.pos_embedding.shape[-1], num_classes)
161+
162+
def forward(self, img):
163+
patches = self.patchify(img)
164+
patches = rearrange(patches, 'b c h w -> (h w) b c')
165+
patches = patches + self.pos_embedding
166+
patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
167+
patches = rearrange(patches, 't b c -> b t c')
168+
features = self.layer_norm(self.transformer(patches))
169+
features = rearrange(features, 'b t c -> t b c')
170+
logits = self.head(features[0])
171+
return logits
172+
173+
174+
if __name__ == '__main__':
175+
shuffle = PatchShuffle(0.75)
176+
a = torch.rand(16, 2, 10)
177+
b, forward_indexes, backward_indexes = shuffle(a)
178+
print(b.shape)
179+
180+
img = torch.rand(2, 3, 32, 32)
181+
encoder = MAE_Encoder()
182+
decoder = MAE_Decoder()
183+
features, backward_indexes = encoder(img)
184+
print(forward_indexes.shape)
185+
predicted_img, mask = decoder(features, backward_indexes)
186+
print(predicted_img.shape)
187+
loss = torch.mean((predicted_img - img) ** 2 * mask / 0.75)

0 commit comments

Comments
 (0)