From 33ccb17e7cc896ff22917ffa68af061c48b2c187 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Fri, 7 Jan 2022 13:32:20 +0100 Subject: [PATCH 1/2] parallelizing knn and inrange searches --- Project.toml | 2 +- src/NearestNeighbors.jl | 1 + src/inrange.jl | 2 +- src/knn.jl | 8 ++++---- test/test_inrange.jl | 4 ++-- test/test_knn.jl | 4 ++-- test/test_monkey.jl | 6 +++--- 7 files changed, 14 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 5b5d135..0380f83 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NearestNeighbors" uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.9" +version = "0.4.10" [deps] Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 6aa4796..7e8c5b1 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -5,6 +5,7 @@ import Distances: Metric, result_type, eval_reduce, eval_end, eval_op, eval_star using StaticArrays import Base.show +using Base.Threads: @threads export NNTree, BruteTree, KDTree, BallTree, DataFreeTree export knn, nn, inrange # TODOs? , allpairs, distmat, npairs diff --git a/src/inrange.jl b/src/inrange.jl index be5c7a8..d678cad 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -15,7 +15,7 @@ function inrange(tree::NNTree, idxs = [Vector{Int}() for _ in 1:length(points)] - for i in 1:length(points) + @threads for i in 1:length(points) inrange_point!(tree, points[i], radius, sortres, idxs[i]) end return idxs diff --git a/src/knn.jl b/src/knn.jl index 6fb4bb1..34bd2ce 100644 --- a/src/knn.jl +++ b/src/knn.jl @@ -11,16 +11,16 @@ end Performs a lookup of the `k` nearest neigbours to the `points` from the data in the `tree`. If `sortres = true` the result is sorted such that the results are in the order of increasing distance to the point. `skip` is an optional predicate -to determine if a point that would be returned should be skipped based on its +to determine if a point that would be returned should be skipped based on its index. """ function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: AbstractVector, F<:Function} check_input(tree, points) check_k(tree, k) n_points = length(points) - dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points] - idxs = [Vector{Int}(undef, k) for _ in 1:n_points] - for i in 1:n_points + dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points] + idxs = [Vector{Int}(undef, k) for _ in 1:n_points] + @threads for i in 1:n_points knn_point!(tree, points[i], sortres, dists[i], idxs[i], skip) end return idxs, dists diff --git a/test/test_inrange.jl b/test/test_inrange.jl index 652faf2..1988cf5 100644 --- a/test/test_inrange.jl +++ b/test/test_inrange.jl @@ -1,7 +1,7 @@ # Does not test leafsize @testset "inrange" begin - @testset "metric" for metric in [Euclidean()] - @testset "tree type" for TreeType in trees_with_brute + @testset "metric $Metric" for metric in [Euclidean()] + @testset "tree type $TreeType" for TreeType in trees_with_brute function test(data) tree = TreeType(data, metric; leafsize=2) dosort = true diff --git a/test/test_knn.jl b/test/test_knn.jl index 3648ae5..7c94776 100644 --- a/test/test_knn.jl +++ b/test/test_knn.jl @@ -3,8 +3,8 @@ import Distances.evaluate @testset "knn" begin - @testset "metric" for metric in [metrics; WeightedEuclidean(ones(2))] - @testset "tree type" for TreeType in trees_with_brute + @testset "metric $metric" for metric in [metrics; WeightedEuclidean(ones(2))] + @testset "tree type $TreeType" for TreeType in trees_with_brute function test(data) tree = TreeType(data, metric; leafsize=2) diff --git a/test/test_monkey.jl b/test/test_monkey.jl index a3fdb0b..12f41f2 100644 --- a/test/test_monkey.jl +++ b/test/test_monkey.jl @@ -3,9 +3,9 @@ import NearestNeighbors.MinkowskiMetric # some edge case has been missed in the real tests -@testset "metric" for metric in fullmetrics - @testset "tree type" for TreeType in trees_with_brute - @testset "type" for T in (Float32, Float64) +@testset "metric $metric" for metric in fullmetrics + @testset "tree type $TreeType" for TreeType in trees_with_brute + @testset "element type $T" for T in (Float32, Float64) @testset "knn monkey" begin # Checks that we find existing point in the tree # and that it is the closest From 18b93cbaf2d830e160b51c44df7ab0038c0c35aa Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Fri, 7 Jan 2022 19:56:24 +0100 Subject: [PATCH 2/2] parallelizing BallTree construction --- .github/workflows/CI.yml | 2 +- Project.toml | 5 +- src/NearestNeighbors.jl | 2 +- src/ball_tree.jl | 100 ++++++++++++++++++++++++++++++--------- src/hyperspheres.jl | 64 +++++++++---------------- test/runtests.jl | 5 +- test/test_monkey.jl | 28 +++++++---- 7 files changed, 126 insertions(+), 80 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index dc12ba9..2e4cdd6 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -15,7 +15,7 @@ jobs: fail-fast: false matrix: version: - - '1.0' + - '1.3' - '1' - 'nightly' os: diff --git a/Project.toml b/Project.toml index 0380f83..7661bd2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,16 @@ name = "NearestNeighbors" uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.10" +version = "0.5.0" [deps] Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] Distances = "0.9, 0.10" StaticArrays = "0.9, 0.10, 0.11, 0.12, 1.0" -julia = "1.0" +julia = "1.3" [extras] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/NearestNeighbors.jl b/src/NearestNeighbors.jl index 7e8c5b1..a1a33b1 100644 --- a/src/NearestNeighbors.jl +++ b/src/NearestNeighbors.jl @@ -5,7 +5,7 @@ import Distances: Metric, result_type, eval_reduce, eval_end, eval_op, eval_star using StaticArrays import Base.show -using Base.Threads: @threads +using Base.Threads export NNTree, BruteTree, KDTree, BallTree, DataFreeTree export knn, nn, inrange # TODOs? , allpairs, distmat, npairs diff --git a/src/ball_tree.jl b/src/ball_tree.jl index cb59b95..22ef166 100644 --- a/src/ball_tree.jl +++ b/src/ball_tree.jl @@ -12,16 +12,8 @@ struct BallTree{V <: AbstractVector,N,T,M <: Metric} <: NNTree{V,M} reordered::Bool # If the data has been reordered end -# When we create the bounding spheres we need some temporary arrays. -# We create a type to hold them to not allocate these arrays at every -# function call and to reduce the number of parameters in the tree builder. -struct ArrayBuffers{N,T <: AbstractFloat} - center::MVector{N,T} -end - -function ArrayBuffers(::Type{Val{N}}, ::Type{T}) where {N, T} - ArrayBuffers(zeros(MVector{N,T})) -end +# minimum number of data points above which parallelization is triggered by default +const DEFAULT_BALLTREE_MIN_PARALLEL_SIZE = 1024 """ BallTree(data [, metric = Euclidean(), leafsize = 10]) -> balltree @@ -33,6 +25,8 @@ function BallTree(data::AbstractVector{V}, leafsize::Int = 10, reorder::Bool = true, storedata::Bool = true, + parallel::Bool = true, + parallel_size::Int = DEFAULT_BALLTREE_MIN_PARALLEL_SIZE, reorderbuffer::Vector{V} = Vector{V}()) where {V <: AbstractArray, M <: Metric} reorder = !isempty(reorderbuffer) || (storedata ? reorder : false) @@ -40,7 +34,6 @@ function BallTree(data::AbstractVector{V}, n_d = length(V) n_p = length(data) - array_buffs = ArrayBuffers(Val{length(V)}, get_T(eltype(V))) indices = collect(1:n_p) # Bottom up creation of hyper spheres so need spheres even for leafs) @@ -70,7 +63,8 @@ function BallTree(data::AbstractVector{V}, if n_p > 0 # Call the recursive BallTree builder build_BallTree(1, data, data_reordered, hyper_spheres, metric, indices, indices_reordered, - 1, length(data), tree_data, array_buffs, reorder) + 1, length(data), tree_data, reorder, Val(parallel), parallel_size) + end if reorder @@ -86,6 +80,8 @@ function BallTree(data::AbstractVecOrMat{T}, leafsize::Int = 10, storedata::Bool = true, reorder::Bool = true, + parallel::Bool = true, + parallel_size::Int = DEFAULT_BALLTREE_MIN_PARALLEL_SIZE, reorderbuffer::Matrix{T} = Matrix{T}(undef, 0, 0)) where {T <: AbstractFloat, M <: Metric} dim = size(data, 1) npoints = size(data, 2) @@ -96,7 +92,7 @@ function BallTree(data::AbstractVecOrMat{T}, reorderbuffer_points = copy_svec(T, reorderbuffer, Val(dim)) end BallTree(points, metric, leafsize = leafsize, storedata = storedata, reorder = reorder, - reorderbuffer = reorderbuffer_points) + parallel = parallel, parallel_size = parallel_size, reorderbuffer = reorderbuffer_points) end # Recursive function to build the tree. @@ -110,8 +106,9 @@ function build_BallTree(index::Int, low::Int, high::Int, tree_data::TreeData, - array_buffs::ArrayBuffers{N,T}, - reorder::Bool) where {V <: AbstractVector, N, T} + reorder::Bool, + parallel::Val{false}, + parallel_size::Int = 0) where {V <: AbstractVector, N, T} n_points = high - low + 1 # Points left if n_points <= tree_data.leafsize @@ -119,7 +116,7 @@ function build_BallTree(index::Int, reorder_data!(data_reordered, data, index, indices, indices_reordered, tree_data) end # Create bounding sphere of points in leaf node by brute force - hyper_spheres[index] = create_bsphere(data, metric, indices, low, high, array_buffs) + hyper_spheres[index] = create_bsphere(data, metric, indices, low, high) return end @@ -132,22 +129,79 @@ function build_BallTree(index::Int, # Sort the data at the mid_idx boundary using the split_dim # to compare - select_spec!(indices, mid_idx, low, high, data, split_dim) + select_spec!(indices, mid_idx, low, high, data, split_dim) # culprit? technically, low and high should be disjoint for different threads build_BallTree(getleft(index), data, data_reordered, hyper_spheres, metric, - indices, indices_reordered, low, mid_idx - 1, - tree_data, array_buffs, reorder) + indices, indices_reordered, low, mid_idx - 1, + tree_data, reorder, parallel) build_BallTree(getright(index), data, data_reordered, hyper_spheres, metric, - indices, indices_reordered, mid_idx, high, - tree_data, array_buffs, reorder) + indices, indices_reordered, mid_idx, high, + tree_data, reorder, parallel) + + # Finally create bounding hyper sphere from the two children's hyper spheres + hyper_spheres[index] = create_bsphere(metric, hyper_spheres[getleft(index)], + hyper_spheres[getright(index)]) + return +end + +# Parallelized recursive function to build the tree. +function build_BallTree(index::Int, + data::Vector{V}, + data_reordered::Vector{V}, + hyper_spheres::Vector{HyperSphere{N,T}}, + metric::Metric, + indices::Vector{Int}, + indices_reordered::Vector{Int}, + low::Int, + high::Int, + tree_data::TreeData, + reorder::Bool, + parallel::Val{true}, + parallel_size::Int) where {V <: AbstractVector, N, T} + + n_points = high - low + 1 # Points left + if n_points <= tree_data.leafsize + if reorder + reorder_data!(data_reordered, data, index, indices, indices_reordered, tree_data) + end + # Create bounding sphere of points in leaf node by brute force + hyper_spheres[index] = create_bsphere(data, metric, indices, low, high) + return + end + + # Find split such that one of the sub trees has 2^p points + # and the left sub tree has more points + mid_idx = find_split(low, tree_data.leafsize, n_points) + + # Brute force to find the dimension with the largest spread + split_dim = find_largest_spread(data, indices, low, high) + + # Sort the data at the mid_idx boundary using the split_dim + # to compare + select_spec!(indices, mid_idx, low, high, data, split_dim) # culprit? technically, low and high should be disjoint for different threads + + @sync begin + left_n_points = mid_idx - low + left_parallel = Val(left_n_points > parallel_size) + @spawn build_BallTree(getleft(index), data, data_reordered, hyper_spheres, metric, + indices, indices_reordered, low, mid_idx - 1, + tree_data, reorder, left_parallel, parallel_size) + + right_n_points = high - mid_idx + 1 + right_parallel = Val(right_n_points > parallel_size) + @spawn build_BallTree(getright(index), data, data_reordered, hyper_spheres, metric, + indices, indices_reordered, mid_idx, high, + tree_data, reorder, right_parallel, parallel_size) + end # Finally create bounding hyper sphere from the two children's hyper spheres hyper_spheres[index] = create_bsphere(metric, hyper_spheres[getleft(index)], - hyper_spheres[getright(index)], - array_buffs) + hyper_spheres[getright(index)]) + return end + function _knn(tree::BallTree, point::AbstractVector, best_idxs::AbstractVector{Int}, diff --git a/src/hyperspheres.jl b/src/hyperspheres.jl index 2dd6766..d6d6d62 100644 --- a/src/hyperspheres.jl +++ b/src/hyperspheres.jl @@ -7,6 +7,8 @@ end HyperSphere(center::SVector{N,T1}, r::T2) where {N, T1, T2} = HyperSphere(center, convert(T1, r)) +Base.:(==)(A::HyperSphere, B::HyperSphere) = A.center == B.center && A.r == B.r + @inline function intersects(m::M, s1::HyperSphere{N,T}, s2::HyperSphere{N,T}) where {T <: AbstractFloat, N, M <: Metric} @@ -19,55 +21,22 @@ end evaluate(m, s1.center, s2.center) + s1.r <= s2.r end -@inline function interpolate(::M, - c1::V, - c2::V, - x, - d, - ab) where {V <: AbstractVector, M <: NormMetric} - alpha = x / d - @assert length(c1) == length(c2) - @inbounds for i in eachindex(ab.center) - ab.center[i] = (1 - alpha) .* c1[i] + alpha .* c2[i] - end - return ab.center, true -end - -@inline function interpolate(::M, - c1::V, - ::V, - ::Any, - ::Any, - ::Any) where {V <: AbstractVector, M <: Metric} - return c1, false -end - -function create_bsphere(data::AbstractVector{V}, metric::Metric, indices::Vector{Int}, low, high, ab) where {V} - n_dim = size(data, 1) - n_points = high - low + 1 - # First find center of all points - fill!(ab.center, 0.0) - for i in low:high - for j in 1:length(ab.center) - ab.center[j] += data[indices[i]][j] - end - end - ab.center .*= 1 / n_points - +# versions with no array buffer - still not allocating in sequential BallTree construction +using Statistics: mean +function create_bsphere(data::AbstractVector{V}, metric::Metric, indices::Vector{Int}, low, high) where {V} + # find center + center = mean(@views(data[indices[low:high]])) # Then find r r = zero(get_T(eltype(V))) for i in low:high - r = max(r, evaluate(metric, data[indices[i]], ab.center)) + r = max(r, evaluate(metric, data[indices[i]], center)) end r += eps(get_T(eltype(V))) - return HyperSphere(SVector{length(V),eltype(V)}(ab.center), r) + return HyperSphere(SVector{length(V),eltype(V)}(center), r) end # Creates a bounding sphere from two other spheres -function create_bsphere(m::Metric, - s1::HyperSphere{N,T}, - s2::HyperSphere{N,T}, - ab) where {N, T <: AbstractFloat} +function create_bsphere(m::Metric, s1::HyperSphere{N,T}, s2::HyperSphere{N,T}) where {N, T <: AbstractFloat} if encloses(m, s1, s2) return HyperSphere(s2.center, s2.r) elseif encloses(m, s2, s1) @@ -79,7 +48,7 @@ function create_bsphere(m::Metric, # neither s1 nor s2 contains the other) dist = evaluate(m, s1.center, s2.center) x = 0.5 * (s2.r - s1.r + dist) - center, is_exact_center = interpolate(m, s1.center, s2.center, x, dist, ab) + center, is_exact_center = interpolate(m, s1.center, s2.center, x, dist) if is_exact_center rad = 0.5 * (s2.r + s1.r + dist) else @@ -88,3 +57,14 @@ function create_bsphere(m::Metric, return HyperSphere(SVector{N,T}(center), rad) end + +@inline function interpolate(::M, c1::V, c2::V, x, d) where {V <: AbstractVector, M <: NormMetric} + length(c1) == length(c2) || throw(DimensionMismatch("interpolate arguments have length $(length(c1)) and $(length(c2))")) + alpha = x / d + center = (1 - alpha) * c1 + alpha * c2 + return center, true +end + +@inline function interpolate(::M, c1::V, ::V, ::Any, ::Any) where {V <: AbstractVector, M <: Metric} + return c1, false +end diff --git a/test/runtests.jl b/test/runtests.jl index a257562..c480d53 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,13 +6,12 @@ using LinearAlgebra using Distances: Distances, Metric, evaluate, PeriodicEuclidean struct CustomMetric1 <: Metric end -Distances.evaluate(::CustomMetric1, a::AbstractVector, b::AbstractVector) = maximum(abs.(a .- b)) +Distances.evaluate(::CustomMetric1, a::AbstractVector, b::AbstractVector) = maximum(abs, (a .- b)) function NearestNeighbors.interpolate(::CustomMetric1, a::V, b::V, x, - d, - ab) where {V <: AbstractVector} + d) where {V <: AbstractVector} idx = (abs.(b .- a) .>= d - x) c = copy(Array(a)) c[idx] = (1 - x / d) * a[idx] + (x / d) * b[idx] diff --git a/test/test_monkey.jl b/test/test_monkey.jl index 12f41f2..56834e4 100644 --- a/test/test_monkey.jl +++ b/test/test_monkey.jl @@ -1,9 +1,8 @@ import NearestNeighbors.MinkowskiMetric # This contains a bunch of random tests that should hopefully detect if # some edge case has been missed in the real tests - - @testset "metric $metric" for metric in fullmetrics + nrep = 30 @testset "tree type $TreeType" for TreeType in trees_with_brute @testset "element type $T" for T in (Float32, Float64) @testset "knn monkey" begin @@ -14,7 +13,7 @@ import NearestNeighbors.MinkowskiMetric elseif TreeType == BallTree && isa(metric, Hamming) continue end - for i in 1:30 + for i in 1:nrep dim_data = rand(1:4) size_data = rand(1000:1300) data = rand(T, dim_data, size_data) @@ -28,7 +27,7 @@ import NearestNeighbors.MinkowskiMetric end # Compares vs Brute Force - for i in 1:30 + for i in 1:nrep dim_data = rand(1:5) size_data = rand(100:151) data = rand(T, dim_data, size_data) @@ -45,7 +44,7 @@ import NearestNeighbors.MinkowskiMetric @testset "inrange monkey" begin # Test against brute force - for i in 1:30 + for i in 1:nrep dim_data = rand(1:6) size_data = rand(20:250) data = rand(T, dim_data, size_data) @@ -62,17 +61,30 @@ import NearestNeighbors.MinkowskiMetric end @testset "coupled monkey" begin - for i in 1:50 + for i in 1:nrep dim_data = rand(1:5) size_data = rand(100:1000) data = randn(T, dim_data, size_data) - tree = TreeType(data, metric; leafsize = rand(1:8)) + + lf = rand(1:8) + tree = TreeType(data, metric; leafsize = lf) + + if TreeType == BallTree # this caught a race-condition in an early version of the parallel BallTree code + tree2 = TreeType(data, metric; leafsize = lf, parallel = true, parallel_size = 0) # triggering parallel code + @test tree.data == tree2.data + @test tree.hyper_spheres[1] == tree2.hyper_spheres[1] + @test tree.indices == tree2.indices + @test tree.metric == tree2.metric + @test tree.tree_data == tree2.tree_data + @test tree.reordered == tree2.reordered + end + point = randn(dim_data) idxs_ball = Int[] r = 0.1 while length(idxs_ball) < 10 r *= 2.0 - idxs_ball = inrange(tree, point, r, true) + idxs_ball = inrange(tree, point, r, true) end idxs_knn, dists = knn(tree, point, length(idxs_ball))