Skip to content

Commit

Permalink
add type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Dec 19, 2019
1 parent dfeb5ef commit 30cded3
Show file tree
Hide file tree
Showing 17 changed files with 608 additions and 390 deletions.
9 changes: 5 additions & 4 deletions adaptive/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import subprocess
from collections import namedtuple
from distutils.command.build_py import build_py as build_py_orig
from typing import Dict

from setuptools.command.sdist import sdist as sdist_orig

Expand All @@ -19,7 +20,7 @@
STATIC_VERSION_FILE = "_static_version.py"


def get_version(version_file=STATIC_VERSION_FILE):
def get_version(version_file: str = STATIC_VERSION_FILE) -> str:
version_info = get_static_version_info(version_file)
version = version_info["version"]
if version == "__use_git__":
Expand All @@ -33,7 +34,7 @@ def get_version(version_file=STATIC_VERSION_FILE):
return version


def get_static_version_info(version_file=STATIC_VERSION_FILE):
def get_static_version_info(version_file: str = STATIC_VERSION_FILE) -> Dict[str, str]:
version_info = {}
with open(os.path.join(package_root, version_file), "rb") as f:
exec(f.read(), {}, version_info)
Expand All @@ -44,7 +45,7 @@ def version_is_from_git(version_file=STATIC_VERSION_FILE):
return get_static_version_info(version_file)["version"] == "__use_git__"


def pep440_format(version_info):
def pep440_format(version_info: Version) -> str:
release, dev, labels = version_info

version_parts = [release]
Expand All @@ -61,7 +62,7 @@ def pep440_format(version_info):
return "".join(version_parts)


def get_version_from_git():
def get_version_from_git() -> Version:
try:
p = subprocess.Popen(
["git", "rev-parse", "--show-toplevel"],
Expand Down
28 changes: 17 additions & 11 deletions adaptive/learner/average_learner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from math import sqrt
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -30,7 +31,12 @@ class AverageLearner(BaseLearner):
Number of evaluated points.
"""

def __init__(self, function, atol=None, rtol=None):
def __init__(
self,
function: Callable,
atol: Optional[float] = None,
rtol: Optional[float] = None,
) -> None:
if atol is None and rtol is None:
raise Exception("At least one of `atol` and `rtol` should be set.")
if atol is None:
Expand All @@ -48,10 +54,10 @@ def __init__(self, function, atol=None, rtol=None):
self.sum_f_sq = 0

@property
def n_requested(self):
def n_requested(self) -> int:
return self.npoints + len(self.pending_points)

def ask(self, n, tell_pending=True):
def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[int], List[float]]:
points = list(range(self.n_requested, self.n_requested + n))

if any(p in self.data or p in self.pending_points for p in points):
Expand All @@ -68,7 +74,7 @@ def ask(self, n, tell_pending=True):
self.tell_pending(p)
return points, loss_improvements

def tell(self, n, value):
def tell(self, n: int, value: float) -> None:
if n in self.data:
# The point has already been added before.
return
Expand All @@ -79,16 +85,16 @@ def tell(self, n, value):
self.sum_f_sq += value ** 2
self.npoints += 1

def tell_pending(self, n):
def tell_pending(self, n: int) -> None:
self.pending_points.add(n)

@property
def mean(self):
def mean(self) -> float:
"""The average of all values in `data`."""
return self.sum_f / self.npoints

@property
def std(self):
def std(self) -> float:
"""The corrected sample standard deviation of the values
in `data`."""
n = self.npoints
Expand All @@ -101,7 +107,7 @@ def std(self):
return sqrt(numerator / (n - 1))

@cache_latest
def loss(self, real=True, *, n=None):
def loss(self, real: bool = True, *, n=None) -> float:
if n is None:
n = self.npoints if real else self.n_requested
else:
Expand All @@ -113,7 +119,7 @@ def loss(self, real=True, *, n=None):
standard_error / self.atol, standard_error / abs(self.mean) / self.rtol
)

def _loss_improvement(self, n):
def _loss_improvement(self, n: int) -> float:
loss = self.loss()
if np.isfinite(loss):
return loss - self.loss(n=self.npoints + n)
Expand All @@ -139,8 +145,8 @@ def plot(self):
vals = hv.Points(vals)
return hv.operation.histogram(vals, num_bins=num_bins, dimension=1)

def _get_data(self):
def _get_data(self) -> Tuple[Dict[int, float], int, float, float]:
return (self.data, self.npoints, self.sum_f, self.sum_f_sq)

def _set_data(self, data):
def _set_data(self, data: Tuple[Dict[int, float], int, float, float]) -> None:
self.data, self.npoints, self.sum_f, self.sum_f_sq = data
49 changes: 32 additions & 17 deletions adaptive/learner/balancing_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from contextlib import suppress
from functools import partial
from operator import itemgetter
from typing import Any, Callable, Dict, List, Set, Tuple, Union

import numpy as np

Expand All @@ -12,7 +13,7 @@
from adaptive.utils import cache_latest, named_product, restore


def dispatch(child_functions, arg):
def dispatch(child_functions: List[Callable], arg: Any) -> Union[Any]:
index, x = arg
return child_functions[index](x)

Expand Down Expand Up @@ -68,7 +69,9 @@ class BalancingLearner(BaseLearner):
behave in an undefined way. Change the `strategy` in that case.
"""

def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
def __init__(
self, learners: List[BaseLearner], *, cdims=None, strategy="loss_improvements"
) -> None:
self.learners = learners

# Naively we would make 'function' a method, but this causes problems
Expand All @@ -89,21 +92,21 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
self.strategy = strategy

@property
def data(self):
def data(self) -> Dict[Tuple[int, Any], Any]:
data = {}
for i, l in enumerate(self.learners):
data.update({(i, p): v for p, v in l.data.items()})
return data

@property
def pending_points(self):
def pending_points(self) -> Set[Tuple[int, Any]]:
pending_points = set()
for i, l in enumerate(self.learners):
pending_points.update({(i, p) for p in l.pending_points})
return pending_points

@property
def npoints(self):
def npoints(self) -> int:
return sum(l.npoints for l in self.learners)

@property
Expand Down Expand Up @@ -135,7 +138,9 @@ def strategy(self, strategy):
' strategy="npoints", or strategy="cycle" is implemented.'
)

