Skip to content

Commit

Permalink
Correctly handle non-euclidean distances (fixes #40)
Browse files Browse the repository at this point in the history
  • Loading branch information
stefankoegl committed Feb 25, 2018
1 parent 587edc7 commit 8513e7a
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions kdtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,13 +387,17 @@ def axis_dist(self, point, axis):
return math.pow(self.data[axis] - point[axis], 2)


def dist(self, point):
def dist(self, point, axis=None):
"""
Squared distance between the current Node
and the given point
"""
r = range(self.dimensions)
return sum([self.axis_dist(point, i) for i in r])
if axis is None:
axes = range(self.dimensions)
else:
axes = [axis]

return sum([self.axis_dist(point, i) for i in axes])


def search_knn(self, point, k, dist=None):
Expand All @@ -406,18 +410,21 @@ def search_knn(self, point, k, dist=None):
distances.
dist is a distance function, expecting two points and returning a
distance value. Distance values can be any comparable type.
distance value. dist should expect an optional `axis` parameter. If
given, the distance on the specified axis should be calculated.
Distance values can be any comparable type.
The result is an ordered list of (node, distance) tuples.
"""

if k < 1:
raise ValueError("k must be greater than 0.")

if dist is None:
get_dist = lambda n: n.dist(point)
else:
get_dist = lambda n: dist(n.data, point)
def get_dist(n, axis=None):
if dist is None:
return n.dist(point, axis=None)
else:
return dist(n.data, point, axis=None)

results = []

Expand Down Expand Up @@ -446,12 +453,10 @@ def _search_node(self, point, k, results, get_dist, counter):
heapq.heapreplace(results, item)
else:
heapq.heappush(results, item)
# get the splitting plane

split_plane = self.data[self.axis]
# get the squared distance between the point and the splitting plane
# (squared since all distances are squared).
plane_dist = point[self.axis] - split_plane
plane_dist2 = plane_dist * plane_dist
pt = KDNode(point, dimensions=self.dimensions)
plane_dist = get_dist(pt, axis=self.axis)

# Search the side of the splitting plane that the point is in
if point[self.axis] < split_plane:
Expand All @@ -463,7 +468,7 @@ def _search_node(self, point, k, results, get_dist, counter):

# Search the other side of the splitting plane if it may contain
# points closer than the farthest point in the current results.
if -plane_dist2 > results[0][0] or len(results) < k:
if -plane_dist > results[0][0] or len(results) < k:
if point[self.axis] < self.data[self.axis]:
if self.right is not None:
self.right._search_node(point, k, results, get_dist,
Expand Down

0 comments on commit 8513e7a

Please sign in to comment.