diff --git a/adaptive/learner/balancing_learner.py b/adaptive/learner/balancing_learner.py index d836bdbec..a86932f04 100644 --- a/adaptive/learner/balancing_learner.py +++ b/adaptive/learner/balancing_learner.py @@ -119,27 +119,34 @@ def strategy(self, strategy): ' strategy="npoints", or strategy="cycle" is implemented.' ) + def _ask_all_learners(self, total_points): + to_select = [] + for index, learner in enumerate(self.learners): + # Take the points from the cache + if index not in self._ask_cache: + self._ask_cache[index] = learner.ask(n=1, tell_pending=False) + points, loss_improvements = self._ask_cache[index] + if not points: # cannot ask for more points + return to_select + to_select.append( + ((index, points[0]), (loss_improvements[0], -total_points[index])) + ) + return to_select + def _ask_and_tell_based_on_loss_improvements(self, n): selected = [] # tuples ((learner_index, point), loss_improvement) total_points = [l.npoints + len(l.pending_points) for l in self.learners] for _ in range(n): - to_select = [] - for index, learner in enumerate(self.learners): - # Take the points from the cache - if index not in self._ask_cache: - self._ask_cache[index] = learner.ask(n=1, tell_pending=False) - points, loss_improvements = self._ask_cache[index] - to_select.append( - ((index, points[0]), (loss_improvements[0], -total_points[index])) - ) - + to_select = self._ask_all_learners(total_points) + if not to_select: # cannot ask for more points + break # Choose the optimal improvement. (index, point), (loss_improvement, _) = max(to_select, key=itemgetter(1)) total_points[index] += 1 selected.append(((index, point), loss_improvement)) self.tell_pending((index, point)) - points, loss_improvements = map(list, zip(*selected)) + points, loss_improvements = map(list, zip(*selected)) if selected else [], [] return points, loss_improvements def _ask_and_tell_based_on_loss(self, n): @@ -156,11 +163,12 @@ def _ask_and_tell_based_on_loss(self, n): if index not in self._ask_cache: self._ask_cache[index] = self.learners[index].ask(n=1) points, loss_improvements = self._ask_cache[index] - + if not points: # cannot ask for more points + break selected.append(((index, points[0]), loss_improvements[0])) self.tell_pending((index, points[0])) - points, loss_improvements = map(list, zip(*selected)) + points, loss_improvements = map(list, zip(*selected)) if selected else [], [] return points, loss_improvements def _ask_and_tell_based_on_npoints(self, n): @@ -172,11 +180,13 @@ def _ask_and_tell_based_on_npoints(self, n): if index not in self._ask_cache: self._ask_cache[index] = self.learners[index].ask(n=1) points, loss_improvements = self._ask_cache[index] + if not points: # cannot ask for more points + break total_points[index] += 1 selected.append(((index, points[0]), loss_improvements[0])) self.tell_pending((index, points[0])) - points, loss_improvements = map(list, zip(*selected)) + points, loss_improvements = map(list, zip(*selected)) if selected else [], [] return points, loss_improvements def _ask_and_tell_based_on_cycle(self, n): diff --git a/adaptive/tests/test_balancing_learner.py b/adaptive/tests/test_balancing_learner.py index fff0ff186..485cc0229 100644 --- a/adaptive/tests/test_balancing_learner.py +++ b/adaptive/tests/test_balancing_learner.py @@ -2,7 +2,7 @@ import pytest -from adaptive.learner import BalancingLearner, Learner1D +from adaptive.learner import BalancingLearner, Learner1D, SequenceLearner from adaptive.runner import simple @@ -44,6 +44,16 @@ def test_distribute_first_points_over_learners(strategy): assert len(set(i_learner)) == len(learners) +@pytest.mark.parametrize("strategy", strategies) +def test_asking_more_points_than_available(strategy): + def dummy(x): + return x + + bl = BalancingLearner([SequenceLearner(dummy, range(5))], strategy=strategy) + bl.ask(100) + bl.ask(100) + + @pytest.mark.parametrize("strategy", strategies) def test_ask_0(strategy): learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]