Skip to content

Commit

Permalink
update sort interaction list functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rymanderson committed Nov 12, 2024
1 parent b97142f commit a2031fc
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 52 deletions.
94 changes: 48 additions & 46 deletions src/interaction_list.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,68 +246,70 @@ end
@inline preallocate_bodies_index(T::Type{<:MultiBranch{<:Any,NT}}, n) where NT = Tuple(Vector{UnitRange{Int64}}(undef, n) for _ in 1:NT)
@inline preallocate_bodies_index(T::Type{<:SingleBranch}, n) = Vector{UnitRange{Int64}}(undef, n)

function sort_list_by_target(direct_list, target_branches::Vector{TT}, source_branches::Vector{TS}, n_leaves) where {TT,TS}
target_counter = zeros(Int32, n_leaves)
place_counter = zeros(Int32, n_leaves)
direct_target_bodies = preallocate_bodies_index(TT, length(direct_list))
direct_source_bodies = preallocate_bodies_index(TS, length(direct_list))

# tally the contributions of each source
for (i_target, j_source) in direct_list
i_leaf = source_branches[i_target].i_leaf
target_counter[i_leaf] += 1
function sort_list_by_target(direct_list, target_branches::Vector{<:Branch})
# count cardinality of each target leaf in direct_list
target_counter = zeros(Int32, 2, length(target_branches))
for (i,j) in direct_list
target_counter[1,i] += 1
end

# prepare place counter
i_cum = 1
for (i,n) in enumerate(target_counter)
place_counter[i] = i_cum
i_cum += n
# cumsum cardinality to obtain an index map
target_counter[2,1] = Int32(1)
for i in 2:size(target_counter,2)
target_counter[2,i] = target_counter[2,i-1] + target_counter[1,i-1]
end

# place interactions
for (i,(i_target, j_source)) in enumerate(direct_list)
i_leaf = target_branches[i_target].i_leaf
update_direct_bodies!(direct_target_bodies, place_counter[i_leaf], target_branches[i_target].bodies_index)
update_direct_bodies!(direct_source_bodies, place_counter[i_leaf], source_branches[j_source].bodies_index)
place_counter[i_leaf] += 1
# preallocate sorted direct_list
sorted_direct_list = similar(direct_list)

# sort direct_list by source
for ij in direct_list
# get source branch index
i = ij[1]

# get and update target destination index for this branch
i_dest = target_counter[2,i]
target_counter[2,i] += Int32(1)

# place target-source pair in the sorted list
sorted_direct_list[i_dest] = ij
end

return direct_target_bodies, direct_source_bodies
return sorted_direct_list
end

function sort_list_by_source(direct_list, target_branches::Vector{TT}, source_branches::Vector{TS}, n_leaves) where {TT,TS}
source_counter = zeros(Int32, n_leaves)
place_counter = zeros(Int32, n_leaves)
direct_target_bodies = preallocate_bodies_index(TT, length(direct_list))
direct_source_bodies = preallocate_bodies_index(TS, length(direct_list))

# tally the contributions of each source
for (i_target, j_source) in direct_list
j_leaf = source_branches[j_source].i_leaf
source_counter[j_leaf] += 1
function sort_list_by_source(direct_list, source_branches::Vector{<:Branch})
# count cardinality of each source leaf in direct_list
source_counter = zeros(Int32, 2, length(source_branches))
for (i,j) in direct_list
source_counter[1,j] += Int32(1)
end

# prepare place counter
i_cum = 1
for (i,n) in enumerate(source_counter)
place_counter[i] = i_cum
i_cum += n
# cumsum cardinality to obtain an index map
source_counter[2,1] = Int32(1)
for i in 2:size(source_counter,2)
source_counter[2,i] = source_counter[2,i-1] + source_counter[1,i-1]
end

# place interactions
for (i, (i_target, j_source)) in enumerate(direct_list)
i_leaf = target_branches[i_target].i_leaf
j_leaf = source_branches[j_source].i_leaf
update_direct_bodies!(direct_target_bodies, place_counter[j_leaf], target_branches[i_target].bodies_index)
update_direct_bodies!(direct_source_bodies, place_counter[j_leaf], source_branches[j_source].bodies_index)
place_counter[j_leaf] += 1
# preallocate sorted direct_list
sorted_direct_list = similar(direct_list)

# sort direct_list by source
for ij in direct_list
# get source branch index
j = ij[2]

# get and update target destination index for this branch
i_dest = source_counter[2,j]
source_counter[2,j] += Int32(1)

# place target-source pair in the sorted list
sorted_direct_list[i_dest] = ij
end

return direct_target_bodies, direct_source_bodies
return sorted_direct_list
end


@inline function update_direct_bodies!(direct_bodies::Vector{<:UnitRange}, leaf_index, bodies_index::UnitRange)
direct_bodies[leaf_index] = bodies_index
end
Expand Down
27 changes: 27 additions & 0 deletions test/interaction_list_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
@testset "interaction list: sort convenience function" begin

Random.seed!(456)
n_bodies = 10_000
bodies = rand(8,n_bodies)
masses = Gravitational(bodies)
tree = Tree(masses; leaf_size = 100)
mac = 0.5
expansion_order = 0

m2l_list, direct_list = build_interaction_lists(tree.branches, tree.branches, tree.leaf_index, mac, true, true, true, UnequalSpheres(), expansion_order)

source_sorted_direct_list = FastMultipole.sort_list_by_source(direct_list, tree.branches)
for i in 2:length(source_sorted_direct_list)
_, i_source = source_sorted_direct_list[i]
_, im1_source = source_sorted_direct_list[i-1]
@test i_source >= im1_source
end

target_sorted_direct_list = FastMultipole.sort_list_by_target(direct_list, tree.branches)
for i in 2:length(target_sorted_direct_list)
i_target, _ = target_sorted_direct_list[i]
im1_target, _ = target_sorted_direct_list[i-1]
@test i_target >= im1_target
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,6 @@ include("evaluate_expansions_test.jl")
include("lamb_helmholtz_test.jl")
include("tree_test.jl")
include("dynamic_expansion_order_test.jl")
include("interaction_list_test.jl")
include("fmm_test.jl")

6 changes: 0 additions & 6 deletions test/vortex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,6 @@ end
function Base.setindex!(vp::VortexParticles, val, i, ::VelocityGradient)
vp.potential[i_VELOCITY_GRADIENT_vortex,i] .= reshape(val,9)
end
function Base.setindex!(vp::VortexParticles, val, i, ::Strength)
p = vp.bodies[i]
position = p.position
sigma = p.sigma
vp.bodies[i] = Vorton(position, val, sigma)
end
FastMultipole.get_n_bodies(vp::VortexParticles) = length(vp.bodies)
Base.eltype(::VortexParticles{TF}) where TF = TF

Expand Down

0 comments on commit a2031fc

Please sign in to comment.