Skip to content

Commit

Permalink
Use default_rng() instead of GLOBAL_RNG on Julia >= 1.3 (#878)
Browse files Browse the repository at this point in the history
* Use `default_rng()` instead of `GLOBAL_RNG` on Julia >= 1.3

* Improve version check

Co-authored-by: Alex Arslan <[email protected]>

* Update test/sampling.jl

---------

Co-authored-by: Alex Arslan <[email protected]>
  • Loading branch information
devmotion and ararslan authored Jul 13, 2023
1 parent eac9bb8 commit 8696d51
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 44 deletions.
2 changes: 1 addition & 1 deletion docs/src/sampling.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Here are a list of algorithms implemented in the package. The functions below ar
- `wv`: the weight vector (of type `AbstractWeights`), for weighted sampling
- `n`: the length of `a`
- `k`: the length of `x`. For sampling without replacement, `k` must not exceed `n`.
- `rng`: optional random number generator (defaults to `Random.GLOBAL_RNG`)
- `rng`: optional random number generator (defaults to `Random.default_rng()` on Julia >= 1.3 and `Random.GLOBAL_RNG` on Julia < 1.3)

All following functions write results to `x` (pre-allocated) and return `x`.

Expand Down
90 changes: 48 additions & 42 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
#
###########################################################

using Random: Sampler, Random.GLOBAL_RNG
using Random: Sampler

if VERSION < v"1.3.0-DEV.565"
default_rng() = Random.GLOBAL_RNG
else
using Random: default_rng
end

### Algorithms for sampling with replacement

Expand All @@ -25,7 +31,7 @@ function direct_sample!(rng::AbstractRNG, a::UnitRange, x::AbstractArray)
end
return x
end
direct_sample!(a::UnitRange, x::AbstractArray) = direct_sample!(Random.GLOBAL_RNG, a, x)
direct_sample!(a::UnitRange, x::AbstractArray) = direct_sample!(default_rng(), a, x)

"""
direct_sample!([rng], a::AbstractArray, x::AbstractArray)
Expand All @@ -46,7 +52,7 @@ function direct_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
end
return x
end
direct_sample!(a::AbstractArray, x::AbstractArray) = direct_sample!(Random.GLOBAL_RNG, a, x)
direct_sample!(a::AbstractArray, x::AbstractArray) = direct_sample!(default_rng(), a, x)

# check whether we can use T to store indices 1:n exactly, and
# use some heuristics to decide whether it is beneficial for k samples
Expand Down Expand Up @@ -103,28 +109,28 @@ sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractArray,
Draw a pair of distinct integers between 1 and `n` without replacement.
Optionally specify a random number generator `rng` as the first argument
(defaults to `Random.GLOBAL_RNG`).
(defaults to `Random.$(VERSION < v"1.3" ? "GLOBAL_RNG" : "default_rng()")`).
"""
function samplepair(rng::AbstractRNG, n::Integer)
i1 = rand(rng, one(n):n)
i2 = rand(rng, one(n):(n - one(n)))
return (i1, ifelse(i2 == i1, n, i2))
end
samplepair(n::Integer) = samplepair(Random.GLOBAL_RNG, n)
samplepair(n::Integer) = samplepair(default_rng(), n)

"""
samplepair([rng], a)
Draw a pair of distinct elements from the array `a` without replacement.
Optionally specify a random number generator `rng` as the first argument
(defaults to `Random.GLOBAL_RNG`).
(defaults to `Random.$(VERSION < v"1.3" ? "GLOBAL_RNG" : "default_rng()")`).
"""
function samplepair(rng::AbstractRNG, a::AbstractArray)
i1, i2 = samplepair(rng, length(a))
return a[i1], a[i2]
end
samplepair(a::AbstractArray) = samplepair(Random.GLOBAL_RNG, a)
samplepair(a::AbstractArray) = samplepair(default_rng(), a)

### Algorithm for sampling without replacement

Expand Down Expand Up @@ -173,7 +179,7 @@ function knuths_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray;
return x
end
knuths_sample!(a::AbstractArray, x::AbstractArray; initshuffle::Bool=true) =
knuths_sample!(Random.GLOBAL_RNG, a, x; initshuffle=initshuffle)
knuths_sample!(default_rng(), a, x; initshuffle=initshuffle)

"""
fisher_yates_sample!([rng], a::AbstractArray, x::AbstractArray)
Expand Down Expand Up @@ -223,7 +229,7 @@ function fisher_yates_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArr
return x
end
fisher_yates_sample!(a::AbstractArray, x::AbstractArray) =
fisher_yates_sample!(Random.GLOBAL_RNG, a, x)
fisher_yates_sample!(default_rng(), a, x)

"""
self_avoid_sample!([rng], a::AbstractArray, x::AbstractArray)
Expand Down Expand Up @@ -269,7 +275,7 @@ function self_avoid_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray
return x
end
self_avoid_sample!(a::AbstractArray, x::AbstractArray) =
self_avoid_sample!(Random.GLOBAL_RNG, a, x)
self_avoid_sample!(default_rng(), a, x)

"""
seqsample_a!([rng], a::AbstractArray, x::AbstractArray)
Expand Down Expand Up @@ -311,7 +317,7 @@ function seqsample_a!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
end
return x
end
seqsample_a!(a::AbstractArray, x::AbstractArray) = seqsample_a!(Random.GLOBAL_RNG, a, x)
seqsample_a!(a::AbstractArray, x::AbstractArray) = seqsample_a!(default_rng(), a, x)

"""
seqsample_c!([rng], a::AbstractArray, x::AbstractArray)
Expand Down Expand Up @@ -357,7 +363,7 @@ function seqsample_c!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
end
return x
end
seqsample_c!(a::AbstractArray, x::AbstractArray) = seqsample_c!(Random.GLOBAL_RNG, a, x)
seqsample_c!(a::AbstractArray, x::AbstractArray) = seqsample_c!(default_rng(), a, x)

"""
seqsample_d!([rng], a::AbstractArray, x::AbstractArray)
Expand Down Expand Up @@ -449,7 +455,7 @@ function seqsample_d!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
end
end

seqsample_d!(a::AbstractArray, x::AbstractArray) = seqsample_d!(Random.GLOBAL_RNG, a, x)
seqsample_d!(a::AbstractArray, x::AbstractArray) = seqsample_d!(default_rng(), a, x)


### Interface functions (poly-algorithms)
Expand All @@ -460,10 +466,10 @@ Select a single random element of `a`. Sampling probabilities are proportional t
the weights given in `wv`, if provided.
Optionally specify a random number generator `rng` as the first argument
(defaults to `Random.GLOBAL_RNG`).
(defaults to `Random.$(VERSION < v"1.3" ? "GLOBAL_RNG" : "default_rng()")`).
"""
sample(rng::AbstractRNG, a::AbstractArray) = a[rand(rng, 1:length(a))]
sample(a::AbstractArray) = sample(Random.GLOBAL_RNG, a)
sample(a::AbstractArray) = sample(default_rng(), a)


"""
Expand All @@ -478,7 +484,7 @@ an ordered sample (also called a sequential sample, i.e. a sample where
items appear in the same order as in `a`) should be taken.
Optionally specify a random number generator `rng` as the first argument
(defaults to `Random.GLOBAL_RNG`).
(defaults to `Random.$(VERSION < v"1.3" ? "GLOBAL_RNG" : "default_rng()")`).
Output array `a` must not be the same object as `x` or `wv`
nor share memory with them, or the result may be incorrect.
Expand Down Expand Up @@ -522,7 +528,7 @@ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray;
return x
end
sample!(a::AbstractArray, x::AbstractArray; replace::Bool=true, ordered::Bool=false) =
sample!(Random.GLOBAL_RNG, a, x; replace=replace, ordered=ordered)
sample!(default_rng(), a, x; replace=replace, ordered=ordered)


"""
Expand All @@ -536,14 +542,14 @@ an ordered sample (also called a sequential sample, i.e. a sample where
items appear in the same order as in `a`) should be taken.
Optionally specify a random number generator `rng` as the first argument
(defaults to `Random.GLOBAL_RNG`).
(defaults to `Random.$(VERSION < v"1.3" ? "GLOBAL_RNG" : "default_rng()")`).
"""
function sample(rng::AbstractRNG, a::AbstractArray{T}, n::Integer;
replace::Bool=true, ordered::Bool=false) where T
sample!(rng, a, Vector{T}(undef, n); replace=replace, ordered=ordered)
end
sample(a::AbstractArray, n::Integer; replace::Bool=true, ordered::Bool=false) =
sample(Random.GLOBAL_RNG, a, n; replace=replace, ordered=ordered)
sample(default_rng(), a, n; replace=replace, ordered=ordered)


"""
Expand All @@ -557,14 +563,14 @@ an ordered sample (also called a sequential sample, i.e. a sample where
items appear in the same order as in `a`) should be taken.
Optionally specify a random number generator `rng` as the first argument
(defaults to `Random.GLOBAL_RNG`).
(defaults to `Random.$(VERSION < v"1.3" ? "GLOBAL_RNG" : "default_rng()")`).
"""
function sample(rng::AbstractRNG, a::AbstractArray{T}, dims::Dims;
replace::Bool=true, ordered::Bool=false) where T
sample!(rng, a, Array{T}(undef, dims); replace=replace, ordered=ordered)
end
sample(a::AbstractArray, dims::Dims; replace::Bool=true, ordered::Bool=false) =
sample(Random.GLOBAL_RNG, a, dims; replace=replace, ordered=ordered)
sample(default_rng(), a, dims; replace=replace, ordered=ordered)

################################################################
#
Expand All @@ -579,7 +585,7 @@ Select a single random integer in `1:length(wv)` with probabilities
proportional to the weights given in `wv`.
Optionally specify a random number generator `rng` as the first argument
(defaults to `Random.GLOBAL_RNG`).
(defaults to `Random.$(VERSION < v"1.3" ? "GLOBAL_RNG" : "default_rng()")`).
"""
function sample(rng::AbstractRNG, wv::AbstractWeights)
1 == firstindex(wv) ||
Expand All @@ -594,10 +600,10 @@ function sample(rng::AbstractRNG, wv::AbstractWeights)
end
return i
end
sample(wv::AbstractWeights) = sample(Random.GLOBAL_RNG, wv)
sample(wv::AbstractWeights) = sample(default_rng(), wv)

sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)]
sample(a::AbstractArray, wv::AbstractWeights) = sample(Random.GLOBAL_RNG, a, wv)
sample(a::AbstractArray, wv::AbstractWeights) = sample(default_rng(), a, wv)

