-
Notifications
You must be signed in to change notification settings - Fork 0
/
options.py
144 lines (132 loc) · 8.13 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import numpy as np
import os
import glob
import torch
import argparse
def parse_args(script):
parser = argparse.ArgumentParser(description= 'few-shot script %s' %(script))
parser.add_argument('--dataset', default='miniImagenet', help='miniImagenet/cub/cars/places/plantae')
parser.add_argument('--image_size', default=224, type=int)
parser.add_argument('--model', default='ResNet10', help='model: Conv{4|6} / ResNet{10|18|34}') # we use ResNet10 in the paper
parser.add_argument('--train_n_way' , default=5, type=int, help='class num to classify for training')
parser.add_argument('--test_n_way' , default=5, type=int, help='class num to classify for testing (validation) ')
parser.add_argument('--n_shot' , default=5, type=int, help='number of labeled data in each class, same as n_support')
parser.add_argument('--train_episode' , default=100, type=int, help='training episode')
parser.add_argument('--test_episode' , default=100, type=int, help='training episode')
parser.add_argument('--train_aug', default=False, action='store_true', help='perform data augmentation or not during training ')
parser.add_argument('--name' , default='tmp', type=str, help='')
parser.add_argument('--save_dir' , default='./output', type=str, help='')
parser.add_argument('--data_dir' , default='filelists/', type=str, help='') #TO change your dataset here
parser.add_argument('--alpha' , default=0.1, type=float, help='lower band reject filter threshold')
parser.add_argument('--beta' , default=1.0, type=float, help='upper band reject filter threshold')
parser.add_argument("--meta_layers", nargs='+', default=[], type=int, help="meta layers")
parser.add_argument('--drop_type', default='static', type=str, help='channel mask type should be in ["static","MSD","MGD"]')
parser.add_argument('--aug_rate', default=0.5, type=float, help='dssc augment channel percentage')
parser.add_argument('--drop_rate', default=0.75, type=float, help='dropout ratio')
parser.add_argument('--drop_prob', default=1.0, type=float, help='probability of aplly dropout')
parser.add_argument('--fixteacher', action='store_true', help='use static teacher to train the student')
parser.add_argument('--param', default=0.3, type=float, help='initial param of MGD')
parser.add_argument('--prefix', default='baseline', type=str, help='initial param of MGD')
parser.add_argument('--outdim', default=512, type=int, help='initial param of MGD')
parser.add_argument('--method', default=None, type=str, help='running method')
parser.add_argument('--mask_type', default='hard', type=str, help='channel mask type should be in [hard, soft, random]')
if script == 'train':
parser.add_argument('--num_classes' , default=64, type=int, help='total number of classes in softmax, only used in baseline')
parser.add_argument('--save_freq' , default=100, type=int, help='Save frequency')
parser.add_argument('--target_set', default='cub', help='cub/cars/places/plantae, use the extremely labeled target data')
parser.add_argument('--modelType', default='St-Net', help='pretrain/St-Net/Tt-Net/Student')
parser.add_argument('--ckp_S' , default='', type=str,help='the ckp path of the expert St-Net model')
parser.add_argument('--ckp_A' , default='', type=str,help='the ckp path of the expert Tt-Net model')
parser.add_argument('--target_num_label', default=5, type=int, help='number of labeled target base images per class')
parser.add_argument('--start_epoch' , default=0, type=int,help ='Starting epoch')
parser.add_argument('--stop_epoch' , default=400, type=int, help ='Stopping epoch')
parser.add_argument('--resume' , default='', type=str, help='continue from previous trained model with largest epoch')
parser.add_argument('--resume_epoch', default=-1, type=int, help='')
parser.add_argument('--warmup' , default='gg3b0', type=str, help='continue from baseline, neglected if resume is true')
parser.add_argument('--temp' , default=1, type=float, help='temperature of prototypical loss in contrastive space')
parser.add_argument('--meta' , default=False, action='store_true', help='wheather drop generalizable')
parser.add_argument('--wdif' , default=0.0, type=float, help='different loss weight')
parser.add_argument('--lamb' , default=0.0, type=float, help='weight of the source data in domain mixup')
parser.add_argument('--wsim' , default=0.0, type=float, help='different loss weight')
parser.add_argument('--sigma' , default=0.2, type=float, help='different loss weight')
parser.add_argument('--aug_type' , default='cutmix', type=str, help='image augmentation type')
parser.add_argument('--reverse', default=False, action='store_true', help='wheather augment negligible channel')
parser.add_argument('--taylor_expansion', default='init', help='point where apply taylor expansion of target few shot loss of channel mask, choose from [init, ones, zeros, random]')
parser.add_argument('--fix', default=False, help='wheather search start from fix point')
parser.add_argument('--lconsist', default=5.0, type=float, help='consistency loss')
elif script == 'test':
parser.add_argument('--target_set', default='cub', help='cub/cars/places/plantae, use the extremely labeled target data')
parser.add_argument('--split' , default='novel', help='base/val/novel')
parser.add_argument('--save_epoch', default=400, type=int,help ='load the model trained in x epoch, use the best model if x is -1')
parser.add_argument('--warmup' , default='gg3bo', type = str, help = 'just for insert the test function into the training.')
parser.add_argument('--stop_epoch' , default=400, type=int, help ='Stopping epoch')
else:
raise ValueError('Unknown script')
return parser.parse_args()
def get_assigned_file(checkpoint_dir,num):
assign_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(num))
return assign_file
def get_resume_file(checkpoint_dir, resume_epoch=-1):
filelist = glob.glob(os.path.join(checkpoint_dir, '*.tar'))
print(filelist)
if len(filelist) == 0:
return None
filelist = [ x for x in filelist if os.path.basename(x) != 'best_model.tar' ]
epochs = []
for x in filelist:
try: e = int(os.path.splitext(os.path.basename(x))[0])
except: e = 0
epochs.append(e)
epochs = np.array(epochs)
max_epoch = np.max(epochs)
epoch = max_epoch if resume_epoch == -1 else resume_epoch
resume_file = os.path.join(checkpoint_dir, '{:d}.tar'.format(epoch))
return resume_file
def get_best_file(checkpoint_dir):
best_file = os.path.join(checkpoint_dir, 'best_model.tar')
if os.path.isfile(best_file):
return best_file
else:
return get_resume_file(checkpoint_dir)
def load_warmup_state(filename):
print(' load pre-trained model file: {}'.format(filename))
warmup_resume_file = get_resume_file(filename)
print(' warmup_resume_file:', warmup_resume_file)
tmp = torch.load(warmup_resume_file)
if tmp is not None:
state = tmp['state']
state_keys = list(state.keys())
for i, key in enumerate(state_keys):
if 'feature.' in key:
newkey = key.replace("feature.","")
state[newkey] = state.pop(key)
else:
state.pop(key)
else:
raise ValueError(' No pre-trained encoder file found!')
return state
'''
def load_warmup_state(filename, method):
print(' load pre-trained model file: {}'.format(filename))
warmup_resume_file = get_resume_file(filename)
print(' warmup_resume_file:', warmup_resume_file)
tmp = torch.load(warmup_resume_file)
if tmp is not None:
state = tmp['state']
state_keys = list(state.keys())
for i, key in enumerate(state_keys):
if 'relationnet' in method and "feature." in key:
newkey = key.replace("feature.","")
state[newkey] = state.pop(key)
elif method == 'gnnnet' and 'feature.' in key:
newkey = key.replace("feature.","")
state[newkey] = state.pop(key)
elif method == 'matchingnet' and 'feature.' in key and '.7.' not in key:
newkey = key.replace("feature.","")
state[newkey] = state.pop(key)
else:
state.pop(key)
else:
raise ValueError(' No pre-trained encoder file found!')
return state
'''