-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconfig.py
103 lines (87 loc) · 3.61 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import re
import math
from ml_collections import config_dict
from utils import RGB_DATASETS
def get_configs(args, **kwargs):
"""Create config dicts for dataset, model and optimizer."""
data_config = config_dict.ConfigDict()
data_config.name = args.dataset.lower()
# minimum number of instances per class
data_config.min_samples = args.min_samples
# dataset imbalance is a function of p
data_config.class_probs = kwargs.pop("p_mass")
# number of classes that are overrepresented in the data
data_config.n_frequent_classes = args.overrepresented_classes
# whether to balance mini-batches
data_config.sampling = kwargs.pop("sampling")
# maximum number of triplets
data_config.num_sets = kwargs.pop("num_sets")
# input dimensionality
data_config.input_dim = kwargs.pop("input_dim")
# average number of instances per class
M = kwargs.pop("n_samples")
data_config.n_samples = M
data_config.n_classes = args.n_classes
data_config.k = kwargs.pop("num_odds")
data_config.targets = args.targets
data_config.apply_augmentations = args.apply_augmentations
data_config.label_noise = args.label_noise
if data_config.name in RGB_DATASETS:
import utils
data_config.is_rgb_dataset = True
data_config.max_pixel_value = 255.0
means, stds = utils.get_data_statistics(data_config.name)
data_config.means = means
data_config.stds = stds
else:
data_config.is_rgb_dataset = False
data_config.oko_batch_size = kwargs.pop("oko_batch_size")
data_config.main_batch_size = kwargs.pop("main_batch_size")
model_config = config_dict.ConfigDict()
model_config.type = re.compile(r"[a-zA-Z]+").search(args.network).group()
try:
model_config.depth = re.compile(r"\d+").search(args.network).group()
except AttributeError:
model_config.depth = ""
model_config.regularization = args.regularization
if args.regularization:
if args.network.lower().startswith("resnet"):
model_config.weight_decay = 1e-2
else:
model_config.weight_decay = 1e-3
else:
model_config.weight_decay = None
model_config.n_classes = args.n_classes
if data_config.k == 0:
model_config.task = "Pair"
else:
model_config.task = f"Odd-$k$-out ($k$={data_config.k}; {data_config.targets})"
# TODO: enable half precision when running things on TPU
model_config.half_precision = False
optimizer_config = config_dict.ConfigDict()
optimizer_config.name = args.optim
optimizer_config.burnin = args.burnin
optimizer_config.patience = args.patience
optimizer_config.lr = kwargs.pop("eta")
optimizer_config.epochs = kwargs.pop("epochs")
optimizer_config.warmup_epochs = args.warmup_epochs
optimizer_config.steps_per_epoch = math.ceil(
data_config.num_sets / data_config.oko_batch_size
)
optimizer_config.clip_val = float(1)
data_config.epochs = optimizer_config.epochs
data_config.initial_lr = optimizer_config.lr
if optimizer_config.name.lower() == "sgd":
# add momentum param if optim is sgd
optimizer_config.momentum = 0.9
else:
optimizer_config.momentum = None
# make config dicts immutable (same type as model param dicts)
freeze = lambda cfg: config_dict.FrozenConfigDict(cfg)
# freeze = lambda cfg: flax.core.frozen_dict.FrozenDict(cfg)
data_config = freeze(data_config)
model_config = freeze(model_config)
optimizer_config = freeze(optimizer_config)
return data_config, model_config, optimizer_config