-
Notifications
You must be signed in to change notification settings - Fork 1
/
speed_test.py
84 lines (62 loc) · 2.18 KB
/
speed_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import warnings
import torch
import torch.multiprocessing as mp
# torch.multiprocessing.set_sharing_strategy('file_system')
import wandb
import sys
sys.path.append('lib/')
from lib.utils import set_seed, dist_setup, get_conf
import lib.trainers as trainers
def main():
args = get_conf()
args.test = True
# set seed if required
set_seed(args.seed)
if not args.multiprocessing_distributed and args.gpu is not None:
warnings.warn('You have chosen a specific GPU. This will completely '
'disable data parallelism.')
if args.dist_url == "env://" and args.world_size == -1:
args.world_size = int(os.environ["WORLD_SIZE"])
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
ngpus_per_node = torch.cuda.device_count()
args.ngpus_per_node = ngpus_per_node
if args.multiprocessing_distributed:
args.world_size = ngpus_per_node * args.world_size
mp.spawn(main_worker,
nprocs=ngpus_per_node,
args=(args,))
else:
print("single process")
main_worker(args.gpu, args)
def main_worker(gpu, args):
args.gpu = gpu
ngpus_per_node = args.ngpus_per_node
dist_setup(ngpus_per_node, args)
# init trainer
trainer_class = getattr(trainers, f'{args.trainer_name}', None)
assert trainer_class is not None, f"Trainer class {args.trainer_name} is not defined"
trainer = trainer_class(args)
# if args.rank == 0:
# if args.wandb_id is None:
# args.wandb_id = wandb.util.generate_id()
# run = wandb.init(project=args.proj_name,
# name=args.run_name,
# config=vars(args),
# id=args.wandb_id,
# resume='allow',
# dir=args.output_dir)
# create model
trainer.build_model()
# # create optimizer
# trainer.build_optimizer()
# resume training
# if args.resume:
# trainer.resume()
# trainer.build_dataloader()
# trainer.vis_policy()
trainer.speedometerv2()
# if args.rank == 0:
# run.finish()
if __name__ == '__main__':
main()