From 160e6642047848b499c748381e90be17e1154dec Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Tue, 14 Apr 2020 16:55:18 +0200 Subject: [PATCH] pass the sequence with the function, fixes https://github.com/python-adaptive/adaptive/issues/265 --- adaptive/learner/sequence_learner.py | 68 +++++++++++++++------------- 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/adaptive/learner/sequence_learner.py b/adaptive/learner/sequence_learner.py index c7398dfa4..6dfdceece 100644 --- a/adaptive/learner/sequence_learner.py +++ b/adaptive/learner/sequence_learner.py @@ -5,8 +5,8 @@ from adaptive.learner.base_learner import BaseLearner -class _IgnoreFirstArgument: - """Remove the first argument from the call signature. +class _IndexToPoint: + """Call function with index of sequence. The SequenceLearner's function receives a tuple ``(index, point)`` but the original function only takes ``point``. @@ -15,18 +15,19 @@ class _IgnoreFirstArgument: pickable. """ - def __init__(self, function): + def __init__(self, function, sequence): self.function = function + self.sequence = sequence - def __call__(self, index_point, *args, **kwargs): - index, point = index_point + def __call__(self, index, *args, **kwargs): + point = self.sequence[index] return self.function(point, *args, **kwargs) def __getstate__(self): - return self.function + return self.function, self.sequence - def __setstate__(self, function): - self.__init__(function) + def __setstate__(self, state): + self.__init__(*state) class SequenceLearner(BaseLearner): @@ -40,7 +41,7 @@ class SequenceLearner(BaseLearner): Parameters ---------- function : callable - The function to learn. Must take a single element `sequence`. + The function to learn. Must take a single element of `sequence`. sequence : sequence The sequence to learn. @@ -58,7 +59,7 @@ class SequenceLearner(BaseLearner): def __init__(self, function, sequence): self._original_function = function - self.function = _IgnoreFirstArgument(function) + self.function = _IndexToPoint(function, sequence) self._to_do_indices = SortedSet({i for i, _ in enumerate(sequence)}) self._ntotal = len(sequence) self.sequence = copy(sequence) @@ -67,31 +68,18 @@ def __init__(self, function, sequence): def ask(self, n, tell_pending=True): indices = [] - points = [] loss_improvements = [] for index in self._to_do_indices: - if len(points) >= n: + if len(indices) >= n: break - point = self.sequence[index] indices.append(index) - points.append((index, point)) loss_improvements.append(1 / self._ntotal) if tell_pending: - for i, p in zip(indices, points): - self.tell_pending((i, p)) + for index in indices: + self.tell_pending(index) - return points, loss_improvements - - def _get_data(self): - return self.data - - def _set_data(self, data): - if data: - indices, values = zip(*data.items()) - # the points aren't used by tell, so we can safely pass None - points = [(i, None) for i in indices] - self.tell_many(points, values) + return indices, loss_improvements def loss(self, real=True): if not (self._to_do_indices or self.pending_points): @@ -105,14 +93,12 @@ def remove_unfinished(self): self._to_do_indices.add(i) self.pending_points = set() - def tell(self, point, value): - index, point = point + def tell(self, index, value): self.data[index] = value self.pending_points.discard(index) self._to_do_indices.discard(index) - def tell_pending(self, point): - index, point = point + def tell_pending(self, index): self.pending_points.add(index) self._to_do_indices.discard(index) @@ -128,3 +114,23 @@ def result(self): @property def npoints(self): return len(self.data) + + def _get_data(self): + return self.data + + def _set_data(self, data): + if data: + indices, values = zip(*data.items()) + self.tell_many(indices, values) + + def __getstate__(self): + return ( + self._original_function, + self.sequence, + self._get_data(), + ) + + def __setstate__(self, state): + function, sequence, data = state + self.__init__(function, sequence) + self._set_data(data)