Skip to content

Commit

Permalink
improve triangulation and LearnerND typing
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Dec 15, 2019
1 parent f21f19d commit 91c1ebd
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 69 deletions.
56 changes: 30 additions & 26 deletions adaptive/learner/learnerND.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
from collections import OrderedDict
from collections.abc import Iterable
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import numpy as np
import scipy.spatial
Expand All @@ -13,6 +13,8 @@

from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors
from adaptive.learner.triangulation import (
Point,
Simplex,
Triangulation,
circumsphere,
fast_det,
Expand Down Expand Up @@ -40,7 +42,7 @@ def volume(simplex: List[Tuple[float, float]], ys: None = None,) -> float:
return vol


def orientation(simplex):
def orientation(simplex: np.ndarray):
matrix = np.subtract(simplex[:-1], simplex[-1])
# See https://www.jstor.org/stable/2315353
sign, _logdet = np.linalg.slogdet(matrix)
Expand Down Expand Up @@ -339,12 +341,14 @@ def __init__(

self.function = func
self._tri = None
self._losses = dict()
self._losses: Dict[Simplex, float] = dict()

self._pending_to_simplex = dict() # vertex → simplex
self._pending_to_simplex: Dict[Point, Simplex] = dict() # vertex → simplex

# triangulation of the pending points inside a specific simplex
self._subtriangulations = dict() # simplex → triangulation
self._subtriangulations: Dict[
Simplex, Triangulation
] = dict() # simplex → triangulation

# scale to unit hypercube
# for the input
Expand Down Expand Up @@ -456,7 +460,7 @@ def tell(self, point: Tuple[float, ...], value: Union[float, np.ndarray],) -> No
to_delete, to_add = tri.add_point(point, simplex, transform=self._transform)
self._update_losses(to_delete, to_add)

def _simplex_exists(self, simplex: Any) -> bool: # XXX: specify simplex: Any
def _simplex_exists(self, simplex: Simplex) -> bool:
simplex = tuple(sorted(simplex))
return simplex in self.tri.simplices

Expand Down Expand Up @@ -498,7 +502,7 @@ def tell_pending(self, point: Tuple[float, ...], *, simplex=None,) -> None:
self._update_subsimplex_losses(simpl, to_add)

def _try_adding_pending_point_to_simplex(
self, point: Tuple[float, ...], simplex: Any, # XXX: specify simplex: Any
self, point: Point, simplex: Simplex,
) -> Any:
# try to insert it
if not self.tri.point_in_simplex(point, simplex):
Expand All @@ -512,8 +516,8 @@ def _try_adding_pending_point_to_simplex(
return self._subtriangulations[simplex].add_point(point)

def _update_subsimplex_losses(
self, simplex: Any, new_subsimplices: Any
) -> None: # XXX: specify simplex: Any
self, simplex: Simplex, new_subsimplices: Set[Simplex]
) -> None:
loss = self._losses[simplex]

loss_density = loss / self.tri.volume(simplex)
Expand All @@ -534,7 +538,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any:
else:
return self._ask_and_tell_pending(n)

def _ask_bound_point(self,) -> Tuple[Tuple[float, ...], float]:
def _ask_bound_point(self,) -> Tuple[Point, float]:
# get the next bound point that is still available
new_point = next(
p
Expand All @@ -544,7 +548,7 @@ def _ask_bound_point(self,) -> Tuple[Tuple[float, ...], float]:
self.tell_pending(new_point)
return new_point, np.inf

def _ask_point_without_known_simplices(self,) -> Tuple[Tuple[float, ...], float]:
def _ask_point_without_known_simplices(self,) -> Tuple[Point, float]:
assert not self._bounds_available
# pick a random point inside the bounds
# XXX: change this into picking a point based on volume loss
Expand Down Expand Up @@ -585,7 +589,7 @@ def _pop_highest_existing_simplex(self) -> Any:
" be a simplex available if LearnerND.tri() is not None."
)

def _ask_best_point(self,) -> Tuple[Tuple[float, ...], float]:
def _ask_best_point(self,) -> Tuple[Point, float]:
assert self.tri is not None

loss, simplex, subsimplex = self._pop_highest_existing_simplex()
Expand All @@ -612,7 +616,7 @@ def _bounds_available(self) -> bool:
for p in self._bounds_points
)

def _ask(self,) -> Tuple[Tuple[float, ...], float]:
def _ask(self,) -> Tuple[Point, float]:
if self._bounds_available:
return self._ask_bound_point() # O(1)

Expand All @@ -624,7 +628,7 @@ def _ask(self,) -> Tuple[Tuple[float, ...], float]:

return self._ask_best_point() # O(log N)

def _compute_loss(self, simplex: Any) -> float: # XXX: specify simplex: Any
def _compute_loss(self, simplex: Simplex) -> float:
# get the loss
vertices = self.tri.get_vertices(simplex)
values = [self.data[tuple(v)] for v in vertices]
Expand Down Expand Up @@ -663,7 +667,7 @@ def _compute_loss(self, simplex: Any) -> float: # XXX: specify simplex: Any
)
)

