-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathrun_mae_pretraining.py
323 lines (269 loc) · 15.5 KB
/
run_mae_pretraining.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
import os
from pathlib import Path
from timm.models import create_model
from optim_factory import create_optimizer
from datasets import build_pretraining_dataset
from engine_for_pretraining import train_one_epoch
from utils import NativeScalerWithGradNormCount as NativeScaler
import utils
import modeling_pretrain
def get_args():
parser = argparse.ArgumentParser('VideoMAE pre-training script', add_help=False)
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--epochs', default=800, type=int)
parser.add_argument('--save_ckpt_freq', default=50, type=int)
# Model parameters
parser.add_argument('--model', default='pretrain_videomae_base_patch16_224', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--decoder_depth', default=4, type=int,
help='depth of decoder')
parser.add_argument('--mask_type', default='tube', choices=['random', 'tube', 'part_window'],
type=str, help='masked strategy of video tokens/patches')
parser.add_argument('--part_win_size', type=int, nargs='+', default=(2,2,10), help='window size (t,h,w) for part window attn')
parser.add_argument('--part_apply_symmetry', type=str, default=None, choices=['global', 'local'], help='perform facial symmetrical masking for part window attn pretraining')
parser.add_argument('--mask_ratio', default=0.75, type=float,
help='ratio of the visual tokens/patches need be masked')
parser.add_argument('--input_size', default=224, type=int,
help='videos input size for backbone')
parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT',
help='Drop path rate (default: 0.1)')
parser.add_argument('--normlize_target', default=True, type=bool,
help='normalized the target patch pixels')
## me: for more attention types
parser.add_argument('--encoder_depth', default=None, type=int,
help='depth of decoder')
parser.add_argument('--attn_type', default='joint', choices=['joint', 'local_global'],
type=str, help='attention type for spatiotemporal modeling')
### local_global
parser.add_argument('--lg_region_size', type=int, nargs='+', default=(2,2,10),
help='region size (t,h,w) for local_global attention')
parser.add_argument('--lg_first_attn_type', type=str, default='self', choices=['cross', 'self'],
help='the first attention layer type for local_global attention')
parser.add_argument('--lg_third_attn_type', type=str, default='cross', choices=['cross', 'self'],
help='the third attention layer type for local_global attention')
parser.add_argument('--lg_attn_param_sharing_first_third', action='store_true',
help='share parameters of the first and the third attention layers for local_global attention')
parser.add_argument('--lg_attn_param_sharing_all', action='store_true',
help='share all the parameters of three attention layers for local_global attention')
parser.add_argument('--lg_no_second', action='store_true',
help='no second (inter-region) attention for local_global attention')
parser.add_argument('--lg_no_third', action='store_true',
help='no third (local-global interaction) attention for local_global attention')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
weight decay. We use a cosine schedule for WD.
(Set the same value with args.weight_decay to keep weight decay no change)""")
parser.add_argument('--lr', type=float, default=1.5e-4, metavar='LR',
help='learning rate (default: 1.5e-4)')
parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min_lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
help='epochs to warmup LR, if scheduler supports')
# Augmentation parameters
parser.add_argument('--color_jitter', type=float, default=0.0, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--train_interpolation', type=str, default='bicubic',
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
# Dataset parameters
parser.add_argument('--data_path', default='/path/to/list_kinetics-400', type=str,
help='dataset path')
parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
parser.add_argument('--num_frames', type=int, default= 16)
parser.add_argument('--sampling_rate', type=int, default= 4)
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default=None,
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--auto_resume', action='store_true')
parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
parser.set_defaults(auto_resume=True)
parser.add_argument('--num_samples', type=int, default=1, help='number of clips sampled from each video')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem',
help='')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--no_augmentation', action='store_true', help="do not use GroupMultiScaleCrop during pre-training")
parser.add_argument('--use_frame_diff_as_target', action='store_true', help="use frame difference as partial recontruction targets")
parser.add_argument('--frame_diff_group_size', default=2, type=int, help="the group size for diff calculation, for example, "
"the target is [f0, f1-f0, f2-f0, f3-f0], [f4, f5-f4, f6-f4, f7-f4], ..., when the size is 4.")
parser.add_argument('--target_diff_weight', default=None, type=float, help="the weight (0-1) for frame diff target ")
return parser.parse_args()
def get_model(args):
print(f"Creating model: {args.model}")
if 'no_depth' in args.model and args.encoder_depth is not None:
model = create_model(
args.model,
pretrained=False,
drop_path_rate=args.drop_path,
drop_block_rate=None,
decoder_depth=args.decoder_depth,
encoder_depth=args.encoder_depth,
attn_type=args.attn_type,
lg_region_size=args.lg_region_size, lg_first_attn_type=args.lg_first_attn_type, # for local_global
lg_third_attn_type=args.lg_third_attn_type,
lg_attn_param_sharing_first_third=args.lg_attn_param_sharing_first_third,
lg_attn_param_sharing_all=args.lg_attn_param_sharing_all,
lg_no_second=args.lg_no_second, lg_no_third=args.lg_no_third,
)
else:
model = create_model(
args.model,
pretrained=False,
drop_path_rate=args.drop_path,
drop_block_rate=None,
decoder_depth=args.decoder_depth,
attn_type=args.attn_type,
lg_region_size=args.lg_region_size, lg_first_attn_type=args.lg_first_attn_type, # for local_global
lg_third_attn_type=args.lg_third_attn_type,
lg_attn_param_sharing_first_third=args.lg_attn_param_sharing_first_third,
lg_attn_param_sharing_all=args.lg_attn_param_sharing_all,
lg_no_second=args.lg_no_second, lg_no_third=args.lg_no_third,
)
return model
def main(args):
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
model = get_model(args)
patch_size = model.encoder.patch_embed.patch_size
print("Patch size = %s" % str(patch_size))
args.window_size = (args.num_frames // 2, args.input_size // patch_size[0], args.input_size // patch_size[1])
args.patch_size = patch_size
# get dataset
dataset_train = build_pretraining_dataset(args)
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
sampler_rank = global_rank
num_training_steps_per_epoch = len(dataset_train) // args.batch_size // num_tasks
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=sampler_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
if global_rank == 0 and args.log_dir is not None:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
else:
log_writer = None
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
worker_init_fn=utils.seed_worker
)
model.to(device)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model = %s" % str(model_without_ddp))
print('number of params: {} M'.format(n_parameters / 1e6))
total_batch_size = args.batch_size * utils.get_world_size()
args.lr = args.lr * total_batch_size / 256
args.min_lr = args.min_lr * total_batch_size / 256
args.warmup_lr = args.warmup_lr * total_batch_size / 256
print("LR = %.8f" % args.lr)
print("Batch size = %d" % total_batch_size)
print("Number of training steps = %d" % num_training_steps_per_epoch)
print("Number of training examples per epoch = %d" % (total_batch_size * num_training_steps_per_epoch))
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model_without_ddp = model.module
optimizer = create_optimizer(
args, model_without_ddp)
loss_scaler = NativeScaler()
print("Use step level LR & WD scheduler!")
lr_schedule_values = utils.cosine_scheduler(
args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
)
if args.weight_decay_end is None:
args.weight_decay_end = args.weight_decay
wd_schedule_values = utils.cosine_scheduler(
args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
utils.auto_load_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
torch.cuda.empty_cache()
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
if log_writer is not None:
log_writer.set_step(epoch * num_training_steps_per_epoch)
train_stats = train_one_epoch(
model, data_loader_train,
optimizer, device, epoch, loss_scaler,
args.clip_grad, log_writer=log_writer,
start_steps=epoch * num_training_steps_per_epoch,
lr_schedule_values=lr_schedule_values,
wd_schedule_values=wd_schedule_values,
patch_size=patch_size[0],
normlize_target=args.normlize_target,
num_samples=args.num_samples,
use_frame_diff_as_target=args.use_frame_diff_as_target,
frame_diff_group_size=args.frame_diff_group_size,
target_diff_weight=args.target_diff_weight,
)
if args.output_dir:
if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
utils.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch, 'n_parameters': n_parameters}
if args.output_dir and utils.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
opts = get_args()
if opts.output_dir:
Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
main(opts)