-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make sampling from Hypergeometric thread-safe #46
Comments
Happy to accept a PR as a patch file that gets applied (so that we can keep updating to new R versions and then applying the patch). |
I don't understand what you mean by "patch file". |
I guess that means you don't use pull requests? Fine I guess... Before changing anything, I tried to reproduce the error explicitly in the @testset "rhyper" begin
# double rhyper(double nn1in, double nn2in, double kkin)
Nred = 30.0
Nblue = 40.0
Npulled = 5.0
hyper_samples = [
ccall((:rhyper, libRmath), Float64, (Float64, Float64, Float64), Nred, Nblue, Npulled)
for _ in 1:1_000_000
]
expected_mean = Npulled * Nred / (Nred + Nblue)
sample_mean = sum(hyper_samples) / length(hyper_samples)
@test sample_mean ≈ expected_mean rtol = 0.001
N = (Nred + Nblue)
expected_variance = Npulled * Nred * (N - Nred) * (N - Npulled) / (N * N * (N - 1))
sample_variance = 1 / (length(hyper_samples)) * sum((hyper_samples .- sample_mean) .^ 2)
@test sample_variance ≈ expected_variance rtol = 0.001
end
@testset "rhyper_multithreaded" begin
# double rhyper(double nn1in, double nn2in, double kkin)
Nred = 30.0
Nblue = 40.0
Npulled = 5.0
hyper_samples = Vector{Float64}(undef, 10_000_000)
Threads.@threads for i in eachindex(hyper_samples)
hyper_samples[i] = ccall(
(:rhyper, libRmath), Float64, (Float64, Float64, Float64),
Nred, Nblue, Npulled
)
end
expected_mean = Npulled * Nred / (Nred + Nblue)
sample_mean = sum(hyper_samples) / length(hyper_samples)
@test sample_mean ≈ expected_mean rtol = 0.001
N = (Nred + Nblue)
expected_variance = Npulled * Nred * (N - Nred) * (N - Npulled) / (N * N * (N - 1))
sample_variance = 1 / (length(hyper_samples)) * sum((hyper_samples .- sample_mean) .^ 2)
@test sample_variance ≈ expected_variance rtol = 0.001
end To my surprise, this seems to work correctly even when using threads. That's contrary to my expectations. I'm also not sure why the original issue with Edit:An example that more closely reproduces the original issue does show the errors: using Distributions
function sample_KkC(n; N, Q)
total_errors = Distributions.Binomial(N, Q)
K = rand(total_errors)
k = ccall(
(:rhyper, libRmath), Float64, (Float64, Float64, Float64),
K, N-K, n
)
return k
end
@testset "fulll" begin
function f(Q)
objective(n) = [sample_KkC(n; N = 819_200, Q) for _ = 1:100]
vals = [10, 100]
objective.(vals)
end
Qs = [0.05, 0.055, 0.1, 0.2, 0.3]
Threads.@threads for i in eachindex(Qs)
f(Qs[i])
end
end |
And making the static variables Patch: Note: the patch contains my edits to |
I found some more places where I'm still confused if you want a PR or just the diff, so I made a PR as well... |
Right we want a PR, in which this patch is applied. That way we keep carrying the patch and applying it every time we upgrade the Rmath version from the R distribution with If we merge that PR, those changes will be overwritten when we sync a new upstream Rmath version with |
Do I understand correctly that you don't need anything further from me? |
Merged and added #51. Thank you! |
This code here says the static variables "should become 'thread_local globals' ". Is there anything preventing just putting
_Thread_local
in the declaration?The current version leads to sampling from the hypergeometric distribution (
rhyper
) being broken when used in a threaded context. I mentioned this previously at JuliaStats/Distributions.jl#1829.The text was updated successfully, but these errors were encountered: