-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_utils.py
348 lines (301 loc) · 12.6 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
from typing import Tuple, Sequence, Dict, Union, Optional, Callable
import torch
import torch.nn as nn
import torchvision
class train_utils:
def __init__(self):
pass
#@markdown ### **Vision Encoder**
#@markdown
#@markdown Defines helper functions:
#@markdown - `get_resnet` to initialize standard ResNet vision encoder
#@markdown - `replace_bn_with_gn` to replace all BatchNorm layers with GroupNorm
def get_resnet(self, name, weights=None, **kwargs):
"""
name: resnet18, resnet34, resnet50
weights: "IMAGENET1K_V1", "r3m"
"""
# load r3m weights
if (weights == "r3m") or (weights == "R3M"):
return self.get_r3m(name=name, **kwargs)
func = getattr(torchvision.models, name)
resnet = func(weights=weights, **kwargs)
resnet.fc = torch.nn.Identity()
return resnet
def get_r3m(self, name, **kwargs):
"""
name: resnet18, resnet34, resnet50
"""
import r3m
r3m.device = 'cpu'
model = r3m.load_r3m(name)
r3m_model = model.module
resnet_model = r3m_model.convnet
resnet_model = resnet_model.to('cpu')
return resnet_model
def replace_submodules(self,
root_module: nn.Module,
predicate: Callable[[nn.Module], bool],
func: Callable[[nn.Module], nn.Module]) -> nn.Module:
"""
Replace all submodules selected by the predicate with
the output of func.
predicate: Return true if the module is to be replaced.
func: Return new module to use.
"""
if predicate(root_module):
return func(root_module)
bn_list = [k.split('.') for k, m
in root_module.named_modules(remove_duplicate=True)
if predicate(m)]
for *parent, k in bn_list:
parent_module = root_module
if len(parent) > 0:
parent_module = root_module.get_submodule('.'.join(parent))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all modules are replaced
bn_list = [k.split('.') for k, m
in root_module.named_modules(remove_duplicate=True)
if predicate(m)]
assert len(bn_list) == 0
return root_module
def replace_bn_with_gn(self,
root_module: nn.Module,
features_per_group: int=16) -> nn.Module:
"""
Relace all BatchNorm layers with GroupNorm.
"""
self.replace_submodules(
root_module=root_module,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features//features_per_group,
num_channels=x.num_features)
)
return root_module
class SimpleViTEncoder(nn.Module):
def __init__(self, model_name: str = 'vit_base_patch16_224', pretrained: bool = True, frozen: bool = False):
super().__init__()
# Load the ViT model
self.vision_encoder = timm.create_model(model_name, pretrained=pretrained, num_classes=0) # Remove classifier
# Optionally freeze the model if required
if frozen:
for param in self.vision_encoder.parameters():
param.requires_grad = False
def forward(self, x):
# Pass the input through the ViT model
return self.vision_encoder(x)
import timm
import numpy as np
import copy
class TransformerObsEncoder(nn.Module):
def __init__(self,
shape_meta: dict,
model_name: str = 'vit_base_patch16_clip_224.openai',
global_pool: str = '',
transforms: list = None,
n_emb: int = 768,
pretrained: bool = True,
frozen: bool = False,
use_group_norm: bool = True,
share_rgb_model: bool = False,
feature_aggregation: str = None,
downsample_ratio: int = 32):
"""
Assumes rgb input: B,T,C,H,W
Assumes low_dim input: B,T,D
"""
super().__init__()
rgb_keys = list()
low_dim_keys = list()
key_model_map = nn.ModuleDict()
key_transform_map = nn.ModuleDict()
key_projection_map = nn.ModuleDict()
key_shape_map = dict()
assert global_pool == ''
model = timm.create_model(
model_name=model_name,
pretrained=pretrained,
global_pool=global_pool, # '' means no pooling
num_classes=0 # remove classification layer
)
self.model_name = model_name
if frozen:
assert pretrained
for param in model.parameters():
param.requires_grad = False
feature_dim = None
if model_name.startswith('resnet'):
if downsample_ratio == 32:
modules = list(model.children())[:-2]
model = torch.nn.Sequential(*modules)
feature_dim = 512
elif downsample_ratio == 16:
modules = list(model.children())[:-3]
model = torch.nn.Sequential(*modules)
feature_dim = 256
else:
raise NotImplementedError(f"Unsupported downsample_ratio: {downsample_ratio}")
elif model_name.startswith('convnext'):
if downsample_ratio == 32:
modules = list(model.children())[:-2]
model = torch.nn.Sequential(*modules)
feature_dim = 1024
else:
raise NotImplementedError(f"Unsupported downsample_ratio: {downsample_ratio}")
if use_group_norm and not pretrained:
model = self.replace_batch_norm_with_group_norm(model)
# handle feature aggregation
self.feature_aggregation = feature_aggregation
if model_name.startswith('vit'):
if self.feature_aggregation is None:
pass
elif self.feature_aggregation != 'cls':
print(f'vit will use the CLS token. feature_aggregation ({self.feature_aggregation}) is ignored!')
self.feature_aggregation = 'cls'
if self.feature_aggregation == 'soft_attention':
self.attention = nn.Sequential(
nn.Linear(feature_dim, 1, bias=False),
nn.Softmax(dim=1)
)
image_shape = None
obs_shape_meta = shape_meta['obs']
for key, attr in obs_shape_meta.items():
shape = tuple(attr['shape'])
type = attr.get('type', 'low_dim')
if type == 'rgb':
assert image_shape is None or image_shape == shape[1:]
image_shape = shape[1:]
if transforms is not None and not isinstance(transforms[0], torch.nn.Module):
assert transforms[0].type == 'RandomCrop'
ratio = transforms[0].ratio
transforms = [
torchvision.transforms.RandomCrop(size=int(image_shape[0] * ratio)),
torchvision.transforms.Resize(size=image_shape[0], antialias=True)
] + transforms[1:]
transform = nn.Identity() if transforms is None else torch.nn.Sequential(*transforms)
for key, attr in obs_shape_meta.items():
shape = tuple(attr['shape'])
type = attr.get('type', 'low_dim')
key_shape_map[key] = shape
if type == 'rgb':
rgb_keys.append(key)
this_model = model if share_rgb_model else copy.deepcopy(model)
key_model_map[key] = this_model
with torch.no_grad():
example_img = torch.zeros((1,) + tuple(shape))
example_feature_map = this_model(example_img)
example_features = self.aggregate_feature(example_feature_map)
feature_shape = example_features.shape
feature_size = feature_shape[-1]
proj = nn.Identity()
if feature_size != n_emb:
proj = nn.Linear(in_features=feature_size, out_features=n_emb)
key_projection_map[key] = proj
this_transform = transform
key_transform_map[key] = this_transform
elif type == 'low_dim':
dim = np.prod(shape)
proj = nn.Identity()
if dim != n_emb:
proj = nn.Linear(in_features=dim, out_features=n_emb)
key_projection_map[key] = proj
low_dim_keys.append(key)
else:
raise RuntimeError(f"Unsupported obs type: {type}")
rgb_keys = sorted(rgb_keys)
low_dim_keys = sorted(low_dim_keys)
self.n_emb = n_emb
self.shape_meta = shape_meta
self.key_model_map = key_model_map
self.key_transform_map = key_transform_map
self.key_projection_map = key_projection_map
self.share_rgb_model = share_rgb_model
self.rgb_keys = rgb_keys
self.low_dim_keys = low_dim_keys
self.key_shape_map = key_shape_map
def aggregate_feature(self, feature):
if self.model_name.startswith('vit'):
if self.feature_aggregation == 'cls':
return feature[:, [0], :]
assert self.feature_aggregation is None
return feature
assert len(feature.shape) == 4
feature = torch.flatten(feature, start_dim=-2) # B, 512, 7*7
feature = torch.transpose(feature, 1, 2) # B, 7*7, 512
if self.feature_aggregation == 'avg':
return torch.mean(feature, dim=[1], keepdim=True)
elif self.feature_aggregation == 'max':
return torch.amax(feature, dim=[1], keepdim=True)
elif self.feature_aggregation == 'soft_attention':
weight = self.attention(feature)
return torch.sum(feature * weight, dim=1, keepdim=True)
else:
assert self.feature_aggregation is None
return feature
def forward(self, obs_dict):
embeddings = list()
batch_size = next(iter(obs_dict.values())).shape[0]
for key in self.rgb_keys:
img = obs_dict[key]
B, T = img.shape[:2]
assert B == batch_size
assert img.shape[2:] == self.key_shape_map[key]
img = img.reshape(B * T, *img.shape[2:])
img = self.key_transform_map[key](img)
raw_feature = self.key_model_map[key](img)
feature = self.aggregate_feature(raw_feature)
emb = self.key_projection_map[key](feature)
assert len(emb.shape) == 3 and emb.shape[0] == B * T and emb.shape[-1] == self.n_emb
emb = emb.reshape(B, -1, self.n_emb)
embeddings.append(emb)
for key in self.low_dim_keys:
data = obs_dict[key]
B, T = data.shape[:2]
assert B == batch_size
assert data.shape[2:] == self.key_shape_map[key]
data = data.reshape(B, T, -1)
emb = self.key_projection_map[key](data)
assert emb.shape[-1] == self.n_emb
embeddings.append(emb)
result = torch.cat(embeddings, dim=1)
return result
def replace_batch_norm_with_group_norm(self, model):
def replace_submodules(root_module, predicate, func):
for name, module in root_module.named_children():
if predicate(module):
root_module.add_module(name, func(module))
else:
replace_submodules(module, predicate, func)
return root_module
return replace_submodules(
root_module=model,
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=(x.num_features // 16) if (x.num_features % 16 == 0) else (x.num_features // 8),
num_channels=x.num_features
)
)
def test():
shape_meta = {
'obs': {
'rgb': {'shape': (3, 224, 224), 'type': 'rgb'},
'low_dim': {'shape': (10,), 'type': 'low_dim'}
}
}
encoder = TransformerObsEncoder(shape_meta=shape_meta)
obs_dict = {
'rgb': torch.rand(2, 5, 3, 224, 224),
'low_dim': torch.rand(2, 5, 10)
}
result = encoder(obs_dict)
print(result.shape)
# test()