Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue bandit scheduler #489

Merged
merged 11 commits into from
Jan 6, 2025
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#### Improvements

- Fix `BanditScheduler` behaviour: the number of active emitters remains stable
({pr}`489`)
- Skip qdax tests if qdax not installed ({pr}`491`)
- Move yapf after isort in pre-commit ({pr}`490`)
- Remove `_cells` attribute from ArchiveBase ({pr}`475`)
Expand Down
29 changes: 21 additions & 8 deletions ribs/schedulers/_bandit_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,16 @@ def ask(self):
# Reselect all emitters.
reselect = self._active_arr.copy()

# If no emitters are active, activate the first num_active.
if not self._active_arr.any():
reselect[:] = False
self._active_arr[:self._num_active] = True
# If not enough emitters are active, activate the first num_active.
# This always happens on the first iteration(s).
num_needed = self._num_active - self._active_arr.sum()
i = 0
while num_needed > 0:
reselect[i] = False
if not self._active_arr[i]:
self._active_arr[i] = True
num_needed -= 1
i += 1

# Deactivate emitters to be reselected.
self._active_arr[reselect] = False
Expand All @@ -258,17 +264,24 @@ def ask(self):
# terminated/restarted will be reselected. Otherwise, if reselect is
# "all", then all emitters are reselected.
if reselect.any():
ucb1 = np.full_like(self._emitter_pool, np.inf)
ucb1 = np.full_like(
self._emitter_pool, np.inf
) # np.inf forces to select emitters that were not yet selected
update_ucb = self._selection != 0
if update_ucb.any():
ucb1[update_ucb] = (
self._success[update_ucb] / self._selection[update_ucb] +
self._zeta * np.sqrt(
np.log(self._success.sum()) /
self._selection[update_ucb]))
# Activate top emitters based on UCB1.
activate = np.argsort(ucb1)[-reselect.sum():]
self._active_arr[activate] = True
# Activate top emitters based on UCB1, until there are num_active
# active emitters. Activate only inactive emitters.
activate = np.argsort(ucb1)[::-1]
for i in activate:
if self._active_arr.sum() == self._num_active:
break
if not self._active_arr[i]:
self._active_arr[i] = True

self._cur_solutions = []

Expand Down
30 changes: 30 additions & 0 deletions tests/schedulers/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,33 @@ def test_tell_fails_with_wrong_shapes(scheduler_fixture, array):
scheduler.tell(objective_batch[:-1], measures_batch)
elif array == "measures_batch":
scheduler.tell(objective_batch, measures_batch[:-1])


def test_constant_active_emitters_bandit_scheduler():
solution_dim = 2
num_solutions = 4
expected_active = 3
archive = GridArchive(solution_dim=solution_dim,
dims=[100, 100],
ranges=[(-1, 1), (-1, 1)])
emitters = [
GaussianEmitter(archive,
sigma=1,
x0=[0.0, 0.0],
batch_size=num_solutions) for _ in range(10)
]
scheduler = BanditScheduler(archive, emitters, num_active=expected_active)
num_loops = 10

rng = np.random.default_rng(42)

for _ in range(num_loops):
solutions = scheduler.ask()
assert scheduler.emitters.sum() == expected_active

# Mock objective and measures for tell
objective = rng.random(len(solutions))
measures = rng.random((len(solutions), 2))
scheduler.tell(objective, measures)

assert scheduler.emitters.sum() == expected_active
Loading