forked from IBM/cdfsl-benchmark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
executable file
·115 lines (82 loc) · 4.22 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim
import torch.optim.lr_scheduler as lr_scheduler
import time
import os
import glob
import configs
import backbone
from data.datamgr import SimpleDataManager, SetDataManager
from methods.baselinetrain import BaselineTrain
from methods.protonet import ProtoNet
from io_utils import model_dict, parse_args, get_resume_file
from datasets import miniImageNet_few_shot, DTD_few_shot
def train(base_loader, model, optimization, start_epoch, stop_epoch, params):
if optimization == 'Adam':
optimizer = torch.optim.Adam(model.parameters())
else:
raise ValueError('Unknown optimization, please define by yourself')
for epoch in range(start_epoch,stop_epoch):
model.train()
model.train_loop(epoch, base_loader, optimizer )
if not os.path.isdir(params.checkpoint_dir):
os.makedirs(params.checkpoint_dir)
if (epoch % params.save_freq==0) or (epoch==stop_epoch-1):
outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch))
torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile)
return model
if __name__=='__main__':
np.random.seed(10)
params = parse_args('train')
image_size = 224
optimization = 'Adam'
if params.method in ['baseline'] :
if params.dataset == "miniImageNet":
datamgr = miniImageNet_few_shot.SimpleDataManager(image_size, batch_size = 16)
base_loader = datamgr.get_data_loader(aug = params.train_aug )
elif params.dataset == "CUB":
base_file = configs.data_dir['CUB'] + 'base.json'
base_datamgr = SimpleDataManager(image_size, batch_size = 16)
base_loader = base_datamgr.get_data_loader( base_file , aug = params.train_aug )
elif params.dataset == "cifar100":
base_datamgr = cifar_few_shot.SimpleDataManager("CIFAR100", image_size, batch_size = 16)
base_loader = base_datamgr.get_data_loader( "base" , aug = True )
params.num_classes = 100
elif params.dataset == 'caltech256':
base_datamgr = caltech256_few_shot.SimpleDataManager(image_size, batch_size = 16)
base_loader = base_datamgr.get_data_loader(aug = False )
params.num_classes = 257
elif params.dataset == "DTD":
base_datamgr = DTD_few_shot.SimpleDataManager(image_size, batch_size = 16)
base_loader = base_datamgr.get_data_loader( aug = True )
else:
raise ValueError('Unknown dataset')
model = BaselineTrain( model_dict[params.model], params.num_classes)
elif params.method in ['protonet']:
n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot)
test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot)
if params.dataset == "miniImageNet":
datamgr = miniImageNet_few_shot.SetDataManager(image_size, n_query = n_query, **train_few_shot_params)
base_loader = datamgr.get_data_loader(aug = params.train_aug)
else:
raise ValueError('Unknown dataset')
if params.method == 'protonet':
model = ProtoNet( model_dict[params.model], **train_few_shot_params )
else:
raise ValueError('Unknown method')
model = model.cuda()
save_dir = configs.save_dir
params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(save_dir, params.dataset, params.model, params.method)
if params.train_aug:
params.checkpoint_dir += '_aug'
if not params.method in ['baseline', 'baseline++']:
params.checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot)
if not os.path.isdir(params.checkpoint_dir):
os.makedirs(params.checkpoint_dir)
start_epoch = params.start_epoch
stop_epoch = params.stop_epoch
model = train(base_loader, model, optimization, start_epoch, stop_epoch, params)