forked from optuna/optuna-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
chainermn_integration.py
140 lines (110 loc) · 4.9 KB
/
chainermn_integration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Optuna example that demonstrates a pruner for ChainerMN.
In this example, we optimize the validation accuracy of hand-written digit recognition using
ChainerMN and FashionMNIST, where architecture of neural network is optimized.
Throughout the training of neural networks,
a pruner observes intermediate results and stops unpromising trials.
ChainerMN and it's Optuna integration are supposed to be invoked via MPI. You can run this example
as follows:
$ STORAGE_URL=sqlite:///example.db
$ STUDY_NAME=`optuna create-study --storage $STORAGE_URL --direction maximize`
$ mpirun -n 2 -- python chainermn_integration.py $STUDY_NAME $STORAGE_URL
"""
import sys
import chainermn
import numpy as np
import optuna
from optuna.trial import TrialState
import chainer
import chainer.functions as F
import chainer.links as L
N_TRAIN_EXAMPLES = 3000
N_VALID_EXAMPLES = 1000
BATCHSIZE = 128
EPOCH = 10
PRUNER_INTERVAL = 3
def create_model(trial):
# We optimize the numbers of layers and their units.
n_layers = trial.suggest_int("n_layers", 1, 3)
layers = []
for i in range(n_layers):
n_units = trial.suggest_int("n_units_l{}".format(i), 4, 128, log=True)
layers.append(L.Linear(None, n_units))
layers.append(F.relu)
layers.append(L.Linear(None, 10))
return chainer.Sequential(*layers)
def objective(trial, comm):
# Sample an architecture.
model = L.Classifier(create_model(trial))
# Setup optimizer.
optimizer = chainer.optimizers.MomentumSGD()
optimizer.setup(model)
optimizer = chainermn.create_multi_node_optimizer(optimizer, comm)
# Setup dataset and iterator. Only worker 0 loads the whole dataset.
# The dataset of worker 0 is evenly split and distributed to all workers.
if comm.rank == 0:
train, valid = chainer.datasets.get_fashion_mnist()
rng = np.random.RandomState(0)
train = chainer.datasets.SubDataset(
train, 0, N_TRAIN_EXAMPLES, order=rng.permutation(len(train))
)
valid = chainer.datasets.SubDataset(
valid, 0, N_VALID_EXAMPLES, order=rng.permutation(len(valid))
)
else:
train, valid = None, None
train = chainermn.scatter_dataset(train, comm, shuffle=True)
valid = chainermn.scatter_dataset(valid, comm)
train_iter = chainer.iterators.SerialIterator(train, BATCHSIZE, shuffle=True)
valid_iter = chainer.iterators.SerialIterator(valid, BATCHSIZE, repeat=False, shuffle=False)
# Setup trainer.
updater = chainer.training.StandardUpdater(train_iter, optimizer)
trainer = chainer.training.Trainer(updater, (EPOCH, "epoch"))
# Add Chainer extension for pruners.
trainer.extend(
optuna.integration.ChainerPruningExtension(
trial, "validation/main/accuracy", (PRUNER_INTERVAL, "epoch")
)
)
evaluator = chainer.training.extensions.Evaluator(valid_iter, model)
trainer.extend(chainermn.create_multi_node_evaluator(evaluator, comm))
log_report_extension = chainer.training.extensions.LogReport(log_name=None)
trainer.extend(log_report_extension)
if comm.rank == 0:
trainer.extend(chainer.training.extensions.ProgressBar())
# Run training.
# Please set show_loop_exception_msg False to inhibit messages about TrialPruned exception.
# ChainerPruningExtension raises TrialPruned exception to stop training, and
# trainer shows some messages every time it receive TrialPruned.
trainer.run(show_loop_exception_msg=False)
# Evaluate.
evaluator = chainer.training.extensions.Evaluator(valid_iter, model)
evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
report = evaluator()
return report["main/accuracy"]
if __name__ == "__main__":
# Please make sure common study and storage are shared among nodes.
study_name = sys.argv[1]
storage_url = sys.argv[2]
study = optuna.load_study(study_name, storage_url, pruner=optuna.pruners.MedianPruner())
comm = chainermn.create_communicator("naive")
if comm.rank == 0:
print("Study name:", study_name)
print("Storage URL:", storage_url)
print("Number of nodes:", comm.size)
# Run optimization!
chainermn_study = optuna.integration.ChainerMNStudy(study, comm)
chainermn_study.optimize(objective, n_trials=25)
if comm.rank == 0:
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
print("Study statistics: ")
print(" Number of finished trials: ", len(study.trials))
print(" Number of pruned trials: ", len(pruned_trials))
print(" Number of complete trials: ", len(complete_trials))
print("Best trial:")
trial = study.best_trial
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))