diff --git a/src/containers.jl b/src/containers.jl index ce0924a..022b83d 100644 --- a/src/containers.jl +++ b/src/containers.jl @@ -33,6 +33,15 @@ const NORMAL = Normal() struct Strength <: Indexable end const STRENGTH = Strength() +##### +##### dispatch on multipole dimension method +##### +abstract type AbstractMethod end + +abstract type ScalarPlusVector <: AbstractMethod end + +abstract type LambHelmholtz <: AbstractMethod end + ##### ##### dispatch convenience functions for multipole creation definition ##### diff --git a/src/fmm.jl b/src/fmm.jl index c568564..6ce30e1 100644 --- a/src/fmm.jl +++ b/src/fmm.jl @@ -118,7 +118,7 @@ end ##### ##### upward pass ##### -function upward_pass_singlethread!(branches, systems, expansion_order::Val{P}) where P +function upward_pass_singlethread!(branches, systems, expansion_order::Val{P}, ::Type{ScalarPlusVector}) where P # loop over branches for branch in view(branches,length(branches):-1:1) # no need to create a multipole expansion at the very top level @@ -252,7 +252,7 @@ function translate_multipoles_multithread!(branches, expansion_order::Val{P}, le end end -function upward_pass_multithread!(branches, systems, expansion_order, levels_index, leaf_index, n_threads) +function upward_pass_multithread!(branches, systems, expansion_order, levels_index, leaf_index, n_threads, ::Type{ScalarPlusVector}) # create multipole expansions body_2_multipole_multithread!(branches, systems, expansion_order, leaf_index, n_threads) @@ -358,7 +358,7 @@ function nearfield_singlethread!(target_system, target_tree::Tree, derivatives_s end end -function horizontal_pass_singlethread!(target_branches, source_branches, m2l_list, expansion_order) +function horizontal_pass_singlethread!(target_branches, source_branches, m2l_list, expansion_order, ::Type{ScalarPlusVector}) for (i_target, j_source) in m2l_list M2L!(target_branches[i_target], source_branches[j_source], expansion_order) end @@ -380,7 +380,7 @@ end # containers = [preallocate_horizontal_pass(expansion_type, expansion_order) for _ in 1:n] # end -function horizontal_pass_multithread!(target_branches, source_branches::Vector{<:Branch{TF}}, m2l_list, expansion_order::Val{P}, n_threads) where {TF,P} +function horizontal_pass_multithread!(target_branches, source_branches::Vector{<:Branch{TF}}, m2l_list, expansion_order::Val{P}, n_threads, ::Type{ScalarPlusVector}) where {TF,P} # number of translations per thread n_per_thread, rem = divrem(length(m2l_list),n_threads) rem > 0 && (n_per_thread += 1) @@ -420,7 +420,7 @@ function preallocate_l2b(float_type, expansion_type, expansion_order::Val{P}, n_ return containers end -function downward_pass_singlethread!(branches, systems, derivatives_switches, expansion_order::Val{P}) where P +function downward_pass_singlethread!(branches, systems, derivatives_switches, expansion_order::Val{P}, ::Type{ScalarPlusVector}) where P regular_harmonics = zeros(eltype(branches[1].multipole_expansion), 2, (P+1)*(P+1)) for branch in branches if branch.n_branches == 0 # leaf level @@ -497,7 +497,7 @@ function local_2_body_multithread!(branches, systems, derivatives_switches, expa end end -function downward_pass_multithread!(branches, systems, derivatives_switch, expansion_order, levels_index, leaf_index, n_threads) +function downward_pass_multithread!(branches, systems, derivatives_switch, expansion_order, levels_index, leaf_index, n_threads, ::Type{ScalarPlusVector}) # m2m translation translate_locals_multithread!(branches, expansion_order, levels_index, n_threads) @@ -896,7 +896,7 @@ function fmm!(target_systems, source_systems; nearfield=true, farfield=true, self_induced=true, unsort_source_bodies=true, unsort_target_bodies=true, source_shrink_recenter=true, target_shrink_recenter=true, - save_tree=false, save_name="tree", gpu=false + save_tree=false, save_name="tree", gpu=false, method=ScalarPlusVector ) # check for duplicate systems target_systems = wrap_duplicates(target_systems, source_systems) @@ -912,7 +912,7 @@ function fmm!(target_systems, source_systems; reset_source_tree=false, reset_target_tree=false, upward_pass, horizontal_pass, downward_pass, nearfield, farfield, self_induced, - unsort_source_bodies, unsort_target_bodies, gpu + unsort_source_bodies, unsort_target_bodies, gpu, method ) # visualize @@ -960,7 +960,7 @@ function fmm!(systems; upward_pass=true, horizontal_pass=true, downward_pass=true, nearfield=true, farfield=true, self_induced=true, unsort_bodies=true, shrink_recenter=true, - save_tree=false, save_name="tree", gpu=false + save_tree=false, save_name="tree", gpu=false, method=ScalarPlusVector ) # create tree tree = Tree(systems; expansion_order, leaf_size, shrink_recenter) @@ -971,7 +971,7 @@ function fmm!(systems; multipole_threshold, reset_tree=false, upward_pass, horizontal_pass, downward_pass, nearfield, farfield, self_induced, - unsort_bodies, gpu + unsort_bodies, gpu, method ) # visualize @@ -1014,7 +1014,7 @@ function fmm!(tree::Tree, systems; multipole_threshold=0.4, reset_tree=true, upward_pass=true, horizontal_pass=true, downward_pass=true, nearfield=true, farfield=true, self_induced=true, - unsort_bodies=true, gpu=false + unsort_bodies=true, gpu=false, method=ScalarPlusVector ) # assemble derivatives switch @@ -1027,7 +1027,7 @@ function fmm!(tree::Tree, systems; fmm!(tree, systems, m2l_list, direct_target_bodies, direct_source_bodies, derivatives_switches; reset_tree, nearfield, upward_pass, horizontal_pass, downward_pass, - unsort_bodies, gpu + unsort_bodies, gpu, method ) return m2l_list, direct_target_bodies, direct_source_bodies, derivatives_switches @@ -1076,7 +1076,7 @@ function fmm!(target_tree::Tree, target_systems, source_tree::Tree, source_syste upward_pass=true, horizontal_pass=true, downward_pass=true, nearfield=true, farfield=true, self_induced=true, unsort_source_bodies=true, unsort_target_bodies=true, - gpu=false + gpu=false, method=ScalarPlusVector ) # assemble derivatives switch @@ -1089,7 +1089,7 @@ function fmm!(target_tree::Tree, target_systems, source_tree::Tree, source_syste fmm!(target_tree, target_systems, source_tree, source_systems, m2l_list, direct_target_bodies, direct_source_bodies, derivatives_switches; reset_source_tree, reset_target_tree, nearfield, upward_pass, horizontal_pass, downward_pass, - unsort_source_bodies, unsort_target_bodies, gpu + unsort_source_bodies, unsort_target_bodies, gpu, method ) return m2l_list, direct_target_bodies, direct_source_bodies, derivatives_switches @@ -1123,14 +1123,14 @@ Dispatches `fmm!` using an existing `::Tree`. function fmm!(tree::Tree, systems, m2l_list, direct_target_bodies, direct_source_bodies, derivatives_switches; reset_tree=true, nearfield=true, upward_pass=true, horizontal_pass=true, downward_pass=true, - unsort_bodies=true, gpu=false + unsort_bodies=true, gpu=false, method=ScalarPlusVector ) fmm!(tree, systems, tree, systems, m2l_list, direct_target_bodies, direct_source_bodies, derivatives_switches; reset_source_tree=reset_tree, reset_target_tree=false, nearfield, upward_pass, horizontal_pass, downward_pass, unsort_source_bodies=unsort_bodies, unsort_target_bodies=false, - gpu + gpu, method ) end @@ -1184,7 +1184,7 @@ function fmm!(target_tree::Tree, target_systems, source_tree::Tree, source_syste reset_source_tree=true, reset_target_tree=true, nearfield=true, upward_pass=true, horizontal_pass=true, downward_pass=true, unsort_source_bodies=true, unsort_target_bodies=true, - gpu=false + gpu=false, method::Type{<:AbstractMethod}=ScalarPlusVector ) # check if systems are empty n_sources = get_n_bodies(source_systems) @@ -1204,10 +1204,10 @@ function fmm!(target_tree::Tree, target_systems, source_tree::Tree, source_syste @sync begin Threads.@spawn nearfield_singlethread!(target_systems, direct_target_bodies, derivatives_switches, source_systems, direct_source_bodies, Val(gpu)) - upward_pass && upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, source_tree.levels_index, source_tree.leaf_index, n_threads-1) - horizontal_pass && length(m2l_list) > 0 && horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, n_threads-1) + upward_pass && upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, source_tree.levels_index, source_tree.leaf_index, n_threads-1, method) + horizontal_pass && length(m2l_list) > 0 && horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, n_threads-1, method) end - downward_pass && downward_pass_multithread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order, target_tree.levels_index, target_tree.leaf_index, n_threads) + downward_pass && downward_pass_multithread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order, target_tree.levels_index, target_tree.leaf_index, n_threads, method) else @@ -1215,16 +1215,16 @@ function fmm!(target_tree::Tree, target_systems, source_tree::Tree, source_syste if n_threads == 1 # && !gpu || gpu nearfield_singlethread!(target_systems, direct_target_bodies, derivatives_switches, source_systems, direct_source_bodies, Val(gpu)) - upward_pass && upward_pass_singlethread!(source_tree.branches, source_systems, source_tree.expansion_order) - horizontal_pass && length(m2l_list) > 0 && horizontal_pass_singlethread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order) - downward_pass && downward_pass_singlethread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order) + upward_pass && upward_pass_singlethread!(source_tree.branches, source_systems, source_tree.expansion_order, method) + horizontal_pass && length(m2l_list) > 0 && horizontal_pass_singlethread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, method) + downward_pass && downward_pass_singlethread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order, method) else # n_threads > 1 && !gpu nearfield_multithread!(target_systems, direct_target_bodies, derivatives_switches, source_systems, direct_source_bodies, n_threads) - upward_pass && upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, source_tree.levels_index, source_tree.leaf_index, n_threads) - horizontal_pass && length(m2l_list) > 0 && horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, n_threads) - downward_pass && downward_pass_multithread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order, target_tree.levels_index, target_tree.leaf_index, n_threads) + upward_pass && upward_pass_multithread!(source_tree.branches, source_systems, source_tree.expansion_order, source_tree.levels_index, source_tree.leaf_index, n_threads, method) + horizontal_pass && length(m2l_list) > 0 && horizontal_pass_multithread!(target_tree.branches, source_tree.branches, m2l_list, source_tree.expansion_order, n_threads, method) + downward_pass && downward_pass_multithread!(target_tree.branches, target_systems, derivatives_switches, target_tree.expansion_order, target_tree.levels_index, target_tree.leaf_index, n_threads, method) end end diff --git a/test/runtests.jl b/test/runtests.jl index 753d822..aea076a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,7 +22,7 @@ include(joinpath(test_dir, "gravitational.jl")) # allow expansion_order::Int to be used while testing FastMultipole.M2M!(branch, child, harmonics, M, expansion_order::Int) = FastMultipole.M2M!(branch, child, harmonics, M, Val(expansion_order)) FastMultipole.M2L!(target_branch, source_branch, harmonics, L, expansion_order::Int) = FastMultipole.M2L!(target_branch, source_branch, harmonics, L, Val(expansion_order)) -FastMultipole.upward_pass_singlethread!(branches, systems, expansion_order::Int) = FastMultipole.upward_pass_singlethread!(branches, systems, Val(expansion_order)) +FastMultipole.upward_pass_singlethread!(branches, systems, expansion_order::Int) = FastMultipole.upward_pass_singlethread!(branches, systems, Val(expansion_order), FastMultipole.ScalarPlusVector) FastMultipole.M2L!(target_branch, source_branch, expansion_order::Int) = FastMultipole.M2L!(target_branch, source_branch, Val(expansion_order)) @testset "complex" begin @@ -568,7 +568,7 @@ u_fmm_67 = mass_target_potential[1] # perform horizontal pass m2l_list, direct_list_target, direct_list_source = FastMultipole.build_interaction_lists(tree.branches, tree.branches, tree.leaf_index, multipole_threshold, true, true, true) FastMultipole.nearfield_singlethread!((elements,), direct_list_target, (FastMultipole.DerivativesSwitch(),), (elements,), direct_list_source) -FastMultipole.horizontal_pass_singlethread!(tree.branches, tree.branches, m2l_list, expansion_order) +FastMultipole.horizontal_pass_singlethread!(tree.branches, tree.branches, m2l_list, expansion_order, FastMultipole.ScalarPlusVector) # consider the effect on branch 3 (mass 2) elements.potential[i_POTENTIAL,new_order_index[2]] .*= 0 # reset potential at mass 2