Skip to content

Commit

Permalink
Included export inside @threads loop
Browse files Browse the repository at this point in the history
  • Loading branch information
kaipartmann committed Feb 23, 2024
1 parent 7d897bd commit 82411cb
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 79 deletions.
1 change: 1 addition & 0 deletions src/Peridynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const FRAC_KWARGS = (:Gc, :epsilon_c)
const DEFAULT_POINT_KWARGS = (:horizon, :rho, ELASTIC_KWARGS..., FRAC_KWARGS...)
const CONTACT_KWARGS = (:radius, :sc)
const EXPORT_KWARGS = (:path, :freq, :write)
const DEFAULT_EXPORT_FIELDS = (:displacement, :damage)
const JOB_KWARGS = (EXPORT_KWARGS...,)
const SUBMIT_KWARGS = (:quiet,)

Expand Down
91 changes: 79 additions & 12 deletions src/auxiliary/io.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
const ExportField = Tuple{Symbol,DataType}

struct ExportOptions{N}
exportflag::Bool
root::String
vtk::String
logfile::String
freq::Int
fields::NTuple{N,Symbol}
fields::NTuple{N,ExportField}

function ExportOptions(root::String, freq::Int, fields::NTuple{N,Symbol}) where {N}
function ExportOptions(root::String, freq::Int, fields::NTuple{N,ExportField}) where {N}
if isempty(root)
return new{0}(false, "", "", "", 0, NTuple{0,Symbol}())
return new{0}(false, "", "", "", 0, NTuple{0,ExportField}())
end
vtk = joinpath(root, "vtk")
logfile = joinpath(root, "logfile.log")
return new{N}(true, root, vtk, logfile, freq, fields)
end
end

function get_export_options(::Type{M}, o::Dict{Symbol,Any}) where {M<:AbstractMaterial}
function get_export_options(::Type{S}, o::Dict{Symbol,Any}) where {S<:AbstractStorage}
local root::String
local freq::Int
local fields::NTuple{N,Symbol} where {N}

