diff --git a/src/sampling.jl b/src/sampling.jl index a80693dcf..cbb0d5c5d 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -608,6 +608,11 @@ sample(wv::AbstractWeights) = sample(default_rng(), wv) sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)] sample(a::AbstractArray, wv::AbstractWeights) = sample(default_rng(), a, wv) +function sample(rng::AbstractRNG, wv::AbstractWeights{<:Real,T,V}) where {T<:Real,V<:SparseVector{T}} + i = sample(rng, Weights(nonzeros(wv.values), sum(wv))) + return SparseArrays.nonzeroinds(wv.values)[i] +end + """ direct_sample!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray) diff --git a/test/wsampling.jl b/test/wsampling.jl index efe9a608f..9b7ebc155 100644 --- a/test/wsampling.jl +++ b/test/wsampling.jl @@ -1,5 +1,5 @@ using StatsBase -using Random, Test, OffsetArrays +using Random, Test, OffsetArrays, SparseArrays Random.seed!(1234) @@ -41,6 +41,7 @@ for wv in ( weights([0.2, 0.8, 0.4, 0.6]), weights([2, 8, 4, 6]), weights(Float32[0.2, 0.8, 0.4, 0.6]), + weights(sparsevec([0, 8, 0, 6])), Weights(Float32[0.2, 0.8, 0.4, 0.6], 2), Weights([2, 8, 4, 6], 20.0), )