Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix problem when learner has no more points for BalancingLearner #214

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions adaptive/learner/balancing_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,27 @@ def strategy(self, strategy):
' strategy="npoints", or strategy="cycle" is implemented.'
)

def _to_select(self, total_points):
to_select = []
for index, learner in enumerate(self.learners):
# Take the points from the cache
if index not in self._ask_cache:
self._ask_cache[index] = learner.ask(n=1, tell_pending=False)
points, loss_improvements = self._ask_cache[index]
if not points: # cannot ask for more points
return to_select
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a couple of things I don't understand here:

  • seems this return should be a continue; just because learner i could not give any more points does not mean that no other learners can give any!
  • I cannot see when this branch will ever be executed. learner.ask(1) is guaranteed to return a point. At the moment there is no way for a learner to indicate that it has "no more points". If a learner returns no points then it is in violation of the API and other stuff is liable to break

to_select.append(
((index, points[0]), (loss_improvements[0], -total_points[index]))
)
return to_select

def _ask_and_tell_based_on_loss_improvements(self, n):
selected = [] # tuples ((learner_index, point), loss_improvement)
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
for _ in range(n):
to_select = []
for index, learner in enumerate(self.learners):
# Take the points from the cache
if index not in self._ask_cache:
self._ask_cache[index] = learner.ask(n=1, tell_pending=False)
points, loss_improvements = self._ask_cache[index]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here one would need a double break, but because that doesn't exist I make it into a function.

to_select.append(
((index, points[0]), (loss_improvements[0], -total_points[index]))
)

to_select = self._to_select(total_points)
if not to_select: # cannot ask for more points
break
# Choose the optimal improvement.
(index, point), (loss_improvement, _) = max(to_select, key=itemgetter(1))
total_points[index] += 1
Expand All @@ -156,7 +163,8 @@ def _ask_and_tell_based_on_loss(self, n):
if index not in self._ask_cache:
self._ask_cache[index] = self.learners[index].ask(n=1)
points, loss_improvements = self._ask_cache[index]

if not points: # cannot ask for more points
break
selected.append(((index, points[0]), loss_improvements[0]))
self.tell_pending((index, points[0]))

Expand All @@ -172,6 +180,8 @@ def _ask_and_tell_based_on_npoints(self, n):
if index not in self._ask_cache:
self._ask_cache[index] = self.learners[index].ask(n=1)
points, loss_improvements = self._ask_cache[index]
if not points: # cannot ask for more points
break
total_points[index] += 1
selected.append(((index, points[0]), loss_improvements[0]))
self.tell_pending((index, points[0]))
Expand Down