if haskey(o, :path) && haskey(o, :freq)
root = string(o[:path])
Expand All @@ -37,20 +37,87 @@ function get_export_options(::Type{M}, o::Dict{Symbol,Any}) where {M<:AbstractMa
end
freq < 0 && throw(ArgumentError("`freq` should be larger than zero!\n"))

fields = get_export_fields(S, o)

return ExportOptions(root, freq, fields)
end

function get_export_fields(::Type{S}, o::Dict{Symbol,Any}) where {S}
export_fieldnames = get_export_fieldnames(o)
storage_fieldnames = fieldnames(S)
storage_fieldtypes = fieldtypes(S)

_export_fields = Vector{ExportField}()

for name in export_fieldnames
idx = findfirst(x -> x === name, storage_fieldnames)
isnothing(idx) && unknown_fieldname_error(S, name)
type = storage_fieldtypes[idx]
push!(_export_fields, (name, type))
end

export_fields = Tuple(_export_fields)

return export_fields
end

function unknown_fieldname_error(::Type{S}, name::Symbol) where {S<:AbstractStorage}
msg = "unknown field $(name) specified for export!\n"
msg *= "Allowed fields for $S:\n"
for allowed_name in fieldnames(S)
msg *= " - $allowed_name\n"
end
throw(ArgumentError(msg))
end

function get_export_fieldnames(o::Dict{Symbol,Any})
local export_fieldnames::NTuple{N,Symbol} where {N}
if haskey(o, :write)
fields = o[:write]
export_fieldnames = o[:write]
else
fields = default_export_fields(M)
export_fieldnames = DEFAULT_EXPORT_FIELDS
end

return ExportOptions(root, freq, fields)
return export_fieldnames
end

function export_results(dh::AbstractDataHandler, options::ExportOptions, timestep::Int,
time::Float64)
function export_results(dh::AbstractDataHandler, options::ExportOptions, chunk_id::Int,
timestep::Int, time::Float64)
options.exportflag || return nothing
if mod(timestep, options.freq) == 0
_export_results(dh, options, timestep, time)
_export_results(dh.chunks[chunk_id], chunk_id, dh.n_chunks, options, timestep, time)
end
return nothing
end

function export_reference_results(dh::AbstractDataHandler, options::ExportOptions)
options.exportflag || return nothing
@threads :static for chunk_id in eachindex(dh.chunks)
_export_results(dh.chunks[chunk_id], chunk_id, dh.n_chunks, options, 0, 0.0)
end
return nothing
end

function _export_results(b::AbstractBodyChunk, chunk_id::Int, n_chunks::Int,
options::ExportOptions, n::Int, t::Float64)
filename = joinpath(options.vtk, @sprintf("timestep_%05d", n))
position = get_loc_position(b)
pvtk_grid(filename, position, b.cells; part=chunk_id, nparts=n_chunks) do vtk
for (field, type) in options.fields
# TODO:
# - solve type instability of `getfield`
# - check if field is part of halo write or read access -> correct length!
vtk[string(field), VTKPointData()] = get_export_field(b.store, field, type)
end
vtk["time", VTKFieldData()] = t
end
return nothing
end

@inline function get_export_field(s::AbstractStorage, name::Symbol, ::Type{V}) where {V}
export_field::V = getfield(s, name)
return export_field
end

@inline function get_loc_position(b::AbstractBodyChunk)
return @views b.store.position[:, 1:b.ch.n_loc_points]
end
2 changes: 1 addition & 1 deletion src/core/job.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ end
function Job(spatial_setup::S, time_solver::T; kwargs...) where {S,T}
o = Dict{Symbol,Any}(kwargs)
check_kwargs(o, JOB_KWARGS)
options = get_export_options(material_type(spatial_setup), o)
options = get_export_options(storage_type(spatial_setup.mat, time_solver), o)
return Job(spatial_setup, time_solver, options)
end
30 changes: 7 additions & 23 deletions src/core/threads_data_handler.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
struct ThreadsDataHandler{C<:AbstractBodyChunk} <: AbstractDataHandler
n_chunks::Int
chunks::Vector{C}
halo_exchanges::Vector{Vector{HaloExchange}}
end
Expand All @@ -12,13 +13,14 @@ end

function ThreadsDataHandler(body::Body, time_solver::AbstractTimeSolver,
point_decomp::PointDecomposition, v::Val{N}) where {N}
body_chunks = chop_body_threads(body, time_solver, point_decomp, v)
_halo_exchanges = find_halo_exchanges(body_chunks)
halo_exchanges = [Vector{HaloExchange}() for _ in eachindex(body_chunks)]
@threads :static for chunk_id in eachindex(body_chunks)
chunks = chop_body_threads(body, time_solver, point_decomp, v)
n_chunks = length(chunks)
_halo_exchanges = find_halo_exchanges(chunks)
halo_exchanges = [Vector{HaloExchange}() for _ in eachindex(chunks)]
@threads :static for chunk_id in eachindex(chunks)
halo_exchanges[chunk_id] = filter(x -> x.dest_chunk_id == chunk_id, _halo_exchanges)
end
return ThreadsDataHandler(body_chunks, halo_exchanges)
return ThreadsDataHandler(n_chunks, chunks, halo_exchanges)
end

function ThreadsDataHandler(multibody::MultibodySetup, time_solver::AbstractTimeSolver,
Expand All @@ -36,24 +38,6 @@ end

get_cells(n::Int) = [MeshCell(VTKCellTypes.VTK_VERTEX, (i,)) for i in 1:n]

function _export_results(dh::ThreadsDataHandler, options::ExportOptions, n::Int, t::Float64)
filename = joinpath(options.vtk, @sprintf("timestep_%05d", n))
n_chunks = length(dh.chunks)
@threads :static for chunk_id in eachindex(dh.chunks)
chunk = dh.chunks[chunk_id]
n_local_points = length(chunk.ch.loc_points)
position = @views chunk.store.position[:, 1:n_local_points]
cells = get_cells(n_local_points)
pvtk_grid(filename, position, cells; part=chunk_id, nparts=n_chunks) do vtk
for fld in options.fields
vtk[string(fld), VTKPointData()] = getfield(chunk.store, fld)
end
vtk["time", VTKFieldData()] = t
end
end
return nothing
end

function halo_exchange!(dh::ThreadsDataHandler, chunk_id::Int)
for he in dh.halo_exchanges[chunk_id]
src_field = get_exchange_field(dh.chunks[he.src_chunk_id], he.field)
Expand Down
8 changes: 6 additions & 2 deletions src/discretizations/body_chunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ struct MultiParamBodyChunk{M<:AbstractMaterial,P<:AbstractPointParameters,
psets::Dict{Symbol,Vector{Int}}
sdbcs::Vector{SingleDimBC}
ch::ChunkHandler
cells::Vector{MeshCell{VTKCellType, Tuple{Int64}}}
end

function MultiParamBodyChunk(body::Body{M,P}, ts::T, pd::PointDecomposition,
Expand All @@ -20,7 +21,8 @@ function MultiParamBodyChunk(body::Body{M,P}, ts::T, pd::PointDecomposition,
param = body.point_params
psets = localized_point_sets(body.point_sets, ch)
sdbcs = body.single_dim_bcs
return MultiParamBodyChunk(mat, dscr, store, param, paramap, psets, sdbcs, ch)
cells = get_cells(ch.n_loc_points)
return MultiParamBodyChunk(mat, dscr, store, param, paramap, psets, sdbcs, ch, cells)
end

struct BodyChunk{M<:AbstractMaterial,P<:AbstractPointParameters,D<:AbstractDiscretization,
Expand All @@ -32,6 +34,7 @@ struct BodyChunk{M<:AbstractMaterial,P<:AbstractPointParameters,D<:AbstractDiscr
psets::Dict{Symbol,Vector{Int}}
sdbcs::Vector{SingleDimBC}
ch::ChunkHandler
cells::Vector{MeshCell{VTKCellType, Tuple{Int64}}}
end

function BodyChunk(body::Body{M,P}, ts::T, pd::PointDecomposition,
Expand All @@ -43,7 +46,8 @@ function BodyChunk(body::Body{M,P}, ts::T, pd::PointDecomposition,
param = first(body.point_params)
psets = localized_point_sets(body.point_sets, ch)
sdbcs = body.single_dim_bcs
return BodyChunk(mat, dscr, store, param, psets, sdbcs, ch)
cells = get_cells(ch.n_loc_points)
return BodyChunk(mat, dscr, store, param, psets, sdbcs, ch, cells)
end

@inline function get_param(b::MultiParamBodyChunk, point_id::Int)
Expand Down
5 changes: 4 additions & 1 deletion src/discretizations/chunk_handler.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
struct ChunkHandler
n_loc_points::Int
point_ids::Vector{Int}
loc_points::UnitRange{Int}
halo_points::Vector{Int}
Expand All @@ -8,11 +9,13 @@ end

function ChunkHandler(bonds::Vector{Bond}, pd::PointDecomposition, chunk_id::Int)
loc_points = pd.decomp[chunk_id]
n_loc_points = length(loc_points)
halo_points = find_halo_points(bonds, loc_points)
halo_by_src = sort_halo_by_src!(halo_points, pd.point_src, length(loc_points))
point_ids = vcat(loc_points, halo_points)
localizer = find_localizer(point_ids)
return ChunkHandler(point_ids, loc_points, halo_points, halo_by_src, localizer)
return ChunkHandler(n_loc_points, point_ids, loc_points, halo_points, halo_by_src,
localizer)
end

function find_halo_points(bonds::Vector{Bond}, loc_points::UnitRange{Int})
Expand Down
4 changes: 2 additions & 2 deletions src/physics/bond_based.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ end
@inline reads_from_halo(::BBMaterial) = (:position,)
@inline writes_to_halo(::BBMaterial) = ()

function force_density!(s::BBStorage, bd::BondDiscretization, param::BBPointParameters,
i::Int)
function force_density_point!(s::BBStorage, bd::BondDiscretization,
param::BBPointParameters, i::Int)
for bond_id in each_bond_idx(bd, i)
bond = bd.bonds[bond_id]
j, L = bond.neighbor, bond.length
Expand Down
2 changes: 1 addition & 1 deletion src/physics/force_density.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function calc_force_density!(b::AbstractBodyChunk)
b.store.n_active_bonds .= 0
for point_id in each_point_idx(b)
param = get_param(b, point_id)
force_density!(b.store, b.dscr, param, point_id)
force_density_point!(b.store, b.dscr, param, point_id)
end
return nothing
end
4 changes: 0 additions & 4 deletions src/physics/material_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,3 @@ function writes_to_halo(mat::AbstractMaterial)
end

#---- optional interface functions ----#

@inline function default_export_fields(::Type{M}) where {M<:AbstractMaterial}
return (:displacement, :damage)
end
20 changes: 10 additions & 10 deletions src/time_solvers/solve_velocity_verlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ function _update_vel!(velocity, velocity_half, acceleration, Δt½, i)
end

function solve!(dh::ThreadsDataHandler, vv::VelocityVerlet, options::ExportOptions)
_export_results(dh, options, 0, 0.0)
export_reference_results(dh, options)

Δt = vv.Δt
Δt½ = 0.5 * vv.Δt
Expand All @@ -83,19 +83,19 @@ function solve!(dh::ThreadsDataHandler, vv::VelocityVerlet, options::ExportOptio
for n in 1:vv.n_steps
t = n * Δt
@threads :static for chunk_id in eachindex(dh.chunks)
body_chunk = dh.chunks[chunk_id]
update_vel_half!(body_chunk, Δt½)
apply_bcs!(body_chunk, t)
update_disp_and_pos!(body_chunk, Δt)
chunk = dh.chunks[chunk_id]
update_vel_half!(chunk, Δt½)
apply_bcs!(chunk, t)
update_disp_and_pos!(chunk, Δt)
end
@threads :static for chunk_id in eachindex(dh.chunks)
halo_exchange!(dh, chunk_id)
body_chunk = dh.chunks[chunk_id]
calc_force_density!(body_chunk)
calc_damage!(body_chunk)
update_acc_and_vel!(body_chunk, Δt½)
chunk = dh.chunks[chunk_id]
calc_force_density!(chunk)
calc_damage!(chunk)
update_acc_and_vel!(chunk, Δt½)
export_results(dh, options, chunk_id, n, t)
end
export_results(dh, options, n, t)
next!(p)
end
finish!(p)
Expand Down
32 changes: 21 additions & 11 deletions test/auxiliary/test_io.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,55 @@
@testitem "ExportOptions" begin
@testitem "ExportOptions BBVerletStorage" begin
o = Dict{Symbol,Any}(:path => "rootpath", :freq => 10)
eo = Peridynamics.get_export_options(BBMaterial, o)
eo = Peridynamics.get_export_options(Peridynamics.BBVerletStorage, o)
@test eo.exportflag == true
@test eo.root == "rootpath"
@test eo.vtk == joinpath("rootpath", "vtk")
@test eo.logfile == joinpath("rootpath", "logfile.log")
@test eo.freq == 10
@test eo.fields == (:displacement, :damage)
@test eo.fields == ((:displacement, Matrix{Float64}), (:damage, Vector{Float64}))

o = Dict{Symbol,Any}(:path => "rootpath")
eo = Peridynamics.get_export_options(BBMaterial, o)
eo = Peridynamics.get_export_options(Peridynamics.BBVerletStorage, o)
@test eo.exportflag == true
@test eo.root == "rootpath"
@test eo.vtk == joinpath("rootpath", "vtk")
@test eo.logfile == joinpath("rootpath", "logfile.log")
@test eo.freq == 10
@test eo.fields == (:displacement, :damage)
@test eo.fields == ((:displacement, Matrix{Float64}), (:damage, Vector{Float64}))

o = Dict{Symbol,Any}(:freq => 10)
msg = "if `freq` is spedified, the keyword `path` is also needed!\n"
@test_throws ArgumentError(msg) Peridynamics.get_export_options(BBMaterial, o)
@test_throws ArgumentError(msg) begin
Peridynamics.get_export_options(Peridynamics.BBVerletStorage, o)
end

o = Dict{Symbol,Any}()
eo = Peridynamics.get_export_options(BBMaterial, o)
eo = Peridynamics.get_export_options(Peridynamics.BBVerletStorage, o)
@test eo.exportflag == false
@test eo.root == ""
@test eo.vtk == ""
@test eo.logfile == ""
@test eo.freq == 0
@test eo.fields == NTuple{0,Symbol}()
@test eo.fields == NTuple{0,Tuple{Symbol,DataType}}()

o = Dict{Symbol,Any}(:path => "rootpath", :freq => -10)
msg = "`freq` should be larger than zero!\n"
@test_throws ArgumentError(msg) Peridynamics.get_export_options(BBMaterial, o)
@test_throws ArgumentError(msg) begin
Peridynamics.get_export_options(Peridynamics.BBVerletStorage, o)
end

o = Dict{Symbol,Any}(:path => "rootpath", :write => (:displacement, :damage, :b_int))
eo = Peridynamics.get_export_options(BBMaterial, o)
eo = Peridynamics.get_export_options(Peridynamics.BBVerletStorage, o)
@test eo.exportflag == true
@test eo.root == "rootpath"
@test eo.vtk == joinpath("rootpath", "vtk")
@test eo.logfile == joinpath("rootpath", "logfile.log")
@test eo.freq == 10
@test eo.fields == (:displacement, :damage, :b_int)
@test eo.fields == ((:displacement, Matrix{Float64}),
(:damage, Vector{Float64}),
(:b_int, Matrix{Float64}))
end

@testitem "export_results" begin

end
Loading

0 comments on commit 82411cb

Please sign in to comment.