Skip to content

Commit

Permalink
abstract expansion dimension method (for LambHelmholtz decomposition)
Browse files Browse the repository at this point in the history
  • Loading branch information
rymanderson committed Jul 24, 2024
1 parent a112409 commit 5dc2e98
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 28 deletions.
9 changes: 9 additions & 0 deletions src/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ const NORMAL = Normal()
struct Strength <: Indexable end
const STRENGTH = Strength()

#####
##### dispatch on multipole dimension method
#####
abstract type AbstractMethod end

abstract type ScalarPlusVector <: AbstractMethod end

abstract type LambHelmholtz <: AbstractMethod end

#####
##### dispatch convenience functions for multipole creation definition
#####
Expand Down
52 changes: 26 additions & 26 deletions src/fmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ end
#####
##### upward pass
#####
function upward_pass_singlethread!(branches, systems, expansion_order::Val{P}) where P
function upward_pass_singlethread!(branches, systems, expansion_order::Val{P}, ::Type{ScalarPlusVector}) where P

# loop over branches
for branch in view(branches,length(branches):-1:1) # no need to create a multipole expansion at the very top level
Expand Down Expand Up @@ -252,7 +252,7 @@ function translate_multipoles_multithread!(branches, expansion_order::Val{P}, le
end
end

function upward_pass_multithread!(branches, systems, expansion_order, levels_index, leaf_index, n_threads)
function upward_pass_multithread!(branches, systems, expansion_order, levels_index, leaf_index, n_threads, ::Type{ScalarPlusVector})
# create multipole expansions
body_2_multipole_multithread!(branches, systems, expansion_order, leaf_index, n_threads)

Expand Down Expand Up @@ -358,7 +358,7 @@ function nearfield_singlethread!(target_system, target_tree::Tree, derivatives_s
end
end

function horizontal_pass_singlethread!(target_branches, source_branches, m2l_list, expansion_order)
function horizontal_pass_singlethread!(target_branches, source_branches, m2l_list, expansion_order, ::Type{ScalarPlusVector})
for (i_target, j_source) in m2l_list
M2L!(target_branches[i_target], source_branches[j_source], expansion_order)
end
Expand All @@ -380,7 +380,7 @@ end
# containers = [preallocate_horizontal_pass(expansion_type, expansion_order) for _ in 1:n]
# end

function horizontal_pass_multithread!(target_branches, source_branches::Vector{<:Branch{TF}}, m2l_list, expansion_order::Val{P}, n_threads) where {TF,P}
function horizontal_pass_multithread!(target_branches, source_branches::Vector{<:Branch{TF}}, m2l_list, expansion_order::Val{P}, n_threads, ::Type{ScalarPlusVector}) where {TF,P}
# number of translations per thread
n_per_thread, rem = divrem(length(m2l_list),n_threads)
rem > 0 && (n_per_thread += 1)
Expand Down Expand Up @@ -420,7 +420,7 @@ function preallocate_l2b(float_type, expansion_type, expansion_order::Val{P}, n_
return containers
end

function downward_pass_singlethread!(branches, systems, derivatives_switches, expansion_order::Val{P}) where P
function downward_pass_singlethread!(branches, systems, derivatives_switches, expansion_order::Val{P}, ::Type{ScalarPlusVector}) where P
regular_harmonics = zeros(eltype(branches[1].multipole_expansion), 2, (P+1)*(P+1))
for branch in branches
if branch.n_branches == 0 # leaf level
Expand Down Expand Up @@ -497,7 +497,7 @@ function local_2_body_multithread!(branches, systems, derivatives_switches, expa
end
end

function downward_pass_multithread!(branches, systems, derivatives_switch, expansion_order, levels_index, leaf_index, n_threads)
function downward_pass_multithread!(branches, systems, derivatives_switch, expansion_order, levels_index, leaf_index, n_threads, ::Type{ScalarPlusVector})
# m2m translation
translate_locals_multithread!(branches, expansion_order, levels_index, n_threads)

Expand Down Expand Up @@ -896,7 +896,7 @@ function fmm!(target_systems, source_systems;
nearfield=true, farfield=true, self_induced=true,
unsort_source_bodies=true, unsort_target_bodies=true,
source_shrink_recenter=true, target_shrink_recenter=true,
save_tree=false, save_name="tree", gpu=false
save_tree=false, save_name="tree", gpu=false, method=ScalarPlusVector
)
# check for duplicate systems
target_systems = wrap_duplicates(target_systems, source_systems)
Expand All @@ -912,7 +912,7 @@ function fmm!(target_systems, source_systems;
reset_source_tree=false, reset_target_tree=false,
upward_pass, horizontal_pass, downward_pass,
nearfield, farfield, self_induced,
unsort_source_bodies, unsort_target_bodies, gpu
unsort_source_bodies, unsort_target_bodies, gpu, method
)

# visualize
Expand Down Expand Up @@ -960,7 +960,7 @@ function fmm!(systems;
upward_pass=true, horizontal_pass=true, downward_pass=true,
nearfield=true, farfield=true, self_induced=true,
unsort_bodies=true, shrink_recenter=true,
save_tree=false, save_name="tree", gpu=false
save_tree=false, save_name="tree", gpu=false, method=ScalarPlusVector
)
# create tree
tree = Tree(systems; expansion_order, leaf_size, shrink_recenter)
Expand All @@ -971,7 +971,7 @@ function fmm!(systems;
multipole_threshold, reset_tree=false,
upward_pass, horizontal_pass, downward_pass,
nearfield, farfield, self_induced,
unsort_bodies, gpu
unsort_bodies, gpu, method
)

# visualize
Expand Down Expand Up @@ -1014,7 +1014,7 @@ function fmm!(tree::Tree, systems;
multipole_threshold=0.4, reset_tree=true,
upward_pass=true, horizontal_pass=true, downward_pass=true,
nearfield=true, farfield=true, self_induced=true,
unsort_bodies=true, gpu=false
unsort_bodies=true, gpu=false, method=ScalarPlusVector
)

# assemble derivatives switch
Expand All @@ -1027,7 +1027,7 @@ function fmm!(tree::Tree, systems;
fmm!(tree, systems, m2l_list, direct_target_bodies, direct_source_bodies, derivatives_switches;
reset_tree,
nearfield, upward_pass, horizontal_pass, downward_pass,
unsort_bodies, gpu
unsort_bodies, gpu, method
)

return m2l_list, direct_target_bodies, direct_source_bodies, derivatives_switches
Expand Down Expand Up @@ -1076,7 +1076,7 @@ function fmm!(target_tree::Tree, target_systems, source_tree::Tree, source_syste
upward_pass=true, horizontal_pass=true, downward_pass=true,
nearfield=true, farfield=true, self_induced=true,
unsort_source_bodies=true, unsort_target_bodies=true,
gpu=false
gpu=false, method=ScalarPlusVector
)

# assemble derivatives switch
Expand All @@ -1089,7 +1089,7 @@ function fmm!(target_tree::Tree, target_systems, source_tree::Tree, source_syste
fmm!(target_tree, target_systems, source_tree, source_systems, m2l_list, direct_target_bodies, direct_source_bodies, derivatives_switches;
reset_source_tree, reset_target_tree,
nearfield, upward_pass, horizontal_pass, downward_pass,
unsort_source_bodies, unsort_target_bodies, gpu
unsort_source_bodies, unsort_target_bodies, gpu, method
)

return m2l_list, direct_target_bodies, direct_source_bodies, derivatives_switches
Expand Down Expand Up @@ -1123,14 +1123,14 @@ Dispatches `fmm!` using an existing `::Tree`.
function fmm!(tree::Tree, systems, m2l_list, direct_target_bodies, direct_source_bodies, derivatives_switches;
reset_tree=true,
nearfield=true, upward_pass=true, horizontal_pass=true, downward_pass=true,
unsort_bodies=true, gpu=false
unsort_bodies=true, gpu=false, method=ScalarPlusVector
)

fmm!(tree, systems, tree, systems, m2l_list, direct_target_bodies, direct_source_bodies, derivatives_switches;
reset_source_tree=reset_tree, reset_target_tree=false,
nearfield, upward_pass, horizontal_pass, downward_pass,
unsort_source_bodies=unsort_bodies, unsort_target_bodies=false,
gpu
gpu, method
)

end
Expand Down Expand Up @@ -1184,7 +1184,7 @@ function fmm!(target_tree::Tree, target_systems, source_tree::Tree, source_syste
reset_source_tree=true, reset_target_tree=true,
nearfield=true, upward_pass=true, horizontal_pass=true, downward_pass=true,
unsort_source_bodies=true, unsort_target_bodies=true,
gpu=false
gpu=false, method::Type{<:AbstractMethod}=ScalarPlusVector
)
# check if systems are empty
n_sources = get_n_bodies(source_systems)
Expand All @@ -1204,27 +1204,27 @@ function fmm!(target_tree::Tree, target_systems, source_tree::Tree, source_syste

@sync begin
Threads.@spawn nearfield_singlethread!(target_systems, direct_target_bodies, derivatives_switches, source_systems, direct_source_bodies, Val(gpu))
upward_pass && upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, source_tree.levels_index, source_tree.leaf_index, n_threads-1)
horizontal_pass && length(m2l_list) > 0 && horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, n_threads-1)
upward_pass && upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, source_tree.levels_index, source_tree.leaf_index, n_threads-1, method)
horizontal_pass && length(m2l_list) > 0 && horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, n_threads-1, method)
end
downward_pass && downward_pass_multithread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order, target_tree.levels_index, target_tree.leaf_index, n_threads)
downward_pass && downward_pass_multithread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order, target_tree.levels_index, target_tree.leaf_index, n_threads, method)

else

# nearfield interactions
if n_threads == 1 # && !gpu || gpu

nearfield_singlethread!(target_systems, direct_target_bodies, derivatives_switches, source_systems, direct_source_bodies, Val(gpu))
upward_pass && upward_pass_singlethread!(source_tree.branches, source_systems, source_tree.expansion_order)
horizontal_pass && length(m2l_list) > 0 && horizontal_pass_singlethread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order)
downward_pass && downward_pass_singlethread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order)
upward_pass && upward_pass_singlethread!(source_tree.branches, source_systems, source_tree.expansion_order, method)
horizontal_pass && length(m2l_list) > 0 && horizontal_pass_singlethread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, method)
downward_pass && downward_pass_singlethread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order, method)

