-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathutils.py
115 lines (82 loc) · 3.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import json
import os
__all__ = ["ConfLoader", "directory_setter", "config_overwriter"]
class ConfLoader:
"""
Load json config file using DictWithAttributeAccess object_hook.
ConfLoader(conf_name).opt attribute is the result of loading json config file.
"""
class DictWithAttributeAccess(dict):
"""
This inner class makes dict to be accessed same as class attribute.
For example, you can use opt.key instead of the opt['key']
"""
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
self[key] = value
def __init__(self, conf_name):
self.conf_name = conf_name
self.opt = self.__get_opt()
def __load_conf(self):
with open(self.conf_name, "r") as conf:
opt = json.load(
conf, object_hook=lambda dict: self.DictWithAttributeAccess(dict)
)
return opt
def __get_opt(self):
opt = self.__load_conf()
opt = self.DictWithAttributeAccess(opt)
return opt
def directory_setter(path="./results", make_dir=False):
"""
Make dictionary if not exists.
"""
if not os.path.exists(path) and make_dir:
os.makedirs(path) # make dir if not exist
print("directory %s is created" % path)
if not os.path.isdir(path):
raise NotADirectoryError(
"%s is not valid. set make_dir=True to make dir." % path
)
def config_overwriter(opt, args):
"""
Overwrite loaded configuration by parsing arguments.
"""
if args.dataset_name is not None:
opt.data_setups.dataset_name = args.dataset_name
if args.batch_size is not None:
opt.data_setups.batch_size = args.batch_size
if args.n_clients is not None:
opt.data_setups.n_clients = args.n_clients
if args.partition_method is not None:
opt.data_setups.partition.method = args.partition_method
if args.partition_s is not None:
opt.data_setups.partition.shard_per_user = args.partition_s
if args.partition_alpha is not None:
opt.data_setups.partition.alpha = args.partition_alpha
if args.model_name is not None:
opt.train_setups.model.name = args.model_name
if args.n_rounds is not None:
opt.train_setups.scenario.n_rounds = args.n_rounds
if args.sample_ratio is not None:
opt.train_setups.scenario.sample_ratio = args.sample_ratio
if args.local_epochs is not None:
opt.train_setups.scenario.local_epochs = args.local_epochs
if args.device is not None:
opt.train_setups.scenario.device = args.device
if args.lr is not None:
opt.train_setups.optimizer.params.lr = args.lr
if args.momentum is not None:
opt.train_setups.optimizer.params.momentum = args.momentum
if args.wd is not None:
opt.train_setups.optimizer.params.weight_decay = args.wd
if args.algo_name is not None:
opt.train_setups.algo.name = args.algo_name
if args.seed is not None:
opt.train_setups.seed = args.seed
if args.group is not None:
opt.wandb_setups.group = args.group
if args.exp_name is not None:
opt.wandb_setups.name = args.exp_name
return opt