diff --git a/src/sampling.jl b/src/sampling.jl index 2dffbbacd..e6569ba21 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -515,16 +515,28 @@ function efraimidis_ares_wsample_norep!(a::AbstractArray, wv::WeightVec, x::Abst # initialize priority queue pq = Vector{Pair{Float64,Int}}(k) - @inbounds for i in 1:k - pq[i] = (wv.values[i]/randexp() => i) + i = 0 + s = 0 + @inbounds for s in 1:n + w = wv.values[s] + w < 0 && error("Negative weight found in weight vector at index $s") + if w > 0 + i += 1 + pq[i] = (w/randexp() => s) + end + i >= k && break end + i < k && throw(DimensionMismatch("wv must have at least $k strictly positive entries (got $i)")) heapify!(pq) # set threshold @inbounds threshold = pq[1].first - @inbounds for i in k+1:n - key = wv.values[i]/randexp() + @inbounds for i in s+1:n + w = wv.values[i] + w < 0 && error("Negative weight found in weight vector at index $i") + w > 0 || continue + key = w/randexp() # if key is larger than the threshold if key > threshold @@ -561,17 +573,28 @@ function efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::WeightVec, x::Abs # initialize priority queue pq = Vector{Pair{Float64,Int}}(k) - @inbounds for i in 1:k - pq[i] = (wv.values[i]/randexp() => i) + i = 0 + s = 0 + @inbounds for s in 1:n + w = wv.values[s] + w < 0 && error("Negative weight found in weight vector at index $s") + if w > 0 + i += 1 + pq[i] = (w/randexp() => s) + end + i >= k && break end + i < k && throw(DimensionMismatch("wv must have at least $k strictly positive entries (got $i)")) heapify!(pq) # set threshold @inbounds threshold = pq[1].first X = threshold*randexp() - @inbounds for i in k+1:n + @inbounds for i in s+1:n w = wv.values[i] + w < 0 && error("Negative weight found in weight vector at index $i") + w > 0 || continue X -= w X <= 0 || continue diff --git a/test/sampling.jl b/test/sampling.jl index c19094f99..bc6027cd7 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -149,3 +149,28 @@ check_sample_norep(a, (3, 12), 0; ordered=false) a = sample(3:12, 5; replace=false, ordered=true) check_sample_norep(a, (3, 12), 0; ordered=true) + +# test of weighted sampling without replacement +a = [1:10;] +wv = WeightVec([zeros(6); 1:4]) +x = vcat([sample(a, wv, 1, replace=false) for j in 1:100000]...) +@test minimum(x) == 7 +@test maximum(x) == 10 +@test maximum(abs(proportions(x) - (1:4)/10)) < 0.01 + +x = vcat([sample(a, wv, 2, replace=false) for j in 1:50000]...) +exact2 = [0.117261905, 0.220634921, 0.304166667, 0.357936508] +@test minimum(x) == 7 +@test maximum(x) == 10 +@test maximum(abs(proportions(x) - exact2)) < 0.01 + +x = vcat([sample(a, wv, 4, replace=false) for j in 1:10000]...) +@test minimum(x) == 7 +@test maximum(x) == 10 +@test maximum(abs(proportions(x) - 0.25)) == 0 + +@test_throws DimensionMismatch sample(a, wv, 5, replace=false) + +wv = WeightVec([zeros(5); 1:4; -1]) +@test_throws ErrorException sample(a, wv, 1, replace=false) +