Skip to content

Commit

Permalink
fix multithreading
Browse files Browse the repository at this point in the history
  • Loading branch information
rymanderson committed Oct 16, 2024
1 parent 2027703 commit 038232a
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 114 deletions.
183 changes: 69 additions & 114 deletions src/fmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,63 +143,13 @@ function upward_pass_singlethread!(branches::AbstractVector{<:Branch{TF}}, syste
end
end

function body_2_multipole_multithread!(branches, systems::Tuple, expansion_order::Val{P}, leaf_index, n_threads) where P
## load balance
leaf_assignments = fill(1:0, length(systems), n_threads)
for (i_system,system) in enumerate(systems)

# total number of bodies
n_bodies = 0
for i_leaf in leaf_index
n_bodies += length(branches[i_leaf].bodies_index[i_system])
end

# number of bodies per thread
n_per_thread, rem = divrem(n_bodies, n_threads)
rem > 0 && (n_per_thread += 1)

# if there are too many threads, we'll actually hurt performance
n_per_thread < MIN_NPT_B2M && (n_per_thread = MIN_NPT_B2M)

# create chunks
i_start = 1
i_thread = 1
n_bodies = 0
for (i_end,i_leaf) in enumerate(leaf_index)
n_bodies += length(branches[i_leaf].bodies_index[i_system])
if n_bodies >= n_per_thread
leaf_assignments[i_system,i_thread] = i_start:i_end
i_start = i_end+1
i_thread += 1
n_bodies = 0
end
end
i_thread <= n_threads && (leaf_assignments[i_system,i_thread] = i_start:length(leaf_index))
end

## compute multipole expansion coefficients
Threads.@threads for i_thread in 1:n_threads
for (i_system,system) in enumerate(systems)
leaf_assignment = leaf_assignments[i_system,i_thread]
for i_branch in view(leaf_index, leaf_assignment)
branch = branches[i_branch]
Threads.lock(branch.lock) do
body_to_multipole!(system, branch, branch.bodies_index[i_system], branch.harmonics, expansion_order)
end
end
end
end
end
function body_to_multipole_multithread!(branches, systems, expansion_order::Val{P}, leaf_index, n_threads) where P

function body_2_multipole_multithread!(branches, system, expansion_order::Val{P}, leaf_index, n_threads) where P
## load balance
leaf_assignments = fill(1:0, n_threads)

# total number of bodies
n_bodies = 0
for i_leaf in leaf_index
n_bodies += length(branches[i_leaf].bodies_index)
end
n_bodies = get_n_bodies(systems)

# number of bodies per thread
n_per_thread, rem = divrem(n_bodies, n_threads)
Expand All @@ -213,7 +163,7 @@ function body_2_multipole_multithread!(branches, system, expansion_order::Val{P}
i_thread = 1
n_bodies = 0
for (i_end,i_leaf) in enumerate(leaf_index)
n_bodies += length(branches[i_leaf].bodies_index)
n_bodies += get_n_bodies(branches[i_leaf])
if n_bodies >= n_per_thread
leaf_assignments[i_thread] = i_start:i_end
i_start = i_end+1
Expand All @@ -224,15 +174,21 @@ function body_2_multipole_multithread!(branches, system, expansion_order::Val{P}
i_thread <= n_threads && (leaf_assignments[i_thread] = i_start:length(leaf_index))

## compute multipole expansion coefficients
Threads.@threads for i_thread in 1:n_threads
for i_branch in view(leaf_index, leaf_assignments[i_thread])
Threads.@threads for assignment in leaf_assignments
for i_task in assignment
i_branch = leaf_index[i_task]
branch = branches[i_branch]
body_to_multipole!(system, branch, branch.bodies_index, branch.harmonics, expansion_order)
body_to_multipole!(branch, systems, branch.harmonics, expansion_order)
end
end
end

function translate_multipoles_multithread!(branches, expansion_order::Val{P}, levels_index, n_threads) where P
function translate_multipoles_multithread!(branches::AbstractVector{<:Branch{TF}}, expansion_order::Val{P}, lamb_helmholtz, levels_index, n_threads) where {TF,P}
# try preallocating one container per task to be reused
Ts = [zeros(length_Ts(P)) for _ in 1:n_threads]
eimϕs = [zeros(2, P+1) for _ in 1:n_threads]
weights_tmp_1 = [initialize_expansion(P, TF) for _ in 1:n_threads]
weights_tmp_2 = [initialize_expansion(P, TF) for _ in 1:n_threads]

# iterate over levels
for level_index in view(levels_index,length(levels_index):-1:2)
Expand All @@ -246,7 +202,10 @@ function translate_multipoles_multithread!(branches, expansion_order::Val{P}, le
n_per_thread < MIN_NPT_M2M && (n_per_thread = MIN_NPT_M2M)

# assign thread start branches
Threads.@threads for i_start in 1:n_per_thread:n_branches
i_starts = 1:n_per_thread:n_branches
Threads.@threads for i_task in 1:length(i_starts)
i_start = i_starts[i_task]

# get final branch
i_end = min(n_branches, i_start+n_per_thread-1)

Expand All @@ -255,19 +214,19 @@ function translate_multipoles_multithread!(branches, expansion_order::Val{P}, le
child_branch = branches[i_branch]
parent_branch = branches[child_branch.i_parent]
Threads.lock(parent_branch.lock) do
M2M!(parent_branch, child_branch, child_branch.harmonics, child_branch.ML, expansion_order)
multipole_to_multipole!(parent_branch, child_branch, weights_tmp_1[i_task], weights_tmp_2[i_task], Ts[i_task], eimϕs[i_task], ζs_mag, Hs_π2, expansion_order, lamb_helmholtz)
end
end
end
end
end

function upward_pass_multithread!(branches, systems, expansion_order, levels_index, leaf_index, n_threads)
function upward_pass_multithread!(branches, systems, expansion_order, lamb_helmholtz, levels_index, leaf_index, n_threads)
# create multipole expansions
body_2_multipole_multithread!(branches, systems, expansion_order, leaf_index, n_threads)
body_to_multipole_multithread!(branches, systems, expansion_order, leaf_index, n_threads)

# m2m translation
translate_multipoles_multithread!(branches, expansion_order, levels_index, n_threads)
translate_multipoles_multithread!(branches, expansion_order, lamb_helmholtz, levels_index, n_threads)
end

#------- direct interaction matrix -------#
Expand All @@ -293,48 +252,28 @@ function horizontal_pass_singlethread!(target_branches::Vector{<:Branch{TF}}, so

end

# function horizontal_pass_singlethread!(target_branches, source_branches, m2l_list, expansion_order, harmonics, L)
# for (i_target, j_source) in m2l_list
# @lock target_branches[i_target].lock M2L!(target_branches[i_target], source_branches[j_source], harmonics, L, expansion_order)
# end
# end

# @inline function preallocate_horizontal_pass(expansion_type, expansion_order)
# harmonics = zeros(expansion_type, (expansion_order<<1 + 1)*(expansion_order<<1 + 1))
# L = zeros(expansion_type, 4)
# return harmonics, L
# end
function horizontal_pass_multithread!(target_branches, source_branches::Vector{<:Branch{TF}}, m2l_list, expansion_order::Val{P}, lamb_helmholtz, n_threads) where {TF,P}
# try preallocating one container per task to be reused
Ts = [zeros(length_Ts(P)) for _ in 1:n_threads]
eimϕs = [zeros(2, P+1) for _ in 1:n_threads]
weights_tmp_1 = [initialize_expansion(P, TF) for _ in 1:n_threads]
weights_tmp_2 = [initialize_expansion(P, TF) for _ in 1:n_threads]

# @inline function preallocate_horizontal_pass(expansion_type, expansion_order, n)
# containers = [preallocate_horizontal_pass(expansion_type, expansion_order) for _ in 1:n]
# end

function horizontal_pass_multithread!(target_branches, source_branches::Vector{<:Branch{TF}}, m2l_list, expansion_order::Val{P}, n_threads) where {TF,P}
# number of translations per thread
n_per_thread, rem = divrem(length(m2l_list),n_threads)
rem > 0 && (n_per_thread += 1)
assignments = 1:n_per_thread:length(m2l_list)

# preallocate memory
# harmonics_preallocated = [initialize_harmonics(P,TF) for _ in 1:length(assignments)]
# ML_preallocated = [initialize_ML(P,TF) for _ in 1:length(assignments)]

# execute tasks
Threads.@threads for i_thread in 1:length(assignments)
# Threads.@threads for i_start in 1:n_per_thread:length(m2l_list)
i_start = assignments[i_thread]
Threads.@threads for i_task in 1:length(assignments)
i_start = assignments[i_task]
i_stop = min(i_start+n_per_thread-1, length(m2l_list))
# harmonics = harmonics_preallocated[i_thread]
# ML = ML_preallocated[i_thread]
# harmonics = initialize_harmonics(P,TF)
# ML = initialize_ML(P,TF)
for (i_target, j_source) in m2l_list[i_start:i_stop]
Threads.lock(target_branches[i_target].lock) do
M2L!(target_branches[i_target], source_branches[j_source], expansion_order)
# M2L!(target_branches[i_target], source_branches[j_source], harmonics, ML, expansion_order)
target_branch = target_branches[i_target]
source_branch = source_branches[j_source]
Threads.lock(target_branch.lock) do
multipole_to_local!(target_branch, source_branch, weights_tmp_1[i_task], weights_tmp_2[i_task], Ts[i_task], eimϕs[i_task], ζs_mag, ηs_mag, Hs_π2, expansion_order, lamb_helmholtz)
end
# target_branch = target_branches[i_target]
# Threads.@lock target_branch.lock M2L!(target_branch, source_branches[j_source], expansion_order)
end
end

Expand Down Expand Up @@ -368,7 +307,13 @@ function downward_pass_singlethread!(branches::AbstractVector{<:Branch{TF}}, sys
end
end

function translate_locals_multithread!(branches, expansion_order::Val{P}, levels_index, n_threads) where P
function translate_locals_multithread!(branches::AbstractVector{<:Branch{TF}}, expansion_order::Val{P}, lamb_helmholtz, levels_index, n_threads) where {TF,P}

# try preallocating one container per task to be reused
Ts = [zeros(length_Ts(P)) for _ in 1:n_threads]
eimϕs = [zeros(2, P+1) for _ in 1:n_threads]
weights_tmp_1 = [initialize_expansion(P, TF) for _ in 1:n_threads]
weights_tmp_2 = [initialize_expansion(P, TF) for _ in 1:n_threads]

# iterate over levels
for level_index in view(levels_index,2:length(levels_index))
Expand All @@ -381,19 +326,25 @@ function translate_locals_multithread!(branches, expansion_order::Val{P}, levels
n_per_thread < MIN_NPT_L2L && (n_per_thread = MIN_NPT_L2L)

# loop over branches
Threads.@threads for i_start in 1:n_per_thread:length(level_index)
i_starts = 1:n_per_thread:length(level_index)
#Threads.@threads for (i_task,i_start) in enumerate(1:n_per_thread:length(level_index))
Threads.@threads for i_task in 1:length(i_starts)
i_start = i_starts[i_task]
i_stop = min(i_start+n_per_thread-1,length(level_index))

# loop over branches
for child_branch in view(branches,view(level_index,i_start:i_stop))
L2L!(branches[child_branch.i_parent], child_branch, child_branch.harmonics, child_branch.ML, expansion_order)
local_to_local!(branches[child_branch.i_parent], child_branch, weights_tmp_1[i_task], weights_tmp_2[i_task], Ts[i_task], eimϕs[i_task], ηs_mag, Hs_π2, expansion_order, lamb_helmholtz)
end
end
end
return nothing
end

function local_2_body_multithread!(branches, systems, derivatives_switches, expansion_order, leaf_index, n_threads)
function local_to_body_multithread!(branches::AbstractVector{<:Branch{TF}}, systems, derivatives_switches, expansion_order::Val{P}, lamb_helmholtz, leaf_index, n_threads) where {TF,P}
# preallocate containers
velocity_n_m = [initialize_velocity_n_m(P,TF) for _ in 1:n_threads]

# create assignments
n_bodies = 0
for i_leaf in leaf_index
Expand Down Expand Up @@ -423,21 +374,21 @@ function local_2_body_multithread!(branches, systems, derivatives_switches, expa
resize!(assignments, i_thread)

# spread remainder across rem chunks
Threads.@threads for i_thread in eachindex(assignments)
assignment = assignments[i_thread]
Threads.@threads for i_task in eachindex(assignments)
assignment = assignments[i_task]
for i_leaf in view(leaf_index,assignment)
leaf = branches[i_leaf]
L2B!(systems, leaf, derivatives_switches, expansion_order)
evaluate_local!(systems, leaf, leaf.harmonics, velocity_n_m[i_task], expansion_order, lamb_helmholtz, derivatives_switches)
end
end
end

function downward_pass_multithread!(branches, systems, derivatives_switch, expansion_order, levels_index, leaf_index, n_threads)
# m2m translation
translate_locals_multithread!(branches, expansion_order, levels_index, n_threads)
function downward_pass_multithread!(branches, systems, derivatives_switch, expansion_order, lamb_helmholtz, levels_index, leaf_index, n_threads)
# local to local translation
translate_locals_multithread!(branches, expansion_order, lamb_helmholtz, levels_index, n_threads)

# local to body interaction
local_2_body_multithread!(branches, systems, derivatives_switch, expansion_order, leaf_index, n_threads)
local_to_body_multithread!(branches, systems, derivatives_switch, expansion_order, lamb_helmholtz, leaf_index, n_threads)
end

#####
Expand Down Expand Up @@ -946,6 +897,10 @@ function fmm!(target_tree::Tree, target_systems, source_tree::Tree, source_syste
n_sources = get_n_bodies(source_systems)
n_targets = get_n_bodies(target_systems)

# wrap lamb_helmholtz in Val
lamb_helmholtz = Val(lamb_helmholtz)


if n_sources > 0 && n_targets > 0

# precompute y-axis rotation by π/2 matrices (if not already done)
Expand All @@ -968,28 +923,28 @@ function fmm!(target_tree::Tree, target_systems, source_tree::Tree, source_syste
@sync begin

Threads.@spawn nearfield_singlethread!(target_systems, direct_target_bodies, derivatives_switches, source_systems, direct_source_bodies, Val(gpu))
upward_pass && upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, source_tree.levels_index, source_tree.leaf_index, n_threads-1)
horizontal_pass && length(m2l_list) > 0 && horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, n_threads-1)
upward_pass && upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, lamb_helmholtz, source_tree.levels_index, source_tree.leaf_index, n_threads-1)
horizontal_pass && length(m2l_list) > 0 && horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, lamb_helmholtz, n_threads-1)

end
downward_pass && downward_pass_multithread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order, target_tree.levels_index, target_tree.leaf_index, n_threads)
downward_pass && downward_pass_multithread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order, lamb_helmholtz, target_tree.levels_index, target_tree.leaf_index, n_threads)

else

# nearfield interactions
if n_threads == 1 # && !gpu || gpu

nearfield_singlethread!(target_systems, direct_target_bodies, derivatives_switches, source_systems, direct_source_bodies, Val(gpu))
upward_pass && upward_pass_singlethread!(source_tree.branches, source_systems, source_tree.expansion_order, Val(lamb_helmholtz))
horizontal_pass && length(m2l_list) > 0 && horizontal_pass_singlethread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, Val(lamb_helmholtz))
downward_pass && downward_pass_singlethread!(target_tree.branches, target_systems, target_tree.expansion_order, Val(lamb_helmholtz), derivatives_switches)
upward_pass && upward_pass_singlethread!(source_tree.branches, source_systems, source_tree.expansion_order, lamb_helmholtz)
horizontal_pass && length(m2l_list) > 0 && horizontal_pass_singlethread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, lamb_helmholtz)
downward_pass && downward_pass_singlethread!(target_tree.branches, target_systems, target_tree.expansion_order, lamb_helmholtz, derivatives_switches)

else # n_threads > 1 && !gpu

nearfield_multithread!(target_systems, direct_target_bodies, derivatives_switches, source_systems, direct_source_bodies, n_threads)
upward_pass && upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, source_tree.levels_index, source_tree.leaf_index, n_threads)
horizontal_pass && length(m2l_list) > 0 && horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, n_threads)
downward_pass && downward_pass_multithread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order, target_tree.levels_index, target_tree.leaf_index, n_threads)
upward_pass && upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, lamb_helmholtz, source_tree.levels_index, source_tree.leaf_index, n_threads)
horizontal_pass && length(m2l_list) > 0 && horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, lamb_helmholtz, n_threads)
downward_pass && downward_pass_multithread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order, lamb_helmholtz, target_tree.levels_index, target_tree.leaf_index, n_threads)

end

Expand Down
6 changes: 6 additions & 0 deletions src/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,12 @@ function initialize_expansion(expansion_order, type=Float64)
return zeros(type, 2, 2, ((expansion_order+1) * (expansion_order+2)) >> 1)
end

function initialize_velocity_n_m(expansion_order, type=Float64)
p = expansion_order
n_harmonics = harmonic_index(p,p)
return zeros(type, 2, 3, n_harmonics)
end

function initialize_harmonics(expansion_order, type=Float64)
p = expansion_order
n_harmonics = harmonic_index(p,p)
Expand Down

0 comments on commit 038232a

Please sign in to comment.