-
Notifications
You must be signed in to change notification settings - Fork 177
/
max_trials_callback.py
45 lines (33 loc) · 1.25 KB
/
max_trials_callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
Optuna example to demonstrate setting the maximum number of trials in
a shared database when multiple workers are used.
In this example, we optimize a simple quadratic function. We use multiple
script runs (workers) to demonstrate the use of max_trial_callbacks,
which allows the user to set a maximum number of trials
regardless of the number of workers/scripts running the Trials.
"""
from time import sleep
import optuna
from optuna.study import MaxTrialsCallback
from optuna.trial import TrialState
def objective(trial):
sleep(1)
x = trial.suggest_float("x", 0, 10)
return x**2
if __name__ == "__main__":
study = optuna.create_study(
study_name="test",
storage="sqlite:///database.sqlite",
load_if_exists=True,
)
study.optimize(
objective, n_trials=50, callbacks=[MaxTrialsCallback(10, states=(TrialState.COMPLETE,))]
)
trials = study.trials_dataframe()
print("Number of completed trials: {}".format(len(trials[trials.state == "COMPLETE"])))
print("Best trial:")
trial = study.best_trial
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))