1
+ import torch
2
+ import timm
3
+ import numpy as np
4
+
5
+ from einops import repeat , rearrange
6
+ from einops .layers .torch import Rearrange
7
+
8
+
9
+ # 这里可以用两个timm模型进行构建我们的结果
10
+ from timm .models .layers import trunc_normal_
11
+ from timm .models .vision_transformer import Block
12
+
13
+ def random_indexes (size : int ):
14
+ forward_indexes = np .arange (size )
15
+ np .random .shuffle (forward_indexes ) # 打乱index
16
+ backward_indexes = np .argsort (forward_indexes ) # 得到原来index的位置,方便进行还原
17
+ return forward_indexes , backward_indexes
18
+
19
+ def take_indexes (sequences , indexes ):
20
+ return torch .gather (sequences , 0 , repeat (indexes , 't b -> t b c' , c = sequences .shape [- 1 ]))
21
+
22
+ class PatchShuffle (torch .nn .Module ):
23
+ def __init__ (self , ratio ) -> None :
24
+ super ().__init__ ()
25
+ self .ratio = ratio
26
+
27
+ def forward (self , patches : torch .Tensor ):
28
+ T , B , C = patches .shape # length, batch, dim
29
+ remain_T = int (T * (1 - self .ratio ))
30
+
31
+ indexes = [random_indexes (T ) for _ in range (B )]
32
+ forward_indexes = torch .as_tensor (np .stack ([i [0 ] for i in indexes ], axis = - 1 ), dtype = torch .long ).to (patches .device )
33
+ backward_indexes = torch .as_tensor (np .stack ([i [1 ] for i in indexes ], axis = - 1 ), dtype = torch .long ).to (patches .device )
34
+
35
+ patches = take_indexes (patches , forward_indexes ) # 随机打乱了数据的patch,这样所有的patch都被打乱了
36
+ patches = patches [:remain_T ] #得到未mask的pacth [T*0.25, B, C]
37
+
38
+ return patches , forward_indexes , backward_indexes
39
+
40
+ class MAE_Encoder (torch .nn .Module ):
41
+ def __init__ (self ,
42
+ image_size = 32 ,
43
+ patch_size = 2 ,
44
+ emb_dim = 192 ,
45
+ num_layer = 12 ,
46
+ num_head = 3 ,
47
+ mask_ratio = 0.75 ,
48
+ ) -> None :
49
+ super ().__init__ ()
50
+
51
+ self .cls_token = torch .nn .Parameter (torch .zeros (1 , 1 , emb_dim ))
52
+ self .pos_embedding = torch .nn .Parameter (torch .zeros ((image_size // patch_size ) ** 2 , 1 , emb_dim ))
53
+
54
+ # 对patch进行shuffle 和 mask
55
+ self .shuffle = PatchShuffle (mask_ratio )
56
+
57
+ # 这里得到一个 (3, dim, patch, patch)
58
+ self .patchify = torch .nn .Conv2d (3 , emb_dim , patch_size , patch_size )
59
+
60
+ self .transformer = torch .nn .Sequential (* [Block (emb_dim , num_head ) for _ in range (num_layer )])
61
+
62
+ # ViT的laynorm
63
+ self .layer_norm = torch .nn .LayerNorm (emb_dim )
64
+
65
+ self .init_weight ()
66
+
67
+ # 初始化类别编码和向量编码
68
+ def init_weight (self ):
69
+ trunc_normal_ (self .cls_token , std = .02 )
70
+ trunc_normal_ (self .pos_embedding , std = .02 )
71
+
72
+ def forward (self , img ):
73
+ patches = self .patchify (img )
74
+ patches = rearrange (patches , 'b c h w -> (h w) b c' )
75
+ patches = patches + self .pos_embedding
76
+
77
+ patches , forward_indexes , backward_indexes = self .shuffle (patches )
78
+
79
+ patches = torch .cat ([self .cls_token .expand (- 1 , patches .shape [1 ], - 1 ), patches ], dim = 0 )
80
+ patches = rearrange (patches , 't b c -> b t c' )
81
+ features = self .layer_norm (self .transformer (patches ))
82
+ features = rearrange (features , 'b t c -> t b c' )
83
+
84
+ return features , backward_indexes
85
+
86
+ class MAE_Decoder (torch .nn .Module ):
87
+ def __init__ (self ,
88
+ image_size = 32 ,
89
+ patch_size = 2 ,
90
+ emb_dim = 192 ,
91
+ num_layer = 4 ,
92
+ num_head = 3 ,
93
+ ) -> None :
94
+ super ().__init__ ()
95
+
96
+ self .mask_token = torch .nn .Parameter (torch .zeros (1 , 1 , emb_dim ))
97
+ self .pos_embedding = torch .nn .Parameter (torch .zeros ((image_size // patch_size ) ** 2 + 1 , 1 , emb_dim ))
98
+
99
+ self .transformer = torch .nn .Sequential (* [Block (emb_dim , num_head ) for _ in range (num_layer )])
100
+
101
+ self .head = torch .nn .Linear (emb_dim , 3 * patch_size ** 2 )
102
+ self .patch2img = Rearrange ('(h w) b (c p1 p2) -> b c (h p1) (w p2)' , p1 = patch_size , p2 = patch_size , h = image_size // patch_size )
103
+
104
+ self .init_weight ()
105
+
106
+ def init_weight (self ):
107
+ trunc_normal_ (self .mask_token , std = .02 )
108
+ trunc_normal_ (self .pos_embedding , std = .02 )
109
+
110
+ def forward (self , features , backward_indexes ):
111
+ T = features .shape [0 ]
112
+ backward_indexes = torch .cat ([torch .zeros (1 , backward_indexes .shape [1 ]).to (backward_indexes ), backward_indexes + 1 ], dim = 0 )
113
+ features = torch .cat ([features , self .mask_token .expand (backward_indexes .shape [0 ] - features .shape [0 ], features .shape [1 ], - 1 )], dim = 0 )
114
+ features = take_indexes (features , backward_indexes )
115
+ features = features + self .pos_embedding # 加上了位置编码的信息
116
+
117
+ features = rearrange (features , 't b c -> b t c' )
118
+ features = self .transformer (features )
119
+ features = rearrange (features , 'b t c -> t b c' )
120
+ features = features [1 :] # remove global feature 去掉全局信息,得到图像信息
121
+
122
+ patches = self .head (features ) # 用head得到patchs
123
+ mask = torch .zeros_like (patches )
124
+ mask [T :] = 1 # mask其他的像素全部设为 1
125
+ mask = take_indexes (mask , backward_indexes [1 :] - 1 )
126
+ img = self .patch2img (patches ) # 得到 重构之后的 img
127
+ mask = self .patch2img (mask )
128
+
129
+ return img , mask
130
+
131
+ class MAE_ViT (torch .nn .Module ):
132
+ def __init__ (self ,
133
+ image_size = 32 ,
134
+ patch_size = 2 ,
135
+ emb_dim = 192 ,
136
+ encoder_layer = 12 ,
137
+ encoder_head = 3 ,
138
+ decoder_layer = 4 ,
139
+ decoder_head = 3 ,
140
+ mask_ratio = 0.75 ,
141
+ ) -> None :
142
+ super ().__init__ ()
143
+
144
+ self .encoder = MAE_Encoder (image_size , patch_size , emb_dim , encoder_layer , encoder_head , mask_ratio )
145
+ self .decoder = MAE_Decoder (image_size , patch_size , emb_dim , decoder_layer , decoder_head )
146
+
147
+ def forward (self , img ):
148
+ features , backward_indexes = self .encoder (img )
149
+ predicted_img , mask = self .decoder (features , backward_indexes )
150
+ return predicted_img , mask
151
+
152
+ class ViT_Classifier (torch .nn .Module ):
153
+ def __init__ (self , encoder : MAE_Encoder , num_classes = 10 ) -> None :
154
+ super ().__init__ ()
155
+ self .cls_token = encoder .cls_token
156
+ self .pos_embedding = encoder .pos_embedding
157
+ self .patchify = encoder .patchify
158
+ self .transformer = encoder .transformer
159
+ self .layer_norm = encoder .layer_norm
160
+ self .head = torch .nn .Linear (self .pos_embedding .shape [- 1 ], num_classes )
161
+
162
+ def forward (self , img ):
163
+ patches = self .patchify (img )
164
+ patches = rearrange (patches , 'b c h w -> (h w) b c' )
165
+ patches = patches + self .pos_embedding
166
+ patches = torch .cat ([self .cls_token .expand (- 1 , patches .shape [1 ], - 1 ), patches ], dim = 0 )
167
+ patches = rearrange (patches , 't b c -> b t c' )
168
+ features = self .layer_norm (self .transformer (patches ))
169
+ features = rearrange (features , 'b t c -> t b c' )
170
+ logits = self .head (features [0 ])
171
+ return logits
172
+
173
+
174
+ if __name__ == '__main__' :
175
+ shuffle = PatchShuffle (0.75 )
176
+ a = torch .rand (16 , 2 , 10 )
177
+ b , forward_indexes , backward_indexes = shuffle (a )
178
+ print (b .shape )
179
+
180
+ img = torch .rand (2 , 3 , 32 , 32 )
181
+ encoder = MAE_Encoder ()
182
+ decoder = MAE_Decoder ()
183
+ features , backward_indexes = encoder (img )
184
+ print (forward_indexes .shape )
185
+ predicted_img , mask = decoder (features , backward_indexes )
186
+ print (predicted_img .shape )
187
+ loss = torch .mean ((predicted_img - img ) ** 2 * mask / 0.75 )
0 commit comments