Skip to content

Commit

Permalink
pass the sequence with the function, fixes #265
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Apr 14, 2020
1 parent 3d9397d commit 396ed34
Showing 1 changed file with 18 additions and 24 deletions.
42 changes: 18 additions & 24 deletions adaptive/learner/sequence_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from adaptive.learner.base_learner import BaseLearner


class _IgnoreFirstArgument:
"""Remove the first argument from the call signature.
class _IndexToPoint:
"""Call function with index of sequence.
The SequenceLearner's function receives a tuple ``(index, point)``
but the original function only takes ``point``.
Expand All @@ -15,18 +15,19 @@ class _IgnoreFirstArgument:
pickable.
"""

def __init__(self, function):
def __init__(self, function, sequence):
self.function = function
self.sequence = sequence

def __call__(self, index_point, *args, **kwargs):
index, point = index_point
def __call__(self, index, *args, **kwargs):
point = self.sequence[index]
return self.function(point, *args, **kwargs)

def __getstate__(self):
return self.function
return self.function, self.sequence

def __setstate__(self, function):
self.__init__(function)
def __setstate__(self, state):
self.__init__(*state)


class SequenceLearner(BaseLearner):
Expand All @@ -40,7 +41,7 @@ class SequenceLearner(BaseLearner):
Parameters
----------
function : callable
The function to learn. Must take a single element `sequence`.
The function to learn. Must take a single element of `sequence`.
sequence : sequence
The sequence to learn.
Expand All @@ -58,7 +59,7 @@ class SequenceLearner(BaseLearner):

def __init__(self, function, sequence):
self._original_function = function
self.function = _IgnoreFirstArgument(function)
self.function = _IndexToPoint(function, sequence)
self._to_do_indices = SortedSet({i for i, _ in enumerate(sequence)})
self._ntotal = len(sequence)
self.sequence = copy(sequence)
Expand All @@ -67,31 +68,26 @@ def __init__(self, function, sequence):

def ask(self, n, tell_pending=True):
indices = []
points = []
loss_improvements = []
for index in self._to_do_indices:
if len(points) >= n:
if len(indices) >= n:
break
point = self.sequence[index]
indices.append(index)
points.append((index, point))
loss_improvements.append(1 / self._ntotal)

if tell_pending:
for i, p in zip(indices, points):
self.tell_pending((i, p))
for index in indices:
self.tell_pending(index)

return points, loss_improvements
return indices, loss_improvements

def _get_data(self):
return self.data

def _set_data(self, data):
if data:
indices, values = zip(*data.items())
# the points aren't used by tell, so we can safely pass None
points = [(i, None) for i in indices]
self.tell_many(points, values)
self.tell_many(indices, values)

def loss(self, real=True):
if not (self._to_do_indices or self.pending_points):
Expand All @@ -105,14 +101,12 @@ def remove_unfinished(self):
self._to_do_indices.add(i)
self.pending_points = set()

def tell(self, point, value):
index, point = point
def tell(self, index, value):
self.data[index] = value
self.pending_points.discard(index)
self._to_do_indices.discard(index)

def tell_pending(self, point):
index, point = point
def tell_pending(self, index):
self.pending_points.add(index)
self._to_do_indices.discard(index)

Expand Down

0 comments on commit 396ed34

Please sign in to comment.