"""
direct_sample!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray)
Expand Down Expand Up @@ -627,7 +633,7 @@ function direct_sample!(rng::AbstractRNG, a::AbstractArray,
return x
end
direct_sample!(a::AbstractArray, wv::AbstractWeights, x::AbstractArray) =
direct_sample!(Random.GLOBAL_RNG, a, wv, x)
direct_sample!(default_rng(), a, wv, x)

function make_alias_table!(w::AbstractVector, wsum,
a::AbstractVector{Float64},
Expand Down Expand Up @@ -725,7 +731,7 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights,
return x
end
alias_sample!(a::AbstractArray, wv::AbstractWeights, x::AbstractArray) =
alias_sample!(Random.GLOBAL_RNG, a, wv, x)
alias_sample!(default_rng(), a, wv, x)

"""
naive_wsample_norep!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray)
Expand Down Expand Up @@ -769,7 +775,7 @@ function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
return x
end
naive_wsample_norep!(a::AbstractArray, wv::AbstractWeights, x::AbstractArray) =
naive_wsample_norep!(Random.GLOBAL_RNG, a, wv, x)
naive_wsample_norep!(default_rng(), a, wv, x)

# Weighted sampling without replacement
# Instead of keys u^(1/w) where u = random(0,1) keys w/v where v = randexp(1) are used.
Expand Down Expand Up @@ -810,7 +816,7 @@ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
return x
end
efraimidis_a_wsample_norep!(a::AbstractArray, wv::AbstractWeights, x::AbstractArray) =
efraimidis_a_wsample_norep!(Random.GLOBAL_RNG, a, wv, x)
efraimidis_a_wsample_norep!(default_rng(), a, wv, x)