def _update_losses(self, to_delete: set, to_add: set) -> None:
def _update_losses(self, to_delete: Set[Simplex], to_add: Set[Simplex]) -> None:
# XXX: add the points outside the triangulation to this as well
pending_points_unbound = set()

Expand Down Expand Up @@ -733,13 +737,11 @@ def _recompute_all_losses(self) -> None:
)

@property
def _scale(self) -> Union[float, np.int64]:
def _scale(self) -> float:
# get the output scale
return self._max_value - self._min_value

def _update_range(
self, new_output: Union[List[int], float, float, np.ndarray]
) -> bool:
def _update_range(self, new_output: Union[List[int], float, np.ndarray]) -> bool:
if self._min_value is None or self._max_value is None:
# this is the first point, nothing to do, just set the range
self._min_value = np.min(new_output)
Expand Down Expand Up @@ -790,7 +792,7 @@ def remove_unfinished(self) -> None:
# Plotting related stuff #
##########################

def plot(self, n=None, tri_alpha=0):
def plot(self, n: Optional[int] = None, tri_alpha: float = 0):
"""Plot the function we want to learn, only works in 2D.
Parameters
Expand Down Expand Up @@ -851,7 +853,7 @@ def plot(self, n=None, tri_alpha=0):

return im.opts(style=im_opts) * tris.opts(style=tri_opts, **no_hover)

def plot_slice(self, cut_mapping, n=None):
def plot_slice(self, cut_mapping: Dict[int, float], n: Optional[int] = None):
"""Plot a 1D or 2D interpolated slice of a N-dimensional function.
Parameters
Expand Down Expand Up @@ -921,7 +923,7 @@ def plot_slice(self, cut_mapping, n=None):
else:
raise ValueError("Only 1 or 2-dimensional plots can be generated.")

def plot_3D(self, with_triangulation=False):
def plot_3D(self, with_triangulation: bool = False):
"""Plot the learner's data in 3D using plotly.
Does *not* work with the
Expand Down Expand Up @@ -1010,7 +1012,7 @@ def _set_data(self, data: OrderedDict) -> None:
if data:
self.tell_many(*zip(*data.items()))

def _get_iso(self, level=0.0, which="surface"):
def _get_iso(self, level: float = 0.0, which: str = "surface"):
if which == "surface":
if self.ndim != 3 or self.vdim != 1:
raise Exception(
Expand Down Expand Up @@ -1081,7 +1083,9 @@ def _get_vertex_index(a, b):

return vertices, faces_or_lines

def plot_isoline(self, level=0.0, n=None, tri_alpha=0):
def plot_isoline(
self, level: float = 0.0, n: Optional[int] = None, tri_alpha: float = 0
):
"""Plot the isoline at a specific level, only works in 2D.
Parameters
Expand Down Expand Up @@ -1121,7 +1125,7 @@ def plot_isoline(self, level=0.0, n=None, tri_alpha=0):
contour = contour.opts(style=contour_opts)
return plot * contour

def plot_isosurface(self, level=0.0, hull_opacity=0.2):
def plot_isosurface(self, level: float = 0.0, hull_opacity: float = 0.2):
"""Plots a linearly interpolated isosurface.
This is the 3D analog of an isoline. Does *not* work with the
Expand Down Expand Up @@ -1159,7 +1163,7 @@ def plot_isosurface(self, level=0.0, hull_opacity=0.2):
hull_mesh = self._get_hull_mesh(opacity=hull_opacity)
return plotly.offline.iplot([isosurface, hull_mesh])

def _get_hull_mesh(self, opacity=0.2):
def _get_hull_mesh(self, opacity: float = 0.2):
plotly = ensure_plotly()
hull = scipy.spatial.ConvexHull(self._bounds_points)

Expand Down
Loading

0 comments on commit 91c1ebd

Please sign in to comment.