Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Oct 11, 2022
2 parents 3c9d902 + 9860573 commit 17e84e4
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions adaptive/learner/sequence_learner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from copy import copy
from typing import Any, Callable, Iterable
from typing import Any, Callable, Iterable, Tuple

import cloudpickle
import numpy as np
from sortedcontainers import SortedDict, SortedSet

from adaptive.learner.base_learner import BaseLearner
from adaptive.types import Int
from adaptive.utils import assign_defaults, partial_function_from_dataframe

try:
Expand All @@ -18,6 +18,14 @@
except ModuleNotFoundError:
with_pandas = False

try:
from typing import TypeAlias
except ImportError:
from typing_extensions import TypeAlias


PointType: TypeAlias = Tuple[Int, Any]


class _IgnoreFirstArgument:
"""Remove the first argument from the call signature.
Expand All @@ -32,9 +40,7 @@ class _IgnoreFirstArgument:
def __init__(self, function: Callable) -> None:
self.function = function # type: ignore

def __call__(
self, index_point: tuple[int, float | np.ndarray], *args, **kwargs
) -> float:
def __call__(self, index_point: PointType, *args, **kwargs):
index, point = index_point
return self.function(point, *args, **kwargs)

Expand Down Expand Up @@ -85,7 +91,9 @@ def new(self) -> SequenceLearner:
"""Return a new `~adaptive.SequenceLearner` without the data."""
return SequenceLearner(self._original_function, self.sequence)

def ask(self, n: int, tell_pending: bool = True) -> tuple[Any, list[float]]:
def ask(
self, n: int, tell_pending: bool = True
) -> tuple[list[PointType], list[float]]:
indices = []
points = []
loss_improvements = []
Expand All @@ -105,31 +113,31 @@ def ask(self, n: int, tell_pending: bool = True) -> tuple[Any, list[float]]:

def loss(self, real: bool = True) -> float:
if not (self._to_do_indices or self.pending_points):
return 0
return 0.0
else:
npoints = self.npoints + (0 if real else len(self.pending_points))
return (self._ntotal - npoints) / self._ntotal

def remove_unfinished(self):
def remove_unfinished(self) -> None:
for i in self.pending_points:
self._to_do_indices.add(i)
self.pending_points = set()

def tell(self, point: tuple[int, Any], value: Any) -> None:
def tell(self, point: PointType, value: Any) -> None:
index, point = point
self.data[index] = value
self.pending_points.discard(index)
self._to_do_indices.discard(index)

def tell_pending(self, point: Any) -> None:
def tell_pending(self, point: PointType) -> None:
index, point = point
self.pending_points.add(index)
self._to_do_indices.discard(index)

def done(self):
def done(self) -> bool:
return not self._to_do_indices and not self.pending_points

def result(self):
def result(self) -> list[Any]:
"""Get the function values in the same order as ``sequence``."""
if not self.done():
raise Exception("Learner is not yet complete.")
Expand Down Expand Up @@ -217,16 +225,18 @@ def load_dataframe(
y_name : str, optional
The ``y_name`` used in ``to_dataframe``, by default "y"
"""
self.tell_many(df[[index_name, x_name]].values, df[y_name].values)
indices = df[index_name].values
xs = df[x_name].values
self.tell_many(zip(indices, xs), df[y_name].values)
if with_default_function_args:
self.function = partial_function_from_dataframe(
self._original_function, df, function_prefix
)

def _get_data(self) -> SortedDict:
def _get_data(self) -> dict[int, Any]:
return self.data

def _set_data(self, data: SortedDict) -> None:
def _set_data(self, data: dict[int, Any]) -> None:
if data:
indices, values = zip(*data.items())
# the points aren't used by tell, so we can safely pass None
Expand Down

0 comments on commit 17e84e4

Please sign in to comment.