# Weighted sampling without replacement
# Instead of keys u^(1/w) where u = random(0,1) keys w/v where v = randexp(1) are used.
Expand Down Expand Up @@ -882,7 +888,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
return x
end
efraimidis_ares_wsample_norep!(a::AbstractArray, wv::AbstractWeights, x::AbstractArray) =
efraimidis_ares_wsample_norep!(Random.GLOBAL_RNG, a, wv, x)
efraimidis_ares_wsample_norep!(default_rng(), a, wv, x)

# Weighted sampling without replacement
# Instead of keys u^(1/w) where u = random(0,1) keys w/v where v = randexp(1) are used.
Expand Down Expand Up @@ -964,7 +970,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
end
efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::AbstractWeights, x::AbstractArray;
ordered::Bool=false) =
efraimidis_aexpj_wsample_norep!(Random.GLOBAL_RNG, a, wv, x; ordered=ordered)
efraimidis_aexpj_wsample_norep!(default_rng(), a, wv, x; ordered=ordered)

function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray;
replace::Bool=true, ordered::Bool=false)
Expand Down Expand Up @@ -998,21 +1004,21 @@ function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::Abs
end
sample!(a::AbstractArray, wv::AbstractWeights, x::AbstractArray;
replace::Bool=true, ordered::Bool=false) =
sample!(Random.GLOBAL_RNG, a, wv, x; replace=replace, ordered=ordered)
sample!(default_rng(), a, wv, x; replace=replace, ordered=ordered)