def _ask_and_tell_based_on_loss_improvements(self, n):
def _ask_and_tell_based_on_loss_improvements(
self, n: int
) -> Tuple[List[Tuple[int, Any]], List[float]]:
selected = [] # tuples ((learner_index, point), loss_improvement)
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
for _ in range(n):
Expand All @@ -158,7 +163,9 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
points, loss_improvements = map(list, zip(*selected))
return points, loss_improvements

def _ask_and_tell_based_on_loss(self, n):
def _ask_and_tell_based_on_loss(
self, n: int
) -> Tuple[List[Tuple[int, Any]], List[float]]:
selected = [] # tuples ((learner_index, point), loss_improvement)
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
for _ in range(n):
Expand All @@ -179,7 +186,9 @@ def _ask_and_tell_based_on_loss(self, n):
points, loss_improvements = map(list, zip(*selected))
return points, loss_improvements

def _ask_and_tell_based_on_npoints(self, n):
def _ask_and_tell_based_on_npoints(
self, n: int
) -> Tuple[List[Tuple[int, Any]], List[float]]:
selected = [] # tuples ((learner_index, point), loss_improvement)
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
for _ in range(n):
Expand All @@ -195,7 +204,9 @@ def _ask_and_tell_based_on_npoints(self, n):
points, loss_improvements = map(list, zip(*selected))
return points, loss_improvements

def _ask_and_tell_based_on_cycle(self, n):
def _ask_and_tell_based_on_cycle(
self, n: int
) -> Tuple[List[Tuple[int, Any]], List[float]]:
points, loss_improvements = [], []
for _ in range(n):
index = next(self._cycle)
Expand All @@ -206,7 +217,9 @@ def _ask_and_tell_based_on_cycle(self, n):

return points, loss_improvements

