Skip to content

Commit

Permalink
Improved performance; halo_exchange now type stable
Browse files Browse the repository at this point in the history
  • Loading branch information
kaipartmann committed Feb 21, 2024
1 parent a13fce5 commit bea3462
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 85 deletions.
8 changes: 4 additions & 4 deletions src/core/halo_exchange.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@

struct HaloExchange
field::Symbol
from_chunk_id::Int
to_chunk_id::Int
from_loc_idxs::Vector{Int}
to_loc_idxs::Vector{Int}
src_chunk_id::Int
dest_chunk_id::Int
src_idxs::Vector{Int}
dest_idxs::Vector{Int}
end

function find_halo_exchanges(body_chunks::Vector{B}) where {B<:BodyChunk}
Expand Down
32 changes: 15 additions & 17 deletions src/core/threads_data_handler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function ThreadsDataHandler(body::Body, time_solver::AbstractTimeSolver,
_halo_exchanges = find_halo_exchanges(body_chunks)
halo_exchanges = [Vector{HaloExchange}() for _ in eachindex(body_chunks)]
@threads :static for chunk_id in eachindex(body_chunks)
halo_exchanges[chunk_id] = filter(x -> x.to_chunk_id == chunk_id, _halo_exchanges)
halo_exchanges[chunk_id] = filter(x -> x.dest_chunk_id == chunk_id, _halo_exchanges)
end
return ThreadsDataHandler(body_chunks, halo_exchanges)
end
Expand Down Expand Up @@ -54,27 +54,25 @@ function _export_results(dh::ThreadsDataHandler, options::ExportOptions, n::Int,
return nothing
end

function halo_exchange!(dh::ThreadsDataHandler, chunk_id)
function halo_exchange!(dh::ThreadsDataHandler, chunk_id::Int)
for he in dh.halo_exchanges[chunk_id]
_halo_exchange!(dh, he)
src_field = get_exchange_field(dh.chunks[he.src_chunk_id], he.field)
dest_field = get_exchange_field(dh.chunks[he.dest_chunk_id], he.field)
exchange!(dest_field, src_field, he.dest_idxs, he.src_idxs)
end
return nothing
end

function _halo_exchange!(dh::ThreadsDataHandler, he::HaloExchange)
from_chunk = dh.chunks[he.from_chunk_id]
to_chunk = dh.chunks[he.to_chunk_id]
field = he.field
from_loc_idxs = he.from_loc_idxs
to_loc_idxs = he.to_loc_idxs
for i in eachindex(from_loc_idxs, to_loc_idxs)
from_idx = from_loc_idxs[i]
to_idx = to_loc_idxs[i]
from_store_entry = getfield(from_chunk.store, field)
to_store_entry = getfield(to_chunk.store, field)
to_store_entry[1, to_idx] = from_store_entry[1, from_idx]
to_store_entry[2, to_idx] = from_store_entry[2, from_idx]
to_store_entry[3, to_idx] = from_store_entry[3, from_idx]
@inline function get_exchange_field(b::AbstractBodyChunk, fieldname::Symbol)
return getfield(b.store, fieldname)::Matrix{Float64}
end

function exchange!(dest::Matrix{Float64}, src::Matrix{Float64}, dest_idxs::Vector{Int},
src_idxs::Vector{Int})
for i in eachindex(dest_idxs, src_idxs)
dest[1, dest_idxs[i]] = src[1, src_idxs[i]]
dest[2, dest_idxs[i]] = src[2, src_idxs[i]]
dest[3, dest_idxs[i]] = src[3, src_idxs[i]]
end
return nothing
end
128 changes: 64 additions & 64 deletions test/core/test_halo_exchange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,104 +19,104 @@

halo_exchanges = Peridynamics.find_halo_exchanges(body_chunks)
@test halo_exchanges[1].field === :position
@test halo_exchanges[1].from_chunk_id == 2
@test halo_exchanges[1].to_chunk_id == 1
@test halo_exchanges[1].from_loc_idxs == [1, 2]
@test halo_exchanges[1].to_loc_idxs == [3, 4]
@test halo_exchanges[1].src_chunk_id == 2
@test halo_exchanges[1].dest_chunk_id == 1
@test halo_exchanges[1].src_idxs == [1, 2]
@test halo_exchanges[1].dest_idxs == [3, 4]
@test halo_exchanges[2].field === :position
@test halo_exchanges[2].from_chunk_id == 1
@test halo_exchanges[2].to_chunk_id == 2
@test halo_exchanges[2].from_loc_idxs == [1, 2]
@test halo_exchanges[2].to_loc_idxs == [3, 4]
@test halo_exchanges[2].src_chunk_id == 1
@test halo_exchanges[2].dest_chunk_id == 2
@test halo_exchanges[2].src_idxs == [1, 2]
@test halo_exchanges[2].dest_idxs == [3, 4]

point_decomp = Peridynamics.PointDecomposition(body, 4)
body_chunks = Peridynamics.chop_body_threads(body, ts, point_decomp, Val{1}())

halo_exchanges = Peridynamics.find_halo_exchanges(body_chunks)

idx1 = findall(x -> x.to_chunk_id == 1, halo_exchanges)
he1 = sort(halo_exchanges[idx1]; by = x -> x.from_chunk_id)
idx1 = findall(x -> x.dest_chunk_id == 1, halo_exchanges)
he1 = sort(halo_exchanges[idx1]; by = x -> x.src_chunk_id)

@test he1[1].field === :position
@test he1[1].from_chunk_id == 2
@test he1[1].to_chunk_id == 1
@test he1[1].from_loc_idxs == [1]
@test he1[1].to_loc_idxs == [2]
@test he1[1].src_chunk_id == 2
@test he1[1].dest_chunk_id == 1
@test he1[1].src_idxs == [1]
@test he1[1].dest_idxs == [2]

@test he1[2].field === :position
@test he1[2].from_chunk_id == 3
@test he1[2].to_chunk_id == 1
@test he1[2].from_loc_idxs == [1]
@test he1[2].to_loc_idxs == [3]
@test he1[2].src_chunk_id == 3
@test he1[2].dest_chunk_id == 1
@test he1[2].src_idxs == [1]
@test he1[2].dest_idxs == [3]

@test he1[3].field === :position
@test he1[3].from_chunk_id == 4
@test he1[3].to_chunk_id == 1
@test he1[3].from_loc_idxs == [1]
@test he1[3].to_loc_idxs == [4]
@test he1[3].src_chunk_id == 4
@test he1[3].dest_chunk_id == 1
@test he1[3].src_idxs == [1]
@test he1[3].dest_idxs == [4]

idx2 = findall(x -> x.to_chunk_id == 2, halo_exchanges)
he2 = sort(halo_exchanges[idx2]; by = x -> x.from_chunk_id)
idx2 = findall(x -> x.dest_chunk_id == 2, halo_exchanges)
he2 = sort(halo_exchanges[idx2]; by = x -> x.src_chunk_id)

@test he2[1].field === :position
@test he2[1].from_chunk_id == 1
@test he2[1].to_chunk_id == 2
@test he2[1].from_loc_idxs == [1]
@test he2[1].to_loc_idxs == [2]
@test he2[1].src_chunk_id == 1
@test he2[1].dest_chunk_id == 2
@test he2[1].src_idxs == [1]
@test he2[1].dest_idxs == [2]

@test he2[2].field === :position
@test he2[2].from_chunk_id == 3
@test he2[2].to_chunk_id == 2
@test he2[2].from_loc_idxs == [1]
@test he2[2].to_loc_idxs == [3]
@test he2[2].src_chunk_id == 3
@test he2[2].dest_chunk_id == 2
@test he2[2].src_idxs == [1]
@test he2[2].dest_idxs == [3]

@test he2[3].field === :position
@test he2[3].from_chunk_id == 4
@test he2[3].to_chunk_id == 2
@test he2[3].from_loc_idxs == [1]
@test he2[3].to_loc_idxs == [4]
@test he2[3].src_chunk_id == 4
@test he2[3].dest_chunk_id == 2
@test he2[3].src_idxs == [1]
@test he2[3].dest_idxs == [4]


idx3 = findall(x -> x.to_chunk_id == 3, halo_exchanges)
he3 = sort(halo_exchanges[idx3]; by = x -> x.from_chunk_id)
idx3 = findall(x -> x.dest_chunk_id == 3, halo_exchanges)
he3 = sort(halo_exchanges[idx3]; by = x -> x.src_chunk_id)

@test he3[1].field === :position
@test he3[1].from_chunk_id == 1
@test he3[1].to_chunk_id == 3
@test he3[1].from_loc_idxs == [1]
@test he3[1].to_loc_idxs == [2]
@test he3[1].src_chunk_id == 1
@test he3[1].dest_chunk_id == 3
@test he3[1].src_idxs == [1]
@test he3[1].dest_idxs == [2]

@test he3[2].field === :position
@test he3[2].from_chunk_id == 2
@test he3[2].to_chunk_id == 3
@test he3[2].from_loc_idxs == [1]
@test he3[2].to_loc_idxs == [3]
@test he3[2].src_chunk_id == 2
@test he3[2].dest_chunk_id == 3
@test he3[2].src_idxs == [1]
@test he3[2].dest_idxs == [3]

@test he3[3].field === :position
@test he3[3].from_chunk_id == 4
@test he3[3].to_chunk_id == 3
@test he3[3].from_loc_idxs == [1]
@test he3[3].to_loc_idxs == [4]
@test he3[3].src_chunk_id == 4
@test he3[3].dest_chunk_id == 3
@test he3[3].src_idxs == [1]
@test he3[3].dest_idxs == [4]

idx4 = findall(x -> x.to_chunk_id == 4, halo_exchanges)
he4 = sort(halo_exchanges[idx4]; by = x -> x.from_chunk_id)
idx4 = findall(x -> x.dest_chunk_id == 4, halo_exchanges)
he4 = sort(halo_exchanges[idx4]; by = x -> x.src_chunk_id)

@test he4[1].field === :position
@test he4[1].from_chunk_id == 1
@test he4[1].to_chunk_id == 4
@test he4[1].from_loc_idxs == [1]
@test he4[1].to_loc_idxs == [2]
@test he4[1].src_chunk_id == 1
@test he4[1].dest_chunk_id == 4
@test he4[1].src_idxs == [1]
@test he4[1].dest_idxs == [2]

@test he4[2].field === :position
@test he4[2].from_chunk_id == 2
@test he4[2].to_chunk_id == 4
@test he4[2].from_loc_idxs == [1]
@test he4[2].to_loc_idxs == [3]
@test he4[2].src_chunk_id == 2
@test he4[2].dest_chunk_id == 4
@test he4[2].src_idxs == [1]
@test he4[2].dest_idxs == [3]

@test he4[3].field === :position
@test he4[3].from_chunk_id == 3
@test he4[3].to_chunk_id == 4
@test he4[3].from_loc_idxs == [1]
@test he4[3].to_loc_idxs == [4]
@test he4[3].src_chunk_id == 3
@test he4[3].dest_chunk_id == 4
@test he4[3].src_idxs == [1]
@test he4[3].dest_idxs == [4]

end

0 comments on commit bea3462

Please sign in to comment.