From e18ea6ae12bc5f68c4ddfa04d928297cff84b44c Mon Sep 17 00:00:00 2001 From: ryota717 Date: Thu, 26 Sep 2024 00:41:38 +0900 Subject: [PATCH] feats: multi-armed bandit samplers. --- .../samplers/multi_armed_bandit/__init__.py | 4 +- .../multi_armed_bandit/multi_armed_bandit.py | 61 +++++++++++++++++++ 2 files changed, 63 insertions(+), 2 deletions(-) create mode 100644 package/samplers/multi_armed_bandit/multi_armed_bandit.py diff --git a/package/samplers/multi_armed_bandit/__init__.py b/package/samplers/multi_armed_bandit/__init__.py index 59718fd3..0581651e 100644 --- a/package/samplers/multi_armed_bandit/__init__.py +++ b/package/samplers/multi_armed_bandit/__init__.py @@ -1,4 +1,4 @@ -from .YOUR_TEMPLATE import YOUR_FEATURE +from .multi_armed_bandit import MultiArmedBanditSampler -__all__ = ["YOUR_FEATURE"] +__all__ = ["MultiArmedBanditSampler"] diff --git a/package/samplers/multi_armed_bandit/multi_armed_bandit.py b/package/samplers/multi_armed_bandit/multi_armed_bandit.py new file mode 100644 index 00000000..871fd766 --- /dev/null +++ b/package/samplers/multi_armed_bandit/multi_armed_bandit.py @@ -0,0 +1,61 @@ +from collections import defaultdict +from typing import Any +from typing import Optional + +from optuna.distributions import BaseDistribution +from optuna.samplers import RandomSampler +from optuna.study import Study +from optuna.study._study_direction import StudyDirection +from optuna.trial import FrozenTrial +from optuna.trial import TrialState + + +class MultiArmedBanditSampler(RandomSampler): + """Sampler based on Multi-armed Bandit Algorithm. + + Args: + epsilon (float): + Params for epsolon-greedy algorithm. + epsilon is probability of selecting arm randomly. + seed (int): + Seed for random number generator and arm selection. + + """ + + def __init__( + self, + epsilon: float = 0.7, + seed: Optional[int] = None, + ) -> None: + super().__init__(seed) + self._epsilon = epsilon + + def sample_independent( + self, + study: Study, + trial: FrozenTrial, + param_name: str, + param_distribution: BaseDistribution, + ) -> Any: + if self._rng.rng.rand() < self._epsilon: + return self._rng.rng.choice(param_distribution.choices) + else: + states = (TrialState.COMPLETE, TrialState.PRUNED) + trials = study._get_trials(deepcopy=False, states=states, use_cache=True) + + rewards_by_choice: defaultdict = defaultdict(float) + cnt_by_choice: defaultdict = defaultdict(int) + for t in trials: + rewards_by_choice[t.params[param_name]] += t.value + cnt_by_choice[t.params[param_name]] += 1 + + if study.direction == StudyDirection.MINIMIZE: + return min( + param_distribution.choices, + key=lambda x: rewards_by_choice[x] / max(cnt_by_choice[x], 1), + ) + else: + return max( + param_distribution.choices, + key=lambda x: rewards_by_choice[x] / max(cnt_by_choice[x], 1), + )