else # n_threads > 1 && !gpu

nearfield_multithread!(target_systems, direct_target_bodies, derivatives_switches, source_systems, direct_source_bodies, n_threads)
upward_pass && upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, source_tree.levels_index, source_tree.leaf_index, n_threads)
horizontal_pass && length(m2l_list) > 0 && horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, n_threads)
downward_pass && downward_pass_multithread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order, target_tree.levels_index, target_tree.leaf_index, n_threads)
upward_pass && upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, source_tree.levels_index, source_tree.leaf_index, n_threads, method)
horizontal_pass && length(m2l_list) > 0 && horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, n_threads, method)
downward_pass && downward_pass_multithread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order, target_tree.levels_index, target_tree.leaf_index, n_threads, method)
end

end
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ include(joinpath(test_dir, "gravitational.jl"))
# allow expansion_order::Int to be used while testing
FastMultipole.M2M!(branch, child, harmonics, M, expansion_order::Int) = FastMultipole.M2M!(branch, child, harmonics, M, Val(expansion_order))
FastMultipole.M2L!(target_branch, source_branch, harmonics, L, expansion_order::Int) = FastMultipole.M2L!(target_branch, source_branch, harmonics, L, Val(expansion_order))
FastMultipole.upward_pass_singlethread!(branches, systems, expansion_order::Int) = FastMultipole.upward_pass_singlethread!(branches, systems, Val(expansion_order))
FastMultipole.upward_pass_singlethread!(branches, systems, expansion_order::Int) = FastMultipole.upward_pass_singlethread!(branches, systems, Val(expansion_order), FastMultipole.ScalarPlusVector)
FastMultipole.M2L!(target_branch, source_branch, expansion_order::Int) = FastMultipole.M2L!(target_branch, source_branch, Val(expansion_order))

@testset "complex" begin
Expand Down Expand Up @@ -568,7 +568,7 @@ u_fmm_67 = mass_target_potential[1]
# perform horizontal pass
m2l_list, direct_list_target, direct_list_source = FastMultipole.build_interaction_lists(tree.branches, tree.branches, tree.leaf_index, multipole_threshold, true, true, true)
FastMultipole.nearfield_singlethread!((elements,), direct_list_target, (FastMultipole.DerivativesSwitch(),), (elements,), direct_list_source)
FastMultipole.horizontal_pass_singlethread!(tree.branches, tree.branches, m2l_list, expansion_order)
FastMultipole.horizontal_pass_singlethread!(tree.branches, tree.branches, m2l_list, expansion_order, FastMultipole.ScalarPlusVector)

# consider the effect on branch 3 (mass 2)
elements.potential[i_POTENTIAL,new_order_index[2]] .*= 0 # reset potential at mass 2
Expand Down

0 comments on commit 5dc2e98

Please sign in to comment.