def ask(self, n, tell_pending=True):
def ask(
self, n: int, tell_pending: bool = True
) -> Tuple[List[Tuple[int, Any]], List[float]]:
"""Chose points for learners."""
if n == 0:
return [], []
Expand All @@ -217,20 +230,20 @@ def ask(self, n, tell_pending=True):
else:
return self._ask_and_tell(n)

def tell(self, x, y):
def tell(self, x: Tuple[int, Any], y: Any) -> None:
index, x = x
self._ask_cache.pop(index, None)
self._loss.pop(index, None)
self._pending_loss.pop(index, None)
self.learners[index].tell(x, y)

def tell_pending(self, x):
def tell_pending(self, x: Tuple[int, Any]) -> None:
index, x = x
self._ask_cache.pop(index, None)
self._loss.pop(index, None)
self.learners[index].tell_pending(x)

def _losses(self, real=True):
def _losses(self, real: bool = True) -> List[float]:
losses = []
loss_dict = self._loss if real else self._pending_loss

Expand All @@ -242,7 +255,7 @@ def _losses(self, real=True):
return losses

@cache_latest
def loss(self, real=True):
def loss(self, real: bool = True) -> Union[float]:
losses = self._losses(real)
return max(losses)

Expand Down Expand Up @@ -325,7 +338,9 @@ def remove_unfinished(self):
learner.remove_unfinished()

@classmethod
def from_product(cls, f, learner_type, learner_kwargs, combos):
def from_product(
cls, f, learner_type, learner_kwargs, combos
) -> "BalancingLearner":
"""Create a `BalancingLearner` with learners of all combinations of
named variables’ values. The `cdims` will be set correctly, so calling
`learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
Expand Down Expand Up @@ -372,7 +387,7 @@ def from_product(cls, f, learner_type, learner_kwargs, combos):
learners.append(learner)
return cls(learners, cdims=arguments)

def save(self, fname, compress=True):
def save(self, fname: Callable, compress: bool = True) -> None:
"""Save the data of the child learners into pickle files
in a directory.
Expand Down Expand Up @@ -410,7 +425,7 @@ def save(self, fname, compress=True):
for l in self.learners:
l.save(fname(l), compress=compress)

def load(self, fname, compress=True):
def load(self, fname: Callable, compress: bool = True) -> None:
"""Load the data of the child learners from pickle files
in a directory.
Expand Down
15 changes: 8 additions & 7 deletions adaptive/learner/base_learner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import abc
from contextlib import suppress
from copy import deepcopy
from typing import Any, Callable, Dict

from adaptive.utils import _RequireAttrsABCMeta, load, save


def uses_nth_neighbors(n):
def uses_nth_neighbors(n: int) -> Callable:
"""Decorator to specify how many neighboring intervals the loss function uses.
Wraps loss functions to indicate that they expect intervals together
Expand Down Expand Up @@ -84,7 +85,7 @@ class BaseLearner(metaclass=_RequireAttrsABCMeta):
npoints: int
pending_points: set

def tell(self, x, y):
def tell(self, x: Any, y) -> None:
"""Tell the learner about a single value.
Parameters
Expand All @@ -94,7 +95,7 @@ def tell(self, x, y):
"""
self.tell_many([x], [y])

def tell_many(self, xs, ys):
def tell_many(self, xs: Any, ys: Any) -> None:
"""Tell the learner about some values.
Parameters
Expand Down Expand Up @@ -161,7 +162,7 @@ def copy_from(self, other):
"""
self._set_data(other._get_data())

def save(self, fname, compress=True):
def save(self, fname: str, compress: bool = True) -> None:
"""Save the data of the learner into a pickle file.
Parameters
Expand All @@ -175,7 +176,7 @@ def save(self, fname, compress=True):
data = self._get_data()
save(fname, data, compress)

def load(self, fname, compress=True):
def load(self, fname: str, compress: bool = True) -> None:
"""Load the data of a learner from a pickle file.
Parameters
Expand All @@ -190,8 +191,8 @@ def load(self, fname, compress=True):
data = load(fname, compress)
self._set_data(data)

def __getstate__(self):
def __getstate__(self) -> Dict[str, Any]:
return deepcopy(self.__dict__)

def __setstate__(self, state):
def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__ = state
Loading

0 comments on commit 30cded3

Please sign in to comment.