Skip to content

Commit

Permalink
Merge pull request #39 from tataratat/ft-rknn
Browse files Browse the repository at this point in the history
add rknn
  • Loading branch information
j042 authored Jul 26, 2024
2 parents c46edf1 + 40775a3 commit 270e9b6
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 39 deletions.
30 changes: 30 additions & 0 deletions napf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,36 @@ def radius_search(self, queries, radius, return_sorted, nthread=None):
nthread,
)

def rknn_search(self, queries, radius, n_nearest, nthread=None):
"""
Searches for k-nearest neighbors within the radius.
With insufficient neighbors, rest of the return values will have dummy
filled in - they will be the maximum value of each data type.
If the dtype is signed, it will have a negative value.
Parameters
----------
queries: (m, d) np.ndarray
radius: float
n_nearest: int
nthread: int
Returns
-------
ids_and_distances: tuple
((m, 1) np.ndarray - uint ids,
(m, 1) np.ndarray - double dists)
"""
if nthread is None:
nthread = self.nthread

return self.core_tree.rknn_search(
enforce_contiguous(queries, self.dtype),
radius,
n_nearest,
nthread,
)

def query_ball_point(self, queries, radius, return_sorted, nthread=None):
"""
scipy-like KDTree query_ball_point call.
Expand Down
78 changes: 75 additions & 3 deletions src/python/pykdt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
#include <algorithm>
#include <iostream>
#include <iterator>
#include <limits>
#include <memory>
#include <thread>
#include <type_traits>
#include <utility>

#include <pybind11/numpy.h>
Expand All @@ -30,6 +32,16 @@ using IndexType = typename UIntVector::value_type;
using IndexVector = UIntVector;
using IndexVectorVector = UIntVectorVector;

// helper function to get dummy values
template<typename Type>
Type max_and_negative_if_signed() {
Type max_val = std::numeric_limits<Type>::max();
if (std::is_signed<Type>::value) {
max_val = -max_val;
}
return max_val;
}

template<typename DataT, unsigned int metric>
class PyKDT {
public:
Expand Down Expand Up @@ -201,6 +213,58 @@ class PyKDT {
return py::make_tuple<py::return_value_policy::move>(out_indices, out_dist);
}

/* radius knn search */
py::tuple rknn_search(const py::array_t<DataT> qpts,
const DistT radius,
const int n_nearest,
const int nthread) {

// in
const py::buffer_info q_buf = qpts.request();
const DataT* q_buf_ptr = static_cast<DataT*>(q_buf.ptr);
const int qlen = q_buf.shape[0];

// out
py::array_t<IndexType> indices({qlen, n_nearest});
py::array_t<DistT> distances({qlen, n_nearest});
// get pointers
IndexType* i_ptr = static_cast<IndexType*>(indices.request().ptr);
DistT* d_ptr = static_cast<DistT*>(distances.request().ptr);

auto searchradiusknn = [&](int begin, int end, int) {
// get pointers for this chunk/thread
IndexType* t_i_ptr = &i_ptr[begin * dim_];
DistT* t_d_ptr = &d_ptr[begin * dim_];
// dummpy values - put max value

const DistT dummy_dist = max_and_negative_if_signed<DistT>();
const IndexType dummy_index = max_and_negative_if_signed<IndexType>();
for (int i{begin}; i < end; i++) {

// call
const auto n_matches = tree_->rknnSearch(&q_buf_ptr[i * dim_],
n_nearest,
t_i_ptr,
t_d_ptr,
radius);

// in case nmatches < n_nearest, we fill the rest with dummy values
for (int j{static_cast<int>(n_matches)}; j < n_nearest; ++j) {
t_i_ptr[j] = dummy_index;
t_d_ptr[j] = dummy_dist;
}

// next pointers
t_i_ptr += n_nearest;
t_d_ptr += n_nearest;
}
};

nthread_execution(searchradiusknn, qlen, nthread);

return py::make_tuple<py::return_value_policy::move>(indices, distances);
}

/// @brief
/// @param qpts
/// @param radius
Expand Down Expand Up @@ -257,8 +321,8 @@ class PyKDT {
return out_indices;
}

/// @brief unique points, indices of unique points, inverse indices to create
/// original points base on unique points.
/// @brief unique points, indices of unique points, inverse indices to
/// create original points base on unique points.
/// @param radius
/// @param return_intersection returns neighbor
/// @param nthread
Expand Down Expand Up @@ -312,7 +376,8 @@ class PyKDT {
this_intersection.emplace_back(match.first);
}
std::sort(this_intersection.begin(), this_intersection.end());
// set inverse_ids - it is the smallest neighbor (intersection) index
// set inverse_ids - it is the smallest neighbor (intersection)
// index
unique_id = this_intersection[0];
} else {
// here, we'd only need min.
Expand Down Expand Up @@ -433,6 +498,13 @@ void add_kdt_pyclass(py::module_& m, const char* class_name) {
py::arg("return_sorted"),
py::arg("nthread"),
py::return_value_policy::move)
.def("rknn_search",
&KDT::rknn_search,
py::arg("queries"),
py::arg("radius"),
py::arg("n_nearest"),
py::arg("nthread"),
py::return_value_policy::move)
.def("query_ball_point",
&KDT::query_ball_point,
py::arg("queries"),
Expand Down
105 changes: 69 additions & 36 deletions tests/test_init_and_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,73 +6,106 @@
import napf


def loop_all_and_test(dims, data_type, metrics, test_func):
for dim, data_t, metric in itertools.product(dims, data_type, metrics):
test_func(dim, data_t, metric)


class InitAndQueryTest(unittest.TestCase):
def test_init_and_query(self):
dims = [
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
]
dims = list(range(1, 21))
data_type = ["float64", "float32", "int64", "int32"]
metric = [1, 2]
metrics = [1, 2]

for d, dt, m in itertools.product(dims, data_type, metric):
def test_func(dim, data_t, metric):
# try to initialize the tree with
n_data = 100
randata = (np.random.random((n_data, d)) * n_data).astype(dt)
randata = (np.random.random((n_data, dim)) * n_data).astype(data_t)
tree_data = np.vstack(
(randata, np.array([[-1] * d], dtype=dt))
(randata, np.array([[-1] * dim], dtype=data_t))
).astype(
dt
data_t
) # make sure one more time

kdt = napf.KDT(tree_data, m)
kdt = napf.KDT(tree_data, metric)

# init test
qname = type(kdt.core_tree).__qualname__
assert kdt.core_tree.dim == d, f"wrong dim init for {qname}"
assert kdt.core_tree.dim == dim, f"wrong dim init for {qname}"
assert (
kdt.core_tree.tree_data.dtype == dt
kdt.core_tree.tree_data.dtype == data_t
), f"wrong dtype init for {qname}"
assert kdt.core_tree.metric == m, f"wrong metric init for {qname}"
assert (
kdt.core_tree.metric == metric
), f"wrong metric init for {qname}"

# query test should be all be dist = 0 and id = n_data
dist, ids = kdt.query([[-1] * d], nthread=1)
dist, ids = kdt.query([[-1] * dim], nthread=1)
assert np.isclose(dist[0], 0), f"wrong dist query for {qname}"
assert ids[0] == n_data, f"wrong index query for {qname}"

# test knn_search
dist, ids = kdt.knn_search(kdt.tree_data, 1, nthread=2)
assert np.isclose(
dist.sum(), 0
), f"wrong dist query for {qname}, dim {d}"
), f"wrong dist query for {qname}, dim {dim}"

# skip integer types for index check
# as it is too easy for them to have duplicates.
# with default options of nanoflann, this will return smaller index
if dt.startswith("int"):
continue
if data_t.startswith("int"):
return

assert np.all(
ids.ravel() == np.arange(len(kdt.tree_data))
), f"wrong index query for {qname}, dim {d}"
), f"wrong index query for {qname}, dim {dim}"

loop_all_and_test(dims, data_type, metrics, test_func)

def test_rknn(self):
dims = list(range(1, 21))
data_type = ["float64", "float32"]
metrics = [1, 2]

def test_func(dim, data_t, metric):
n_data = 100
random_data = (np.random.random((n_data, dim)) * n_data).astype(
data_t
)

# create exactly two matches
tree_data = np.vstack((random_data, random_data)).astype(data_t)

kdt = napf.KDT(tree_data, metric)

nn = 1
ids, dists = kdt.rknn_search(random_data, 1e-10, nn)
assert ids.shape[1] == nn
assert dists.shape[1] == nn
# we can't guarantee ids but distance should be all zero
assert np.isclose(dists.sum(), 0)

# with two queries, also zero dist
nn = 2
ids, dists = kdt.rknn_search(random_data, 1e-10, nn)
assert dists.shape[1] == nn
assert ids.shape[1] == nn
assert np.isclose(dists.sum(), 0)

# with five, only first two queires are zero and last three are
# dummy values
nn = 5
ids, dists = kdt.rknn_search(random_data, 1e-10, 5)
assert dists.shape[1] == nn
assert ids.shape[1] == nn
assert np.isclose(dists[:, :2].sum(), 0)
# dummy dist is negative for floats
assert (dists[:, 2:] < 0).all()
# dummy id must be bigger than n_data, as long as n_data does
# not overflow
assert (abs(ids[:, 2:]) > n_data * 2).all()

loop_all_and_test(dims, data_type, metrics, test_func)


if __name__ == "__main__":
Expand Down

0 comments on commit 270e9b6

Please sign in to comment.