-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
82 lines (61 loc) · 2.25 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
import yaml
import torch
import numpy as np
from collections import defaultdict, OrderedDict
from src.model_handler import ModelHandler
def set_random_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_config(config_path="config.yml"):
with open(config_path, "r") as setting:
config = yaml.load(setting, Loader=yaml.FullLoader)
return config
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('-config', '--config', required=True, type=str, help='path to the config file')
parser.add_argument('--multi_run', action='store_true', help='whether open multiple run')
args = vars(parser.parse_args())
return args
def print_config(config):
print("******************** MODEL CONFIGURATION ********************")
for key in sorted(config.keys()):
val = config[key]
keystr = "{}".format(key) + (" " * (24 - len(key)))
print("{} ==> {}".format(keystr, val))
print("********************* END CONFIGURATION *********************")
def grid(kwargs):
class MncDc:
def __init__(self, a):
self.a = a
def __call__(self):
return self.a
def merge_dicts(*dicts):
from functools import reduce
def merge_two_dicts(x, y):
z = x.copy()
z.update(y)
return z
return reduce(lambda a, nd: merge_two_dicts(a, nd if nd else {}), dicts, {})
sin = OrderedDict({k: v for k, v in kwargs.items() if isinstance(v, list)})
for k, v in sin.items():
copy_v = []
for e in v:
copy_v.append(MncDc(e) if isinstance(e, tuple) else e)
sin[k] = copy_v
grd = np.array(np.meshgrid(*sin.values()), dtype=object).T.reshape(-1, len(sin.values()))
return [merge_dicts(
{k: v for k, v in kwargs.items() if not isinstance(v, list)},
{k: vv[i]() if isinstance(vv[i], MncDc) else vv[i] for i, k in enumerate(sin)}
) for vv in grd]
def main(config):
print_config(config)
set_random_seed(config['random_seed'])
model = ModelHandler(config)
model.train()
model.test()
if __name__ == '__main__':
args = get_args()
config = get_config(args['config'])
main(config)