Skip to content

Commit

Permalink
Merge pull request #79 from ACEsuit/co/lenbasis
Browse files Browse the repository at this point in the history
Introduce `basis_size`
  • Loading branch information
cortner authored Aug 29, 2024
2 parents 9290dcb + c097b5f commit c94ce7c
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions src/assemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@ using ParallelDataTransfer
using ProgressMeter
using SharedArrays

"""
basis_size(model)
Return the length of the basis, assuming that `model` is a linear model, or
when interpreted as a linear model. The returned integer must match the
size of the feature matrix that will be assembled for the given model.
It defaults to `Base.length` but can be overloaded if needed.
"""
basis_size(model) = Base.length(model)

struct DataPacket{T <: AbstractData}
rows::UnitRange
data::T
Expand All @@ -23,8 +34,8 @@ function assemble(data::AbstractVector{<:AbstractData}, basis)
packets = DataPacket.(rows, data)
sort!(packets, by = length, rev = true)
(nprocs() > 1) && sendto(workers(), basis = basis)
@info " - Creating feature matrix with size ($(rows[end][end]), $(length(basis)))."
A = SharedArray(zeros(rows[end][end], length(basis)))
@info " - Creating feature matrix with size ($(rows[end][end]), $(basis_size(basis)))."
A = SharedArray(zeros(rows[end][end], basis_size(basis)))
Y = SharedArray(zeros(size(A, 1)))
@info " - Beginning assembly with processor count: $(nprocs())."
@showprogress pmap(packets) do p
Expand Down

0 comments on commit c94ce7c

Please sign in to comment.