Skip to content

Commit

Permalink
Merge branch 'wcw/threaded-assembly' into dev-assemble
Browse files Browse the repository at this point in the history
  • Loading branch information
wcwitt committed Sep 22, 2023
2 parents c4c1aa0 + 82ee211 commit 01f9a7f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.4"

[deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LowRankApprox = "898213cb-b102-5a47-900c-97e73b919f73"
Expand Down
26 changes: 20 additions & 6 deletions src/assemble.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Distributed
using Folds
using ParallelDataTransfer
using ProgressMeter
using SharedArrays
Expand Down Expand Up @@ -42,7 +43,7 @@ end
"""
Assemble feature matrix and target vector for given data and basis.
"""
function assemble(data::AbstractVector{<:AbstractData}, basis; do_gc = true)
function assemble_mixed(data::AbstractVector{<:AbstractData}, basis; mode=:threaded)
@info "Assembling linear problem."
rows = Array{UnitRange}(undef, length(data)) # row ranges for each element of data
rows[1] = 1:count_observations(data[1])
Expand All @@ -51,15 +52,26 @@ function assemble(data::AbstractVector{<:AbstractData}, basis; do_gc = true)
end
packets = DataPacket.(rows, data)
sort!(packets, by = length, rev = true)
(nprocs() > 1) && sendto(workers(), basis = basis)
@info " - Creating feature matrix with size ($(rows[end][end]), $(length(basis)))."
A = SharedArray(zeros(rows[end][end], length(basis)))
Y = SharedArray(zeros(size(A, 1)))
@info " - Beginning assembly with processor count: $(nprocs())."
@showprogress pmap(packets) do p
A[p.rows, :] .= feature_matrix(p.data, basis)
W = SharedArray(zeros(size(A, 1)))
if mode == :serial
@info " - Beginning serial assembly."
elseif mode == :threaded
@info " - Beginning threaded assembly with $(Threads.nthreads()) threads."
map = Folds.map
elseif mode == :distributed
@info " - Beginning distributed assembly with $(nprocs()) processes."
map = pmap
(nprocs() > 1) && sendto(workers(), basis = basis)
end
progress = Progress(length(data))
map(packets) do p
A[p.rows,:] .= feature_matrix(p.data, basis)
Y[p.rows] .= target_vector(p.data)
do_gc && GC.gc()
next!(progress)
GC.gc()
end
@info " - Assembly completed."
return Array(A), Array(Y), assemble_weights(data)
Expand All @@ -80,6 +92,8 @@ function assemble_weights(data::AbstractVector{<:AbstractData})
W = SharedArray(zeros(rows[end][end]))
@showprogress pmap(packets) do p
W[p.rows] .= weight_vector(p.data)
next!(progress)
GC.gc()
end
return Array(W)
end
Expand Down

0 comments on commit 01f9a7f

Please sign in to comment.