diff --git a/src/sampling.jl b/src/sampling.jl index ea19a9306..b8045e6f5 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -584,7 +584,10 @@ Optionally specify a random number generator `rng` as the first argument function sample(rng::AbstractRNG, wv::AbstractWeights) 1 == firstindex(wv) || throw(ArgumentError("non 1-based arrays are not supported")) - t = rand(rng) * sum(wv) + all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed")) + s = sum(wv) + s > 0 || throw(ArgumentError("sum of weights must be greater than 0")) + t = rand(rng) * s n = length(wv) i = 1 cw = wv[1] @@ -621,6 +624,8 @@ function direct_sample!(rng::AbstractRNG, a::AbstractArray, throw(ArgumentError("non 1-based arrays are not supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) + all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed")) + sum(wv) > 0 || throw(ArgumentError("sum of weights must be greater than 0")) for i = 1:length(x) x[i] = a[sample(rng, wv)] end @@ -710,6 +715,8 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, throw(ArgumentError("non 1-based arrays are not supported")) n = length(a) length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) + all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed")) + sum(wv) > 0 || throw(ArgumentError("sum of weights must be greater than 0")) # create alias table ap = Vector{Float64}(undef, n) @@ -749,6 +756,8 @@ function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray, n = length(a) length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) k = length(x) + all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed")) + sum(wv) > 0 || throw(ArgumentError("sum of weights must be greater than 0")) w = Vector{Float64}(undef, n) copyto!(w, wv) @@ -795,6 +804,8 @@ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray, n = length(a) length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) k = length(x) + all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed")) + sum(wv) > 0 || throw(ArgumentError("sum of weights must be greater than 0")) # calculate keys for all items keys = randexp(rng, n) @@ -845,14 +856,14 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, @inbounds for _s in 1:n s = _s w = wv.values[s] - w < 0 && error("Negative weight found in weight vector at index $s") + w < 0 && throw(ArgumentError("Negative weight found in weight vector at index $s")) if w > 0 i += 1 pq[i] = (w/randexp(rng) => s) end i >= k && break end - i < k && throw(DimensionMismatch("wv must have at least $k strictly positive entries (got $i)")) + i < k && throw(ArgumentError("wv must have at least $k strictly positive entries (got $i)")) heapify!(pq) # set threshold @@ -860,7 +871,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, @inbounds for i in s+1:n w = wv.values[i] - w < 0 && error("Negative weight found in weight vector at index $i") + w < 0 && throw(ArgumentError("Negative weight found in weight vector at index $i")) w > 0 || continue key = w/randexp(rng) @@ -918,14 +929,14 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, @inbounds for _s in 1:n s = _s w = wv.values[s] - w < 0 && error("Negative weight found in weight vector at index $s") + w < 0 && throw(ArgumentError("Negative weight found in weight vector at index $s")) if w > 0 i += 1 pq[i] = (w/randexp(rng) => s) end i >= k && break end - i < k && throw(DimensionMismatch("wv must have at least $k strictly positive entries (got $i)")) + i < k && throw(ArgumentError("wv must have at least $k strictly positive entries (got $i)")) heapify!(pq) # set threshold @@ -934,7 +945,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, @inbounds for i in s+1:n w = wv.values[i] - w < 0 && error("Negative weight found in weight vector at index $i") + w < 0 && throw(ArgumentError("Negative weight found in weight vector at index $i")) w > 0 || continue X -= w X <= 0 || continue @@ -991,7 +1002,7 @@ function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::Abs end end else - k <= n || error("Cannot draw $k samples from $n samples without replacement.") + k <= n || throw(ArgumentError("Cannot draw $k samples from $n samples without replacement.")) efraimidis_aexpj_wsample_norep!(rng, a, wv, x; ordered=ordered) end return x diff --git a/src/weights.jl b/src/weights.jl index 78091c2ae..465a57211 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -5,20 +5,24 @@ abstract type AbstractWeights{S<:Real, T<:Real, V<:AbstractVector{T}} <: Abstrac @weights name Generates a new generic weight type with specified `name`, which subtypes `AbstractWeights` -and stores the `values` (`V<:RealVector`) and `sum` (`S<:Real`). +and stores the `values` (`V<:RealVector`), the pre-computed `sum` (`S<:Real`) and +whether any values are `negative`. """ macro weights(name) return quote mutable struct $name{S<:Real, T<:Real, V<:AbstractVector{T}} <: AbstractWeights{S, T, V} values::V sum::S - function $(esc(name)){S, T, V}(values, sum) where {S<:Real, T<:Real, V<:AbstractVector{T}} + negative::Union{Bool, Missing} + function $(esc(name)){S, T, V}(values, sum, negative=missing) where {S<:Real, T<:Real, V<:AbstractVector{T}} isfinite(sum) || throw(ArgumentError("weights cannot contain Inf or NaN values")) - return new{S, T, V}(values, sum) + return new{S, T, V}(values, sum, negative) end end - $(esc(name))(values::AbstractVector{T}, sum::S) where {S<:Real, T<:Real} = $(esc(name)){S, T, typeof(values)}(values, sum) - $(esc(name))(values::AbstractVector{<:Real}) = $(esc(name))(values, sum(values)) + $(esc(name))(values::AbstractVector{T}, + sum::S=Base.sum(values), + negative::Union{Bool, Missing}=missing) where {S<:Real, T<:Real} = + $(esc(name)){S, T, typeof(values)}(values, sum, negative) end end @@ -53,9 +57,35 @@ Base.getindex(wv::W, ::Colon) where {W <: AbstractWeights} = W(copy(wv.values), isfinite(sum) || throw(ArgumentError("weights cannot contain Inf or NaN values")) wv.values[i] = v wv.sum = sum + wv.negative = v < zero(v) ? true : + wv.negative === false ? false : missing v end +function Base.all(f::Base.Fix2{typeof(>=)}, wv::AbstractWeights) + if iszero(f.x) + if wv.negative === missing + # sum is significantly faster than all when no entries are negative + wv.negative = sum(<(0), wv.values) > 0 + end + return !wv.negative + else + return all(f, wv.values) + end +end + +function Base.any(f::Base.Fix2{typeof(<)}, wv::AbstractWeights) + if iszero(f.x) + if wv.negative === missing + # sum is significantly faster than all when no entries are negative + wv.negative = sum(<(0), wv.values) > 0 + end + return wv.negative + else + return any(f, wv.values) + end +end + """ varcorrection(n::Integer, corrected=false) @@ -333,6 +363,9 @@ end Base.getindex(wv::UnitWeights{T}, ::Colon) where {T} = UnitWeights{T}(wv.len) +Base.all(f::Base.Fix2{typeof(>=)}, wv::UnitWeights{T}) where {T} = one(T) >= f.x +Base.any(f::Base.Fix2{typeof(<)}, wv::UnitWeights{T}) where {T} = one(T) < f.x + """ uweights(s::Integer) uweights(::Type{T}, s::Integer) where T<:Real diff --git a/test/sampling.jl b/test/sampling.jl index 27fcd2d3c..bb2947a32 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -213,10 +213,10 @@ x = vcat([sample(a, wv, 4, replace=false) for j in 1:10000]...) @test maximum(x) == 10 @test maximum(abs, proportions(x) .- 0.25) == 0 -@test_throws DimensionMismatch sample(a, wv, 5, replace=false) +@test_throws ArgumentError sample(a, wv, 5, replace=false) wv = Weights([zeros(5); 1:4; -1]) -@test_throws ErrorException sample(a, wv, 1, replace=false) +@test_throws ArgumentError sample(a, wv, 1, replace=false) #### weighted sampling with dimension diff --git a/test/weights.jl b/test/weights.jl index 52142efd8..e47b05b91 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -42,7 +42,6 @@ weight_funcs = (weights, aweights, fweights, pweights) @test_throws ArgumentError f([0.1, Inf]) @test_throws ArgumentError f([0.1, NaN]) - end @testset "$f, setindex!" for f in weight_funcs @@ -125,6 +124,50 @@ end @test Base.dataids(wv) == () end +@testset "Fast-path all(<=, wv) and any(<, wv)" begin + for f in weight_funcs + @test all(>=(0), f([1, 2])) + @test all(>=(0), f([-0.0, 0.0])) + @test !all(>=(0), f([1, -2])) + @test !any(<(0), f([1, 2])) + @test !any(<(0), f([-0.0, 0.0])) + @test any(<(0), f([1, -2])) + + @test all(>=(1), f([2, 3, 4])) + @test !all(>=(1), f([0, 1, 2])) + @test any(<(3), f([2, 3, 4])) + @test !any(<(1), f([1, 2, 3])) + + wv = f([1.0, 2.0, 3.0]) + @test all(>=(0), wv) + @test !any(<(0), wv) + wv[2] = -0.0 + @test all(>=(0), wv) + @test !any(<(0), wv) + wv[2] = -1.0 + @test !all(>=(0), wv) + @test any(<(0), wv) + wv[2] = 1.0 + @test all(>=(0), wv) + @test !any(<(0), wv) + + wv = f([1.0, 2.0, 3.0]) + wv[2] = -1.0 + @test !all(>=(0), wv) + @test any(<(0), wv) + wv[2] = 1.0 + @test all(>=(0), wv) + @test !any(<(0), wv) + end + + @test all(>=(0), uweights(2)) + @test !any(<(0), uweights(2)) + @test all(>=(1), uweights(2)) + @test !any(<(1), uweights(2)) + @test !all(>=(2), uweights(2)) + @test any(<(2), uweights(2)) +end + ## wsum @testset "wsum" begin diff --git a/test/wsampling.jl b/test/wsampling.jl index d1de4c855..cd9210372 100644 --- a/test/wsampling.jl +++ b/test/wsampling.jl @@ -161,5 +161,20 @@ end # This corner case should theoretically succeed # but it currently fails as Base.mightalias is not smart enough @test_broken f(y, weights(view(x, 5:6)), view(x, 2:4)) + + # Check that negative weights are not allowed + if f === efraimidis_ares_wsample_norep! || f === efraimidis_aexpj_wsample_norep! + y[3] = -0.0 + @test_throws ArgumentError f(x, weights(y), z) + else + y[3] = -0.0 + f(x, weights(y), z) + end + y[3] = -1.0 + @test_throws ArgumentError f(x, weights(y), z) + + # Check that sum of weights cannot be zero + @test_throws ArgumentError f(x, weights(fill(0.0, 10)), z) + @test_throws ArgumentError f(x, weights(fill(-0.0, 10)), z) end end \ No newline at end of file