-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multi-armed bandit sampler #155
Changes from 3 commits
8f0c8b4
e18ea6a
a1c1784
371556f
888441c
0bc6cb7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2024 <Ryota Nishijima> | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
--- | ||
author: Ryota Nishijima | ||
title: Multi-armed Bandit Sampler | ||
description: Sampler based on multi-armed bandit algorithm with epsilon-greedy arm selection. | ||
tags: [sampler, multi-armed bandit] | ||
optuna_versions: [4.0.0] | ||
license: MIT License | ||
--- | ||
|
||
## Class or Function Names | ||
|
||
- MultiArmedBanditSampler | ||
|
||
## Example | ||
|
||
```python | ||
mod = optunahub.load_module("samplers/multi_armed_bandit") | ||
sampler = mod.MultiArmedBanditSampler() | ||
``` | ||
|
||
See [`example.py`](https://github.com/optuna/optunahub-registry/blob/main/package/samplers/multi_armed_bandit/example.py) for more details. | ||
|
||
## Others | ||
|
||
This package provides a sampler based on Multi-armed bandit algorithm with epsilon-greedy selection. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .multi_armed_bandit import MultiArmedBanditSampler | ||
|
||
|
||
__all__ = ["MultiArmedBanditSampler"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import optuna | ||
import optunahub | ||
|
||
|
||
if __name__ == "__main__": | ||
module = optunahub.load_module( | ||
package="samplers/multi_armed_bandit", | ||
) | ||
sampler = module.MultiArmedBanditSampler() | ||
|
||
def objective(trial: optuna.Trial) -> float: | ||
x = trial.suggest_categorical("arm_1", [1, 2, 3]) | ||
y = trial.suggest_categorical("arm_2", [1, 2]) | ||
|
||
return x + y | ||
|
||
study = optuna.create_study(sampler=sampler) | ||
study.optimize(objective, n_trials=20) | ||
|
||
print(study.best_trial.value, study.best_trial.params) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from collections import defaultdict | ||
from typing import Any | ||
from typing import Optional | ||
|
||
from optuna.distributions import BaseDistribution | ||
from optuna.samplers import RandomSampler | ||
from optuna.study import Study | ||
from optuna.study._study_direction import StudyDirection | ||
from optuna.trial import FrozenTrial | ||
from optuna.trial import TrialState | ||
|
||
|
||
class MultiArmedBanditSampler(RandomSampler): | ||
"""Sampler based on Multi-armed Bandit Algorithm. | ||
|
||
Args: | ||
epsilon (float): | ||
Params for epsolon-greedy algorithm. | ||
epsilon is probability of selecting arm randomly. | ||
seed (int): | ||
ryota717 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Seed for random number generator and arm selection. | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
epsilon: float = 0.7, | ||
seed: Optional[int] = None, | ||
) -> None: | ||
super().__init__(seed) | ||
self._epsilon = epsilon | ||
|
||
def sample_independent( | ||
self, | ||
study: Study, | ||
trial: FrozenTrial, | ||
param_name: str, | ||
param_distribution: BaseDistribution, | ||
) -> Any: | ||
if self._rng.rng.rand() < self._epsilon: | ||
return self._rng.rng.choice(param_distribution.choices) | ||
else: | ||
states = (TrialState.COMPLETE, TrialState.PRUNED) | ||
trials = study._get_trials(deepcopy=False, states=states, use_cache=True) | ||
|
||
rewards_by_choice: defaultdict = defaultdict(float) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [QUESTION] This There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is a very good point actually:)
So usually, we start from the random initialization. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your suggestion!
This looks me to good and changed initialization in 371556f |
||
cnt_by_choice: defaultdict = defaultdict(int) | ||
for t in trials: | ||
rewards_by_choice[t.params[param_name]] += t.value | ||
cnt_by_choice[t.params[param_name]] += 1 | ||
|
||
if study.direction == StudyDirection.MINIMIZE: | ||
return min( | ||
param_distribution.choices, | ||
key=lambda x: rewards_by_choice[x] / max(cnt_by_choice[x], 1), | ||
) | ||
else: | ||
return max( | ||
param_distribution.choices, | ||
key=lambda x: rewards_by_choice[x] / max(cnt_by_choice[x], 1), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I confirmed that the example works!