Skip to content

Commit

Permalink
Merge pull request #198 from dhanak/dev
Browse files Browse the repository at this point in the history
Use seed! to put every copy of rng into a unique state
  • Loading branch information
ablaom authored Nov 29, 2022
2 parents f22e0a6 + 10843eb commit 2efcb75
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
6 changes: 2 additions & 4 deletions src/classification/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,10 @@ function build_forest(
loss = (ns, n) -> util.entropy(ns, n, entropy_terms)

if rng isa Random.AbstractRNG
shared_seed = rand(rng, UInt)
Threads.@threads for i in 1:n_trees
# The Mersenne Twister (Julia's default) is not thread-safe.
_rng = copy(rng)
# Take some elements from the ring to have different states for each tree. This
# is the only way given that only a `copy` can be expected to exist for RNGs.
rand(_rng, i)
_rng = Random.seed!(copy(rng), shared_seed + i)
inds = rand(_rng, 1:t_samples, n_samples)
forest[i] = build_tree(
labels[inds],
Expand Down
6 changes: 2 additions & 4 deletions src/regression/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,10 @@ function build_forest(
forest = impurity_importance ? Vector{Root{S, T}}(undef, n_trees) : Vector{LeafOrNode{S, T}}(undef, n_trees)

if rng isa Random.AbstractRNG
shared_seed = rand(rng, UInt)
Threads.@threads for i in 1:n_trees
# The Mersenne Twister (Julia's default) is not thread-safe.
_rng = copy(rng)
# Take some elements from the ring to have different states for each tree.
# This is the only way given that only a `copy` can be expected to exist for RNGs.
rand(_rng, i)
_rng = Random.seed!(copy(rng), shared_seed + i)
inds = rand(_rng, 1:t_samples, n_samples)
forest[i] = build_tree(
labels[inds],
Expand Down

0 comments on commit 2efcb75

Please sign in to comment.