Skip to content

Commit

Permalink
Specialize sample for sparse weights
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Nov 25, 2024
1 parent 4f98560 commit 22957fc
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,11 @@ sample(wv::AbstractWeights) = sample(default_rng(), wv)
sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)]
sample(a::AbstractArray, wv::AbstractWeights) = sample(default_rng(), a, wv)

function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,T,V}) where {T<:Real,V<:SparseVector{T}}
i = sample(rng, Weights(nonzeros(wv.values), sum(wv)))
return SparseArrays.nonzeroinds(wv.values)[i]
end

"""
direct_sample!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray)
Expand Down
3 changes: 2 additions & 1 deletion test/wsampling.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using StatsBase
using Random, Test, OffsetArrays
using Random, Test, OffsetArrays, SparseArrays

Random.seed!(1234)

Expand Down Expand Up @@ -41,6 +41,7 @@ for wv in (
weights([0.2, 0.8, 0.4, 0.6]),
weights([2, 8, 4, 6]),
weights(Float32[0.2, 0.8, 0.4, 0.6]),
weights(sparsevec([0, 8, 0, 6])),
Weights(Float32[0.2, 0.8, 0.4, 0.6], 2),
Weights([2, 8, 4, 6], 20.0),
)
Expand Down

0 comments on commit 22957fc

Please sign in to comment.