sample(rng::AbstractRNG, a::AbstractArray{T}, wv::AbstractWeights, n::Integer;
replace::Bool=true, ordered::Bool=false) where {T} =
sample!(rng, a, wv, Vector{T}(undef, n); replace=replace, ordered=ordered)
sample(a::AbstractArray, wv::AbstractWeights, n::Integer;
replace::Bool=true, ordered::Bool=false) =
sample(Random.GLOBAL_RNG, a, wv, n; replace=replace, ordered=ordered)
sample(default_rng(), a, wv, n; replace=replace, ordered=ordered)

sample(rng::AbstractRNG, a::AbstractArray{T}, wv::AbstractWeights, dims::Dims;
replace::Bool=true, ordered::Bool=false) where {T} =
sample!(rng, a, wv, Array{T}(undef, dims); replace=replace, ordered=ordered)
sample(a::AbstractArray, wv::AbstractWeights, dims::Dims;
replace::Bool=true, ordered::Bool=false) =
sample(Random.GLOBAL_RNG, a, wv, dims; replace=replace, ordered=ordered)
sample(default_rng(), a, wv, dims; replace=replace, ordered=ordered)

# wsample interface

Expand All @@ -1026,14 +1032,14 @@ an ordered sample (also called a sequential sample, i.e. a sample where
items appear in the same order as in `a`) should be taken.
Optionally specify a random number generator `rng` as the first argument
(defaults to `Random.GLOBAL_RNG`).
(defaults to `Random.$(VERSION < v"1.3" ? "GLOBAL_RNG" : "default_rng()")`).
"""
wsample!(rng::AbstractRNG, a::AbstractArray, w::AbstractVector{<:Real}, x::AbstractArray;
replace::Bool=true, ordered::Bool=false) =
sample!(rng, a, weights(w), x; replace=replace, ordered=ordered)
wsample!(a::AbstractArray, w::AbstractVector{<:Real}, x::AbstractArray;
replace::Bool=true, ordered::Bool=false) =
sample!(Random.GLOBAL_RNG, a, weights(w), x; replace=replace, ordered=ordered)
sample!(default_rng(), a, weights(w), x; replace=replace, ordered=ordered)

"""
wsample([rng], [a], w)
Expand All @@ -1042,12 +1048,12 @@ Select a weighted random sample of size 1 from `a` with probabilities proportion
to the weights given in `w`. If `a` is not present, select a random weight from `w`.
Optionally specify a random number generator `rng` as the first argument
(defaults to `Random.GLOBAL_RNG`).
(defaults to `Random.$(VERSION < v"1.3" ? "GLOBAL_RNG" : "default_rng()")`).
"""
wsample(rng::AbstractRNG, w::AbstractVector{<:Real}) = sample(rng, weights(w))
wsample(w::AbstractVector{<:Real}) = wsample(Random.GLOBAL_RNG, w)
wsample(w::AbstractVector{<:Real}) = wsample(default_rng(), w)
wsample(rng::AbstractRNG, a::AbstractArray, w::AbstractVector{<:Real}) = sample(rng, a, weights(w))
wsample(a::AbstractArray, w::AbstractVector{<:Real}) = wsample(Random.GLOBAL_RNG, a, w)
wsample(a::AbstractArray, w::AbstractVector{<:Real}) = wsample(default_rng(), a, w)


"""
Expand All @@ -1061,14 +1067,14 @@ an ordered sample (also called a sequential sample, i.e. a sample where
items appear in the same order as in `a`) should be taken.
Optionally specify a random number generator `rng` as the first argument
(defaults to `Random.GLOBAL_RNG`).
(defaults to `Random.$(VERSION < v"1.3" ? "GLOBAL_RNG" : "default_rng()")`).
"""
wsample(rng::AbstractRNG, a::AbstractArray{T}, w::AbstractVector{<:Real}, n::Integer;
replace::Bool=true, ordered::Bool=false) where {T} =
wsample!(rng, a, w, Vector{T}(undef, n); replace=replace, ordered=ordered)
wsample(a::AbstractArray, w::AbstractVector{<:Real}, n::Integer;
replace::Bool=true, ordered::Bool=false) =
wsample(Random.GLOBAL_RNG, a, w, n; replace=replace, ordered=ordered)
wsample(default_rng(), a, w, n; replace=replace, ordered=ordered)

"""
wsample([rng], [a], w, dims::Dims; replace=true, ordered=false)
Expand All @@ -1078,11 +1084,11 @@ weights given in `w` if `a` is present, otherwise select a random sample of size
`n` of the weights given in `w`. The dimensions of the output are given by `dims`.
Optionally specify a random number generator `rng` as the first argument
(defaults to `Random.GLOBAL_RNG`).
(defaults to `Random.$(VERSION < v"1.3" ? "GLOBAL_RNG" : "default_rng()")`).
"""
wsample(rng::AbstractRNG, a::AbstractArray{T}, w::AbstractVector{<:Real}, dims::Dims;
replace::Bool=true, ordered::Bool=false) where {T} =
wsample!(rng, a, w, Array{T}(undef, dims); replace=replace, ordered=ordered)
wsample(a::AbstractArray, w::AbstractVector{<:Real}, dims::Dims;
replace::Bool=true, ordered::Bool=false) =
wsample(Random.GLOBAL_RNG, a, w, dims; replace=replace, ordered=ordered)
wsample(default_rng(), a, w, dims; replace=replace, ordered=ordered)
8 changes: 7 additions & 1 deletion test/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ n = 100000
# a) if the same rng is passed to a sample function twice,
# the results should be the same (repeatability)
# b) not specifying a rng should be the same as specifying Random.GLOBAL_RNG
# and Random.default_rng() on Julia >= 1.3
function test_rng_use(func, non_rng_args...)
# some sampling methods mutate a passed array and return it
# so that the tests don't pass trivially, we need to copy those
Expand All @@ -17,12 +18,17 @@ function test_rng_use(func, non_rng_args...)
# repeatability
@test func(MersenneTwister(1), deepcopy(non_rng_args)...) ==
func(MersenneTwister(1), deepcopy(non_rng_args)...)
# default RNG is Random.GLOBAL_RNG
# default RNG is Random.GLOBAL_RNG/Random.default_rng()
Random.seed!(47)
x = func(deepcopy(non_rng_args)...)
Random.seed!(47)
y = func(Random.GLOBAL_RNG, deepcopy(non_rng_args)...)
@test x == y
if VERSION >= v"1.3.0-DEV.565"
Random.seed!(47)
y = func(Random.default_rng(), deepcopy(non_rng_args)...)
@test x == y
end
end

#### sample with replacement
Expand Down

0 comments on commit 8696d51

Please sign in to comment.