From 8513e7a7c8e8e26985305d871cf64189ab5c6551 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20K=C3=B6gl?= Date: Sun, 25 Feb 2018 15:55:43 +0100 Subject: [PATCH] Correctly handle non-euclidean distances (fixes #40) --- kdtree.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/kdtree.py b/kdtree.py index 68ef492..c3475e7 100644 --- a/kdtree.py +++ b/kdtree.py @@ -387,13 +387,17 @@ def axis_dist(self, point, axis): return math.pow(self.data[axis] - point[axis], 2) - def dist(self, point): + def dist(self, point, axis=None): """ Squared distance between the current Node and the given point """ - r = range(self.dimensions) - return sum([self.axis_dist(point, i) for i in r]) + if axis is None: + axes = range(self.dimensions) + else: + axes = [axis] + + return sum([self.axis_dist(point, i) for i in axes]) def search_knn(self, point, k, dist=None): @@ -406,7 +410,9 @@ def search_knn(self, point, k, dist=None): distances. dist is a distance function, expecting two points and returning a - distance value. Distance values can be any comparable type. + distance value. dist should expect an optional `axis` parameter. If + given, the distance on the specified axis should be calculated. + Distance values can be any comparable type. The result is an ordered list of (node, distance) tuples. """ @@ -414,10 +420,11 @@ def search_knn(self, point, k, dist=None): if k < 1: raise ValueError("k must be greater than 0.") - if dist is None: - get_dist = lambda n: n.dist(point) - else: - get_dist = lambda n: dist(n.data, point) + def get_dist(n, axis=None): + if dist is None: + return n.dist(point, axis=None) + else: + return dist(n.data, point, axis=None) results = [] @@ -446,12 +453,10 @@ def _search_node(self, point, k, results, get_dist, counter): heapq.heapreplace(results, item) else: heapq.heappush(results, item) - # get the splitting plane + split_plane = self.data[self.axis] - # get the squared distance between the point and the splitting plane - # (squared since all distances are squared). - plane_dist = point[self.axis] - split_plane - plane_dist2 = plane_dist * plane_dist + pt = KDNode(point, dimensions=self.dimensions) + plane_dist = get_dist(pt, axis=self.axis) # Search the side of the splitting plane that the point is in if point[self.axis] < split_plane: @@ -463,7 +468,7 @@ def _search_node(self, point, k, results, get_dist, counter): # Search the other side of the splitting plane if it may contain # points closer than the farthest point in the current results. - if -plane_dist2 > results[0][0] or len(results) < k: + if -plane_dist > results[0][0] or len(results) < k: if point[self.axis] < self.data[self.axis]: if self.right is not None: self.right._search_node(point, k, results, get_dist,