-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
59 lines (53 loc) · 2.08 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import apex.amp as amp
import argparse
import torch
import torch.nn as nn
import torchvision
from cifar_data import get_datasets
from trainer import train, test
from logger import Logger
import utils
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./cnfg.yml', type=str)
return parser.parse_args()
def main():
# config
args = parse_args()
cnfg = utils.parse_config(args.config)
# data
tr_loader, tst_loader = get_datasets(cnfg['data']['flag'],
cnfg['data']['dir'],
cnfg['data']['batch_size'])
# initialization
utils.set_seed(cnfg['seed'])
device = torch.device(
'cuda:0') if cnfg['gpu'] is None else torch.device(cnfg['gpu'])
logger = Logger(cnfg)
model = utils.get_model(cnfg['model']).to(device)
criterion = nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(),
lr=cnfg['train']['lr'],
momentum=cnfg['train']['momentum'],
weight_decay=cnfg['train']['weight_decay'])
amp_args = dict(opt_level=cnfg['opt']['level'],
loss_scale=cnfg['opt']['loss_scale'], verbosity=False)
if cnfg['opt']['level'] == '02':
amp_args['master_weights'] = cnfg['opt']['store']
model, opt = amp.initialize(model, opt, **amp_args)
scheduler = utils.get_scheduler(
opt, cnfg['train'], cnfg['train']['epochs']*len(tr_loader))
# train+test
for epoch in range(cnfg['train']['epochs']):
train(epoch, model, criterion,
opt, scheduler, tr_loader, device, logger)
# testing
test(epoch, model, tst_loader, criterion, device, logger)
# save
if (epoch+1) % cnfg['save']['epochs'] == 0 and epoch > 0:
pth = 'models/' + cnfg['logger']['project'] + '_' \
+ cnfg['logger']['run'] + '_' + str(epoch) + '.pth'
utils.save_model(model, cnfg, epoch, pth)
logger.log_model(pth)
if __name__ == "__main__":
main()