-
Notifications
You must be signed in to change notification settings - Fork 238
/
brute.py
98 lines (75 loc) · 2.56 KB
/
brute.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""Iterate over every combination of hyperparameters."""
import logging
from network import Network
from tqdm import tqdm
# Setup logging.
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%m/%d/%Y %I:%M:%S %p',
level=logging.DEBUG,
filename='brute-log.txt'
)
def train_networks(networks, dataset):
"""Train each network.
Args:
networks (list): Current population of networks
dataset (str): Dataset to use for training/evaluating
"""
pbar = tqdm(total=len(networks))
for network in networks:
network.train(dataset)
network.print_network()
pbar.update(1)
pbar.close()
# Sort our final population.
networks = sorted(networks, key=lambda x: x.accuracy, reverse=True)
# Print out the top 5 networks.
print_networks(networks[:5])
def print_networks(networks):
"""Print a list of networks.
Args:
networks (list): The population of networks
"""
logging.info('-'*80)
for network in networks:
network.print_network()
def generate_network_list(nn_param_choices):
"""Generate a list of all possible networks.
Args:
nn_param_choices (dict): The parameter choices
Returns:
networks (list): A list of network objects
"""
networks = []
# This is silly.
for nbn in nn_param_choices['nb_neurons']:
for nbl in nn_param_choices['nb_layers']:
for a in nn_param_choices['activation']:
for o in nn_param_choices['optimizer']:
# Set the parameters.
network = {
'nb_neurons': nbn,
'nb_layers': nbl,
'activation': a,
'optimizer': o,
}
# Instantiate a network object with set parameters.
network_obj = Network()
network_obj.create_set(network)
networks.append(network_obj)
return networks
def main():
"""Brute force test every network."""
dataset = 'cifar10'
nn_param_choices = {
'nb_neurons': [64, 128, 256, 512, 768, 1024],
'nb_layers': [1, 2, 3, 4],
'activation': ['relu', 'elu', 'tanh', 'sigmoid'],
'optimizer': ['rmsprop', 'adam', 'sgd', 'adagrad',
'adadelta', 'adamax', 'nadam'],
}
logging.info("***Brute forcing networks***")
networks = generate_network_list(nn_param_choices)
train_networks(networks, dataset)
if __name__ == '__main__':
main()