-
Notifications
You must be signed in to change notification settings - Fork 4
/
run_nni.py
43 lines (33 loc) · 1.01 KB
/
run_nni.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "True" # NOT Safe
from omegaconf import OmegaConf
import nni
from flashlight.runner import main_pl
from config import config as dc
def search_params_intp(params):
ret = {}
for param in params.keys():
# param : "train.batch"
spl = param.split(".")
if len(spl) == 2:
temp = {}
temp[spl[1]] = params[param]
ret[spl[0]] = temp
elif len(spl) == 1:
ret[spl[0]] = params[param]
else:
raise ValueError
return ret
def _main(cfg=dc.DefaultConfig) -> None:
params = nni.get_next_parameter()
params = search_params_intp(params)
cfg = OmegaConf.structured(cfg)
args = OmegaConf.merge(cfg, params)
print(args)
ml = main_pl.MainPL(
args.train, args.val, args.test, args.hw, args.network, args.data, args.opt, args.log, args.seed
)
final_result = ml.run()
nni.report_final_result(final_result)
if __name__ == "__main__":
_main()