-
Notifications
You must be signed in to change notification settings - Fork 33
/
search.py
59 lines (51 loc) · 2.1 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Code for "APQ: Joint Search for Network Architecture, Pruning and Quantization Policy"
# CVPR 2020
# Tianzhe Wang, Kuan Wang, Han Cai, Ji Lin, Zhijian Liu, Song Han
# {usedtobe, kuanwang, hancai, jilin, zhijian, songhan}@mit.edu
import argparse
import json
from methods.evolution.evo_main_gather import evolution_gather
parser = argparse.ArgumentParser(description='Best Arch Searcher')
parser.add_argument('--prepare', type=str, default=None, choices=['acc', 'acc_quant'])
parser.add_argument('--acc_train_sample', type=int, default=None)
parser.add_argument('--mode', type=str, default='evolution', choices=['evolution'])
parser.add_argument('--constraint', type=float, default=120)
parser.add_argument('--exp_name', type=str, default='test')
args, _ = parser.parse_known_args()
print(args)
def main():
import copy
import os
if args.mode == 'evolution':
def add_arch(info, lst):
info1 = copy.deepcopy(info)
info2 = copy.deepcopy(info)
del info1['dw_w_bits_setting']
del info1['dw_a_bits_setting']
del info1['pw_w_bits_setting']
del info1['pw_a_bits_setting']
del info2['wid']
del info2['ks']
del info2['e']
del info2['d']
lst.append((info1, info2))
dic = {}
whole = {}
candidate_archs = []
out_dir = 'exps/{}'.format(args.exp_name)
lats = []
for i in [args.constraint]:
res, info, t = evolution_gather(parser, force_latency=i)
acc, arch, lat = info
print((i, res, lat, arch, acc))
if i not in dic or dic[i] < acc:
dic[i] = acc
whole[i] = (t, res, lat, arch, acc)
lats.append(lat)
add_arch(arch, candidate_archs)
print('Found Best Architecture: {}'.format(dic))
os.makedirs(out_dir, exist_ok=True)
json.dump(candidate_archs[0], open('{}/arch'.format(out_dir), 'w'))
json.dump(lats, open('{}/lat'.format(out_dir), 'w'))
if __name__ == '__main__':
main()