diff --git a/adaptive/learner/learnerND.py b/adaptive/learner/learnerND.py index de6b79d3d..61621388d 100644 --- a/adaptive/learner/learnerND.py +++ b/adaptive/learner/learnerND.py @@ -3,7 +3,7 @@ import random from collections import OrderedDict from collections.abc import Iterable -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import numpy as np import scipy.spatial @@ -13,6 +13,8 @@ from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors from adaptive.learner.triangulation import ( + Point, + Simplex, Triangulation, circumsphere, fast_det, @@ -40,7 +42,7 @@ def volume(simplex: List[Tuple[float, float]], ys: None = None,) -> float: return vol -def orientation(simplex): +def orientation(simplex: np.ndarray): matrix = np.subtract(simplex[:-1], simplex[-1]) # See https://www.jstor.org/stable/2315353 sign, _logdet = np.linalg.slogdet(matrix) @@ -339,12 +341,14 @@ def __init__( self.function = func self._tri = None - self._losses = dict() + self._losses: Dict[Simplex, float] = dict() - self._pending_to_simplex = dict() # vertex → simplex + self._pending_to_simplex: Dict[Point, Simplex] = dict() # vertex → simplex # triangulation of the pending points inside a specific simplex - self._subtriangulations = dict() # simplex → triangulation + self._subtriangulations: Dict[ + Simplex, Triangulation + ] = dict() # simplex → triangulation # scale to unit hypercube # for the input @@ -456,7 +460,7 @@ def tell(self, point: Tuple[float, ...], value: Union[float, np.ndarray],) -> No to_delete, to_add = tri.add_point(point, simplex, transform=self._transform) self._update_losses(to_delete, to_add) - def _simplex_exists(self, simplex: Any) -> bool: # XXX: specify simplex: Any + def _simplex_exists(self, simplex: Simplex) -> bool: simplex = tuple(sorted(simplex)) return simplex in self.tri.simplices @@ -498,7 +502,7 @@ def tell_pending(self, point: Tuple[float, ...], *, simplex=None,) -> None: self._update_subsimplex_losses(simpl, to_add) def _try_adding_pending_point_to_simplex( - self, point: Tuple[float, ...], simplex: Any, # XXX: specify simplex: Any + self, point: Point, simplex: Simplex, ) -> Any: # try to insert it if not self.tri.point_in_simplex(point, simplex): @@ -512,8 +516,8 @@ def _try_adding_pending_point_to_simplex( return self._subtriangulations[simplex].add_point(point) def _update_subsimplex_losses( - self, simplex: Any, new_subsimplices: Any - ) -> None: # XXX: specify simplex: Any + self, simplex: Simplex, new_subsimplices: Set[Simplex] + ) -> None: loss = self._losses[simplex] loss_density = loss / self.tri.volume(simplex) @@ -534,7 +538,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Any: else: return self._ask_and_tell_pending(n) - def _ask_bound_point(self,) -> Tuple[Tuple[float, ...], float]: + def _ask_bound_point(self,) -> Tuple[Point, float]: # get the next bound point that is still available new_point = next( p @@ -544,7 +548,7 @@ def _ask_bound_point(self,) -> Tuple[Tuple[float, ...], float]: self.tell_pending(new_point) return new_point, np.inf - def _ask_point_without_known_simplices(self,) -> Tuple[Tuple[float, ...], float]: + def _ask_point_without_known_simplices(self,) -> Tuple[Point, float]: assert not self._bounds_available # pick a random point inside the bounds # XXX: change this into picking a point based on volume loss @@ -585,7 +589,7 @@ def _pop_highest_existing_simplex(self) -> Any: " be a simplex available if LearnerND.tri() is not None." ) - def _ask_best_point(self,) -> Tuple[Tuple[float, ...], float]: + def _ask_best_point(self,) -> Tuple[Point, float]: assert self.tri is not None loss, simplex, subsimplex = self._pop_highest_existing_simplex() @@ -612,7 +616,7 @@ def _bounds_available(self) -> bool: for p in self._bounds_points ) - def _ask(self,) -> Tuple[Tuple[float, ...], float]: + def _ask(self,) -> Tuple[Point, float]: if self._bounds_available: return self._ask_bound_point() # O(1) @@ -624,7 +628,7 @@ def _ask(self,) -> Tuple[Tuple[float, ...], float]: return self._ask_best_point() # O(log N) - def _compute_loss(self, simplex: Any) -> float: # XXX: specify simplex: Any + def _compute_loss(self, simplex: Simplex) -> float: # get the loss vertices = self.tri.get_vertices(simplex) values = [self.data[tuple(v)] for v in vertices] @@ -663,7 +667,7 @@ def _compute_loss(self, simplex: Any) -> float: # XXX: specify simplex: Any ) ) - def _update_losses(self, to_delete: set, to_add: set) -> None: + def _update_losses(self, to_delete: Set[Simplex], to_add: Set[Simplex]) -> None: # XXX: add the points outside the triangulation to this as well pending_points_unbound = set() @@ -733,13 +737,11 @@ def _recompute_all_losses(self) -> None: ) @property - def _scale(self) -> Union[float, np.int64]: + def _scale(self) -> float: # get the output scale return self._max_value - self._min_value - def _update_range( - self, new_output: Union[List[int], float, float, np.ndarray] - ) -> bool: + def _update_range(self, new_output: Union[List[int], float, np.ndarray]) -> bool: if self._min_value is None or self._max_value is None: # this is the first point, nothing to do, just set the range self._min_value = np.min(new_output) @@ -790,7 +792,7 @@ def remove_unfinished(self) -> None: # Plotting related stuff # ########################## - def plot(self, n=None, tri_alpha=0): + def plot(self, n: Optional[int] = None, tri_alpha: float = 0): """Plot the function we want to learn, only works in 2D. Parameters @@ -851,7 +853,7 @@ def plot(self, n=None, tri_alpha=0): return im.opts(style=im_opts) * tris.opts(style=tri_opts, **no_hover) - def plot_slice(self, cut_mapping, n=None): + def plot_slice(self, cut_mapping: Dict[int, float], n: Optional[int] = None): """Plot a 1D or 2D interpolated slice of a N-dimensional function. Parameters @@ -921,7 +923,7 @@ def plot_slice(self, cut_mapping, n=None): else: raise ValueError("Only 1 or 2-dimensional plots can be generated.") - def plot_3D(self, with_triangulation=False): + def plot_3D(self, with_triangulation: bool = False): """Plot the learner's data in 3D using plotly. Does *not* work with the @@ -1010,7 +1012,7 @@ def _set_data(self, data: OrderedDict) -> None: if data: self.tell_many(*zip(*data.items())) - def _get_iso(self, level=0.0, which="surface"): + def _get_iso(self, level: float = 0.0, which: str = "surface"): if which == "surface": if self.ndim != 3 or self.vdim != 1: raise Exception( @@ -1081,7 +1083,9 @@ def _get_vertex_index(a, b): return vertices, faces_or_lines - def plot_isoline(self, level=0.0, n=None, tri_alpha=0): + def plot_isoline( + self, level: float = 0.0, n: Optional[int] = None, tri_alpha: float = 0 + ): """Plot the isoline at a specific level, only works in 2D. Parameters @@ -1121,7 +1125,7 @@ def plot_isoline(self, level=0.0, n=None, tri_alpha=0): contour = contour.opts(style=contour_opts) return plot * contour - def plot_isosurface(self, level=0.0, hull_opacity=0.2): + def plot_isosurface(self, level: float = 0.0, hull_opacity: float = 0.2): """Plots a linearly interpolated isosurface. This is the 3D analog of an isoline. Does *not* work with the @@ -1159,7 +1163,7 @@ def plot_isosurface(self, level=0.0, hull_opacity=0.2): hull_mesh = self._get_hull_mesh(opacity=hull_opacity) return plotly.offline.iplot([isosurface, hull_mesh]) - def _get_hull_mesh(self, opacity=0.2): + def _get_hull_mesh(self, opacity: float = 0.2): plotly = ensure_plotly() hull = scipy.spatial.ConvexHull(self._bounds_points) diff --git a/adaptive/learner/triangulation.py b/adaptive/learner/triangulation.py index 41d75766b..a59c265e4 100644 --- a/adaptive/learner/triangulation.py +++ b/adaptive/learner/triangulation.py @@ -1,14 +1,18 @@ +import collections.abc import math from collections import Counter -from collections.abc import Iterable, Sized from itertools import chain, combinations from math import factorial -from typing import Any, Iterator, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union import numpy as np import scipy.spatial -Simplex = Tuple[int, ...] # XXX: check if this is correct +SimplexPoints = Union[ + List[Tuple[float, ...]], np.ndarray +] # XXX: check if this is correct +Simplex = Tuple[int, ...] +Point = Union[Tuple[float, ...], np.ndarray] # XXX: check if this is correct def fast_norm(v: Union[Tuple[float, ...], np.ndarray]) -> float: @@ -21,9 +25,7 @@ def fast_norm(v: Union[Tuple[float, ...], np.ndarray]) -> float: def fast_2d_point_in_simplex( - point: Tuple[float, ...], - simplex: Union[List[Tuple[float, ...]], np.ndarray], - eps: float = 1e-8, + point: Point, simplex: SimplexPoints, eps: float = 1e-8 ) -> Union[bool, np.bool_]: (p0x, p0y), (p1x, p1y), (p2x, p2y) = simplex px, py = point @@ -38,9 +40,7 @@ def fast_2d_point_in_simplex( return (t >= -eps) and (s + t <= 1 + eps) -def point_in_simplex( - point: Any, simplex: Simplex, eps: float = 1e-8 -) -> Union[bool, np.bool_]: +def point_in_simplex(point: Point, simplex: SimplexPoints, eps: float = 1e-8) -> bool: if len(point) == 2: return fast_2d_point_in_simplex(point, simplex, eps) @@ -51,7 +51,7 @@ def point_in_simplex( return all(alpha > -eps) and sum(alpha) < 1 + eps -def fast_2d_circumcircle(points: np.ndarray,) -> Tuple[Tuple[float, float], float]: +def fast_2d_circumcircle(points: Iterable[Point]) -> Tuple[Tuple[float, float], float]: """Compute the center and radius of the circumscribed circle of a triangle Parameters @@ -88,7 +88,7 @@ def fast_2d_circumcircle(points: np.ndarray,) -> Tuple[Tuple[float, float], floa def fast_3d_circumcircle( - points: np.ndarray, + points: Iterable[Point], ) -> Tuple[Tuple[float, float, float], float]: """Compute the center and radius of the circumscribed shpere of a simplex. @@ -140,7 +140,7 @@ def fast_det(matrix: np.ndarray) -> float: return np.linalg.det(matrix) -def circumsphere(pts: np.ndarray,) -> Tuple[Tuple[float, ...], float]: +def circumsphere(pts: np.ndarray) -> Tuple[Tuple[float, ...], float]: dim = len(pts) - 1 if dim == 2: return fast_2d_circumcircle(pts) @@ -193,10 +193,12 @@ def orientation(face: np.ndarray, origin: np.ndarray) -> int: def is_iterable_and_sized(obj: Any) -> bool: - return isinstance(obj, Iterable) and isinstance(obj, Sized) + return isinstance(obj, collections.abc.Iterable) and isinstance( + obj, collections.abc.Sized + ) -def simplex_volume_in_embedding(vertices: List[Tuple[float, ...]]) -> float: +def simplex_volume_in_embedding(vertices: Iterable[Point]) -> float: """Calculate the volume of a simplex in a higher dimensional embedding. That is: dim > len(vertices) - 1. For example if you would like to know the surface area of a triangle in a 3d space. @@ -277,7 +279,7 @@ class Triangulation: or more simplices in the """ - def __init__(self, coords: np.ndarray) -> None: + def __init__(self, coords: Iterable[Point]) -> None: if not is_iterable_and_sized(coords): raise TypeError("Please provide a 2-dimensional list of points") coords = list(coords) @@ -305,10 +307,10 @@ def __init__(self, coords: np.ndarray) -> None: "(the points are linearly dependent)" ) - self.vertices = list(coords) - self.simplices = set() + self.vertices: List[Point] = list(coords) + self.simplices: Set[Simplex] = set() # initialise empty set for each vertex - self.vertex_to_simplices = [set() for _ in coords] + self.vertex_to_simplices: List[Set[Simplex]] = [set() for _ in coords] # find a Delaunay triangulation to start with, then we will throw it # away and continue with our own algorithm @@ -328,16 +330,16 @@ def add_simplex(self, simplex: Simplex) -> None: for vertex in simplex: self.vertex_to_simplices[vertex].add(simplex) - def get_vertices(self, indices: Sequence[int]) -> Any: + def get_vertices(self, indices: Sequence[int]) -> List[Optional[Point]]: return [self.get_vertex(i) for i in indices] - def get_vertex(self, index: Optional[int]) -> Any: + def get_vertex(self, index: Optional[int]) -> Optional[Point]: if index is None: return None return self.vertices[index] def get_reduced_simplex( - self, point: Any, simplex: Simplex, eps: float = 1e-8 + self, point: Point, simplex: Simplex, eps: float = 1e-8 ) -> list: """Check whether vertex lies within a simplex. @@ -364,12 +366,12 @@ def get_reduced_simplex( return [simplex[i] for i in result] def point_in_simplex( - self, point: Any, simplex: Simplex, eps: float = 1e-8 - ) -> Union[bool, np.bool_]: + self, point: Point, simplex: Simplex, eps: float = 1e-8 + ) -> bool: vertices = self.get_vertices(simplex) return point_in_simplex(point, vertices, eps) - def locate_point(self, point: Any) -> Any: + def locate_point(self, point: Point) -> Simplex: """Find to which simplex the point belongs. Return indices of the simplex containing the point. @@ -385,8 +387,11 @@ def dim(self) -> int: return len(self.vertices[0]) def faces( - self, dim: None = None, simplices: Optional[Any] = None, vertices: None = None - ) -> Iterator[Any]: + self, + dim: Optional[int] = None, + simplices: Optional[Iterable[Simplex]] = None, + vertices: Optional[Iterable[int]] = None, + ) -> Iterator[Tuple[int, ...]]: """Iterator over faces of a simplex or vertex sequence.""" if dim is None: dim = self.dim @@ -407,11 +412,11 @@ def faces( else: return faces - def containing(self, face): + def containing(self, face: Tuple[int, ...]) -> Set[Simplex]: """Simplices containing a face.""" return set.intersection(*(self.vertex_to_simplices[i] for i in face)) - def _extend_hull(self, new_vertex: Any, eps: float = 1e-8) -> Any: + def _extend_hull(self, new_vertex: Point, eps: float = 1e-8) -> Set[Simplex]: # count multiplicities in order to get all hull faces multiplicities = Counter(face for face in self.faces()) hull_faces = [face for face, count in multiplicities.items() if count == 1] @@ -471,7 +476,7 @@ def circumscribed_circle( def point_in_cicumcircle( self, pt_index: int, simplex: Simplex, transform: np.ndarray - ) -> np.bool_: + ) -> bool: # return self.fast_point_in_circumcircle(pt_index, simplex, transform) eps = 1e-8 @@ -487,9 +492,9 @@ def default_transform(self) -> np.ndarray: def bowyer_watson( self, pt_index: int, - containing_simplex: Optional[Any] = None, + containing_simplex: Optional[Simplex] = None, transform: Optional[np.ndarray] = None, - ) -> Any: + ) -> Tuple[Set[Simplex], Set[Simplex]]: """Modified Bowyer-Watson point adding algorithm. Create a hole in the triangulation around the new point, @@ -549,7 +554,7 @@ def bowyer_watson( new_triangles = self.vertex_to_simplices[pt_index] return bad_triangles - new_triangles, new_triangles - bad_triangles - def _simplex_is_almost_flat(self, simplex: Simplex) -> np.bool_: + def _simplex_is_almost_flat(self, simplex: Simplex) -> bool: return self._relative_volume(simplex) < 1e-8 def _relative_volume(self, simplex: Simplex) -> float: @@ -565,8 +570,8 @@ def _relative_volume(self, simplex: Simplex) -> float: def add_point( self, - point: Any, - simplex: Optional[Any] = None, + point: Point, + simplex: Optional[Simplex] = None, transform: Optional[np.ndarray] = None, ) -> Any: """Add a new vertex and create simplices as appropriate. @@ -575,13 +580,13 @@ def add_point( ---------- point : float vector Coordinates of the point to be added. - transform : N*N matrix of floats - Multiplication matrix to apply to the point (and neighbouring - simplices) when running the Bowyer Watson method. simplex : tuple of ints, optional Simplex containing the point. Empty tuple indicates points outside the hull. If not provided, the algorithm costs O(N), so this should be used whenever possible. + transform : N*N matrix of floats + Multiplication matrix to apply to the point (and neighbouring + simplices) when running the Bowyer Watson method. """ point = tuple(point) if simplex is None: @@ -626,7 +631,7 @@ def volume(self, simplex: Simplex) -> float: def volumes(self) -> List[float]: return [self.volume(sim) for sim in self.simplices] - def reference_invariant(self): + def reference_invariant(self) -> bool: """vertex_to_simplices and simplices are compatible.""" for vertex in range(len(self.vertices)): if any(vertex not in tri for tri in self.vertex_to_simplices[vertex]): @@ -640,26 +645,28 @@ def vertex_invariant(self, vertex): """Simplices originating from a vertex don't overlap.""" raise NotImplementedError - def get_neighbors_from_vertices(self, simplex: Simplex) -> Any: + def get_neighbors_from_vertices(self, simplex: Simplex) -> Set[Simplex]: return set.union(*[self.vertex_to_simplices[p] for p in simplex]) - def get_face_sharing_neighbors(self, neighbors: Any, simplex: Simplex) -> Any: + def get_face_sharing_neighbors( + self, neighbors: Set[Simplex], simplex: Simplex + ) -> Set[Simplex]: """Keep only the simplices sharing a whole face with simplex.""" return { simpl for simpl in neighbors if len(set(simpl) & set(simplex)) == self.dim } # they share a face - def get_simplices_attached_to_points(self, indices: Any) -> Any: + def get_simplices_attached_to_points(self, indices: Simplex) -> Set[Simplex]: # Get all simplices that share at least a point with the simplex neighbors = self.get_neighbors_from_vertices(indices) return self.get_face_sharing_neighbors(neighbors, indices) - def get_opposing_vertices(self, simplex: Simplex,) -> Any: + def get_opposing_vertices(self, simplex: Simplex) -> Tuple[int, ...]: if simplex not in self.simplices: raise ValueError("Provided simplex is not part of the triangulation") neighbors = self.get_simplices_attached_to_points(simplex) - def find_opposing_vertex(vertex): + def find_opposing_vertex(vertex: int): # find the simplex: simp = next((x for x in neighbors if vertex not in x), None) if simp is None: