From 05da07c3d4c9bcdf00164de18c696a0641c75dca Mon Sep 17 00:00:00 2001 From: Rik Huijzer Date: Tue, 14 Jun 2022 15:49:16 +0200 Subject: [PATCH 01/10] Fix typo in `print_tree` description --- src/DecisionTree.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 9a8ab4d1..540e052b 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -94,7 +94,7 @@ In the example output below, the top node considers whether "Feature 3" is above or below the threshold -28.156052806422238. If the value of "Feature 3" is strictly below the threshold for some input to be classified, we move to the `L->` part underneath, which is a node -looking at if "Feature 2" is above or below -161.04351901384842. +looking at if "Feature 2" is above or equal -161.04351901384842. If the value of "Feature 2" is strictly below the threshold for some input to be classified, we end up at `L-> 5 : 842/3650`. This is to be read as "In the left split, the tree will classify the input as class 5, as 842 of the 3650 datapoints From dfe3c5dd475cc48927c741ff29dc1bd11b35f196 Mon Sep 17 00:00:00 2001 From: rikhuijzer Date: Tue, 14 Jun 2022 16:10:11 +0200 Subject: [PATCH 02/10] =?UTF-8?q?Replace=20`L`=20by=20`<`=20and=20`R`=20by?= =?UTF-8?q?=20`=E2=89=A5`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/DecisionTree.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 9a8ab4d1..3590f63c 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -90,25 +90,25 @@ end print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing) Print a textual visualization of the given decision tree `tree`. -In the example output below, the top node considers whether +In the example output below, the top node considers whether "Feature 3" is above or below the threshold -28.156052806422238. -If the value of "Feature 3" is strictly below the threshold for some input to be classified, -we move to the `L->` part underneath, which is a node +If the value of "Feature 3" is strictly below the threshold for some input to be classified, +we move to the `L->` part underneath, which is a node looking at if "Feature 2" is above or below -161.04351901384842. -If the value of "Feature 2" is strictly below the threshold for some input to be classified, -we end up at `L-> 5 : 842/3650`. This is to be read as "In the left split, -the tree will classify the input as class 5, as 842 of the 3650 datapoints +If the value of "Feature 2" is strictly below the threshold for some input to be classified, +we end up at `L-> 5 : 842/3650`. This is to be read as "In the left split, +the tree will classify the input as class 5, as 842 of the 3650 datapoints in the training data that ended up here were of class 5." # Example output: ``` Feature 3, Threshold -28.156052806422238 -L-> Feature 2, Threshold -161.04351901384842 - L-> 5 : 842/3650 - R-> 7 : 2493/10555 -R-> Feature 7, Threshold 108.1408338577021 - L-> 2 : 2434/15287 - R-> 8 : 1227/3508 +< then Feature 2, Threshold -161.04351901384842 + < then 5 : 842/3650 + ≥ then 7 : 2493/10555 +≥ then Feature 7, Threshold 108.1408338577021 + < then 2 : 2434/15287 + ≥ then 8 : 1227/3508 ``` To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or @@ -125,9 +125,9 @@ function print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing) else println("Feature $(tree.featid): \"$(feature_names[tree.featid])\", Threshold $(tree.featval)") end - print(" " ^ indent * "L-> ") + print(" " ^ indent * "< then ") print_tree(tree.left, depth, indent + 1; feature_names = feature_names) - print(" " ^ indent * "R-> ") + print(" " ^ indent * "≥ then ") print_tree(tree.right, depth, indent + 1; feature_names = feature_names) end From b35a4dda2908e0a2dbcee3ad06639ee62dc46a36 Mon Sep 17 00:00:00 2001 From: rikhuijzer Date: Wed, 15 Jun 2022 08:31:49 +0200 Subject: [PATCH 03/10] Round digits in `print_tree` --- src/DecisionTree.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 9a8ab4d1..402581b8 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -81,13 +81,13 @@ depth(leaf::Leaf) = 0 depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right)) function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing) - matches = findall(leaf.values .== leaf.majority) - ratio = string(length(matches)) * "/" * string(length(leaf.values)) + n_matches = count(leaf.values .== leaf.majority) + ratio = string(n_matches, "/", length(leaf.values)) println("$(leaf.majority) : $(ratio)") end """ - print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing) + print_tree(tree::Node, depth=-1, indent=0; digits=2, feature_names=nothing) Print a textual visualization of the given decision tree `tree`. In the example output below, the top node considers whether @@ -115,15 +115,16 @@ To facilitate visualisation of trees using third party packages, a `DecisionTree `DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the AbstractTrees.jl interface. See [`wrap`](@ref)` for details. """ -function print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing) +function print_tree(tree::Node, depth=-1, indent=0; digits=2, feature_names=nothing) if depth == indent println() return end + featval = round(tree.featval; digits=digits) if feature_names === nothing - println("Feature $(tree.featid), Threshold $(tree.featval)") + println("Feature $(tree.featid), Threshold $featval") else - println("Feature $(tree.featid): \"$(feature_names[tree.featid])\", Threshold $(tree.featval)") + println("Feature $(tree.featid): \"$(feature_names[tree.featid])\", Threshold $featval") end print(" " ^ indent * "L-> ") print_tree(tree.left, depth, indent + 1; feature_names = feature_names) From 7c25616c3bbc0c717b3d2ab90fc7d7164b84501d Mon Sep 17 00:00:00 2001 From: rikhuijzer Date: Wed, 15 Jun 2022 08:35:09 +0200 Subject: [PATCH 04/10] Update doc --- src/DecisionTree.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 402581b8..c94d008b 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -102,11 +102,11 @@ in the training data that ended up here were of class 5." # Example output: ``` -Feature 3, Threshold -28.156052806422238 -L-> Feature 2, Threshold -161.04351901384842 +Feature 3, Threshold -28.16 +L-> Feature 2, Threshold -161.04 L-> 5 : 842/3650 R-> 7 : 2493/10555 -R-> Feature 7, Threshold 108.1408338577021 +R-> Feature 7, Threshold 108.14 L-> 2 : 2434/15287 R-> 8 : 1227/3508 ``` From 469ead99cbd55e0e1d9de835a06ca6b29dadb3ba Mon Sep 17 00:00:00 2001 From: rikhuijzer Date: Thu, 16 Jun 2022 08:30:10 +0200 Subject: [PATCH 05/10] Change `digits` to `sigdigits` --- src/DecisionTree.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index c94d008b..a1dbb0cf 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -87,7 +87,7 @@ function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing) end """ - print_tree(tree::Node, depth=-1, indent=0; digits=2, feature_names=nothing) + print_tree(tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing) Print a textual visualization of the given decision tree `tree`. In the example output below, the top node considers whether @@ -111,16 +111,16 @@ R-> Feature 7, Threshold 108.14 R-> 8 : 1227/3508 ``` -To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or -`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the -AbstractTrees.jl interface. See [`wrap`](@ref)` for details. +To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or +`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the +AbstractTrees.jl interface. See [`wrap`](@ref)` for details. """ -function print_tree(tree::Node, depth=-1, indent=0; digits=2, feature_names=nothing) +function print_tree(tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing) if depth == indent println() return end - featval = round(tree.featval; digits=digits) + featval = round(tree.featval; sigdigits=sigdigits) if feature_names === nothing println("Feature $(tree.featid), Threshold $featval") else From d206a3ac31221c3cb65b24939c275a1a3274dc9a Mon Sep 17 00:00:00 2001 From: rikhuijzer Date: Thu, 16 Jun 2022 08:46:42 +0200 Subject: [PATCH 06/10] Place comparison in the first line --- src/DecisionTree.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 3590f63c..50b8e797 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -102,13 +102,13 @@ in the training data that ended up here were of class 5." # Example output: ``` -Feature 3, Threshold -28.156052806422238 -< then Feature 2, Threshold -161.04351901384842 - < then 5 : 842/3650 - ≥ then 7 : 2493/10555 -≥ then Feature 7, Threshold 108.1408338577021 - < then 2 : 2434/15287 - ≥ then 8 : 1227/3508 +Feature 3 < -28.156052806422238 ? +├─ Feature 2 < -161.04351901384842 ? + ├─ 5 : 842/3650 + └─ 7 : 2493/10555 +└─ Feature 7 < 108.1408338577021 ? + ├─ 2 : 2434/15287 + └─ 8 : 1227/3508 ``` To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or @@ -121,13 +121,13 @@ function print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing) return end if feature_names === nothing - println("Feature $(tree.featid), Threshold $(tree.featval)") + println("Feature $(tree.featid) < $(tree.featval)") else - println("Feature $(tree.featid): \"$(feature_names[tree.featid])\", Threshold $(tree.featval)") + println("Feature $(tree.featid): \"$(feature_names[tree.featid])\" < $(tree.featval)") end - print(" " ^ indent * "< then ") + print(" " ^ indent * "├─ ") print_tree(tree.left, depth, indent + 1; feature_names = feature_names) - print(" " ^ indent * "≥ then ") + print(" " ^ indent * "└─ ") print_tree(tree.right, depth, indent + 1; feature_names = feature_names) end From e09e353c9206ec3125970102c2b16554867c4fdc Mon Sep 17 00:00:00 2001 From: Rik Huijzer Date: Fri, 17 Jun 2022 19:11:07 +0200 Subject: [PATCH 07/10] Set `sigdigits=4` Co-authored-by: Anthony Blaom, PhD --- src/DecisionTree.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 75057f95..719584a9 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -87,7 +87,7 @@ function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing) end """ - print_tree(tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing) + print_tree(tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing) Print a textual visualization of the specified `tree`. For example, if for some input pattern the value of "Feature 3" is "-30" and the value From 3ba1651bf6148a526b04c76f5c467c829100fa10 Mon Sep 17 00:00:00 2001 From: rikhuijzer Date: Fri, 17 Jun 2022 19:40:24 +0200 Subject: [PATCH 08/10] Add test --- src/DecisionTree.jl | 29 ++++++++++++++++++----------- src/scikitlearnAPI.jl | 6 ++++-- test/classification/random.jl | 16 ++++++++++++++-- 3 files changed, 36 insertions(+), 15 deletions(-) diff --git a/src/DecisionTree.jl b/src/DecisionTree.jl index 719584a9..10c638bb 100644 --- a/src/DecisionTree.jl +++ b/src/DecisionTree.jl @@ -80,14 +80,18 @@ length(ensemble::Ensemble) = length(ensemble.trees) depth(leaf::Leaf) = 0 depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right)) -function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing) +function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; feature_names=nothing) n_matches = count(leaf.values .== leaf.majority) ratio = string(n_matches, "/", length(leaf.values)) - println("$(leaf.majority) : $(ratio)") + println(io, "$(leaf.majority) : $(ratio)") +end +function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing) + return print_tree(stdout, leaf, depth, indent; feature_names=feature_names) end + """ - print_tree(tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing) + print_tree([io::IO,] tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing) Print a textual visualization of the specified `tree`. For example, if for some input pattern the value of "Feature 3" is "-30" and the value @@ -113,21 +117,24 @@ To facilitate visualisation of trees using third party packages, a `DecisionTree `DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the AbstractTrees.jl interface. See [`wrap`](@ref)` for details. """ -function print_tree(tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing) +function print_tree(io::IO, tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing) if depth == indent - println() + println(io) return end featval = round(tree.featval; sigdigits=sigdigits) if feature_names === nothing - println("Feature $(tree.featid) < $featval ?") + println(io, "Feature $(tree.featid) < $featval ?") else - println("Feature $(tree.featid): \"$(feature_names[tree.featid])\" < $featval ?") + println(io, "Feature $(tree.featid): \"$(feature_names[tree.featid])\" < $featval ?") end - print(" " ^ indent * "├─ ") - print_tree(tree.left, depth, indent + 1; feature_names = feature_names) - print(" " ^ indent * "└─ ") - print_tree(tree.right, depth, indent + 1; feature_names = feature_names) + print(io, " " ^ indent * "├─ ") + print_tree(io, tree.left, depth, indent + 1; feature_names=feature_names) + print(io, " " ^ indent * "└─ ") + print_tree(io, tree.right, depth, indent + 1; feature_names=feature_names) +end +function print_tree(tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing) + return print_tree(stdout, tree, depth, indent; sigdigits=sigdigits, feature_names=feature_names) end function show(io::IO, leaf::Leaf) diff --git a/src/scikitlearnAPI.jl b/src/scikitlearnAPI.jl index 249e531a..1f6986a3 100644 --- a/src/scikitlearnAPI.jl +++ b/src/scikitlearnAPI.jl @@ -386,5 +386,7 @@ length(dt::DecisionTreeClassifier) = length(dt.root) length(dt::DecisionTreeRegressor) = length(dt.root) print_tree(dt::DecisionTreeClassifier, depth=-1; kwargs...) = print_tree(dt.root, depth; kwargs...) -print_tree(dt::DecisionTreeRegressor, depth=-1; kwargs...) = print_tree(dt.root, depth; kwargs...) -print_tree(n::Nothing, depth=-1; kwargs...) = show(n) +print_tree(io::IO, dt::DecisionTreeClassifier, depth=-1; kwargs...) = print_tree(io, dt.root, depth; kwargs...) +print_tree(dt::DecisionTreeRegressor, depth=-1; kwargs...) = print_tree(dt.root, depth; kwargs...) +print_tree(io::IO, dt::DecisionTreeRegressor, depth=-1; kwargs...) = print_tree(io, dt.root, depth; kwargs...) +print_tree(n::Nothing, depth=-1; kwargs...) = show(n) diff --git a/test/classification/random.jl b/test/classification/random.jl index 4b0e52f0..ab195199 100644 --- a/test/classification/random.jl +++ b/test/classification/random.jl @@ -3,7 +3,7 @@ Random.seed!(16) -n,m = 10^3, 5; +n, m = 10^3, 5; features = rand(n,m); weights = rand(-1:1,m); labels = round.(Int, features * weights); @@ -15,7 +15,19 @@ preds = apply_tree(model, round.(Int, features)) max_depth = 3 model = build_tree(labels, features, 0, max_depth) @test depth(model) == max_depth -print_tree(model, 3) + +io = IOBuffer() +print_tree(io, model, 3) +text = String(take!(io)) +println() +print(text) +println() + +# Read the regex as: many not arrow left followed by an arrow left, a space, some numbers and +# a dot and a space and question mark. +rx = r"[^<]*< [0-9\.]* ?" +matches = eachmatch(rx, text) +@test !isempty(matches) model = build_tree(labels, features) preds = apply_tree(model, features) From 969c6378393ad29869badca9603ed8b230cec454 Mon Sep 17 00:00:00 2001 From: Rik Huijzer Date: Wed, 22 Jun 2022 00:40:23 +0200 Subject: [PATCH 09/10] Test multiple seeds (#174) * Test multiple rngs * Add comment * Fix rng not being passed correctly * Revert some changes * Update comment * Extend tests * Simplify test * Lower accuracy bound * Fix a bug in the usage of Mersenne Twister * Fix tests * Use some more StableRNG * Use some more StableRNG * Use some more StableRNG * Use some more StableRNG * Use some more StableRNG * Use some more StableRNG * Fix Julia 1.6 * Use `StableRNG` in test/classification/adult * Use `StableRNG` for data generation too * Put old numbers back * Add one more rng --- .github/workflows/CI.yml | 2 +- Project.toml | 9 +++- src/classification/main.jl | 22 ++++++---- src/measures.jl | 3 +- src/regression/main.jl | 15 ++++--- test/classification/adult.jl | 12 ++--- test/classification/digits.jl | 3 +- test/classification/heterogeneous.jl | 12 ++--- test/classification/iris.jl | 10 ++--- test/classification/low_precision.jl | 20 +++++---- test/classification/random.jl | 66 ++++++++++++++++------------ test/classification/scikitlearn.jl | 21 +++++---- test/regression/digits.jl | 7 +-- test/regression/low_precision.jl | 23 +++++----- test/regression/random.jl | 24 +++++----- test/regression/scikitlearn.jl | 35 ++++++++------- test/runtests.jl | 6 ++- 17 files changed, 166 insertions(+), 124 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 163aa689..123c6f12 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - version: + version: - '1.0' - '1.6' - '1' # automatically expands to the latest stable 1.x release of Julia diff --git a/Project.toml b/Project.toml index 106d4f1a..f3f22e20 100644 --- a/Project.toml +++ b/Project.toml @@ -7,14 +7,19 @@ version = "0.10.12" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" -Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ScikitLearnBase = "6e75b9c4-186b-50bd-896f-2d2496a4843e" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] AbstractTrees = "0.3" ScikitLearnBase = "0.5" julia = "1" + +[extras] +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["StableRNGs", "Test"] diff --git a/src/classification/main.jl b/src/classification/main.jl index 52e7bdb7..d1061a14 100644 --- a/src/classification/main.jl +++ b/src/classification/main.jl @@ -37,7 +37,8 @@ end function _convert( node :: treeclassifier.NodeMeta{S}, list :: AbstractVector{T}, - labels :: AbstractVector{T}) where {S, T} + labels :: AbstractVector{T} + ) where {S, T} if node.is_leaf return Leaf{T}(list[node.label], labels[node.region]) @@ -138,7 +139,7 @@ function prune_tree(tree::LeafOrNode{S, T}, purity_thresh=1.0) where {S, T} end -apply_tree(leaf::Leaf{T}, feature::AbstractVector{S}) where {S, T} = leaf.majority +apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.majority function apply_tree(tree::Node{S, T}, features::AbstractVector{S}) where {S, T} if tree.featid == 0 @@ -197,7 +198,7 @@ function build_forest( min_samples_leaf = 1, min_samples_split = 2, min_purity_increase = 0.0; - rng = Random.GLOBAL_RNG) where {S, T} + rng::Union{Integer,AbstractRNG} = Random.GLOBAL_RNG) where {S, T} if n_trees < 1 throw("the number of trees must be >= 1") @@ -221,7 +222,12 @@ function build_forest( if rng isa Random.AbstractRNG Threads.@threads for i in 1:n_trees - inds = rand(rng, 1:t_samples, n_samples) + # The Mersenne Twister (Julia's default) is not thread-safe. + _rng = copy(rng) + # Take some elements from the ring to have different states for each tree. + # This is the only way given that only a `copy` can be expected to exist for RNGs. + rand(_rng, i) + inds = rand(_rng, 1:t_samples, n_samples) forest[i] = build_tree( labels[inds], features[inds,:], @@ -231,9 +237,9 @@ function build_forest( min_samples_split, min_purity_increase, loss = loss, - rng = rng) + rng = _rng) end - elseif rng isa Integer # each thread gets its own seeded rng + else # each thread gets its own seeded rng Threads.@threads for i in 1:n_trees Random.seed!(rng + i) inds = rand(1:t_samples, n_samples) @@ -247,8 +253,6 @@ function build_forest( min_purity_increase, loss = loss) end - else - throw("rng must of be type Integer or Random.AbstractRNG") end return Ensemble{S, T}(forest) @@ -298,7 +302,7 @@ function build_adaboost_stumps( labels :: AbstractVector{T}, features :: AbstractMatrix{S}, n_iterations :: Integer; - rng = Random.GLOBAL_RNG) where {S, T} + rng = Random.GLOBAL_RNG) where {S, T} N = length(labels) weights = ones(N) / N stumps = Node{S, T}[] diff --git a/src/measures.jl b/src/measures.jl index 06de1e18..c9fb92f8 100644 --- a/src/measures.jl +++ b/src/measures.jl @@ -135,7 +135,7 @@ function _nfoldCV(classifier::Symbol, labels::AbstractVector{T}, features::Abstr predictions = apply_forest(model, test_features) elseif classifier == :stumps model, coeffs = build_adaboost_stumps( - train_labels, train_features, n_iterations) + train_labels, train_features, n_iterations; rng=rng) predictions = apply_adaboost_stumps(model, coeffs, test_features) end cm = confusion_matrix(test_labels, predictions) @@ -186,6 +186,7 @@ function nfoldCV_stumps( n_iterations ::Integer = 10; verbose :: Bool = true, rng = Random.GLOBAL_RNG) where {S, T} + rng = mk_rng(rng)::Random.AbstractRNG _nfoldCV(:stumps, labels, features, n_folds, n_iterations; verbose=verbose, rng=rng) end diff --git a/src/regression/main.jl b/src/regression/main.jl index 2d012aa0..7eca176f 100644 --- a/src/regression/main.jl +++ b/src/regression/main.jl @@ -56,7 +56,7 @@ function build_forest( min_samples_leaf = 5, min_samples_split = 2, min_purity_increase = 0.0; - rng = Random.GLOBAL_RNG) where {S, T <: Float64} + rng::Union{Integer,AbstractRNG} = Random.GLOBAL_RNG) where {S, T <: Float64} if n_trees < 1 throw("the number of trees must be >= 1") @@ -77,7 +77,12 @@ function build_forest( if rng isa Random.AbstractRNG Threads.@threads for i in 1:n_trees - inds = rand(rng, 1:t_samples, n_samples) + # The Mersenne Twister (Julia's default) is not thread-safe. + _rng = copy(rng) + # Take some elements from the ring to have different states for each tree. + # This is the only way given that only a `copy` can be expected to exist for RNGs. + rand(_rng, i) + inds = rand(_rng, 1:t_samples, n_samples) forest[i] = build_tree( labels[inds], features[inds,:], @@ -86,9 +91,9 @@ function build_forest( min_samples_leaf, min_samples_split, min_purity_increase, - rng = rng) + rng = _rng) end - elseif rng isa Integer # each thread gets its own seeded rng + else # each thread gets its own seeded rng Threads.@threads for i in 1:n_trees Random.seed!(rng + i) inds = rand(1:t_samples, n_samples) @@ -101,8 +106,6 @@ function build_forest( min_samples_split, min_purity_increase) end - else - throw("rng must of be type Integer or Random.AbstractRNG") end return Ensemble{S, T}(forest) diff --git a/test/classification/adult.jl b/test/classification/adult.jl index 5a351f1a..5d2b3add 100644 --- a/test/classification/adult.jl +++ b/test/classification/adult.jl @@ -5,7 +5,7 @@ features, labels = load_data("adult") -model = build_tree(labels, features) +model = build_tree(labels, features; rng=StableRNG(1)) preds = apply_tree(model, features) cm = confusion_matrix(labels, preds) @test cm.accuracy > 0.99 @@ -15,13 +15,13 @@ labels = string.(labels) n_subfeatures = 3 n_trees = 5 -model = build_forest(labels, features, n_subfeatures, n_trees) +model = build_forest(labels, features, n_subfeatures, n_trees; rng=StableRNG(1)) preds = apply_forest(model, features) cm = confusion_matrix(labels, preds) @test cm.accuracy > 0.9 n_iterations = 15 -model, coeffs = build_adaboost_stumps(labels, features, n_iterations); +model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1)); preds = apply_adaboost_stumps(model, coeffs, features); cm = confusion_matrix(labels, preds); @test cm.accuracy > 0.8 @@ -29,7 +29,7 @@ cm = confusion_matrix(labels, preds); println("\n##### 3 foldCV Classification Tree #####") pruning_purity = 0.9 nfolds = 3 -accuracy = nfoldCV_tree(labels, features, nfolds, pruning_purity; verbose=false); +accuracy = nfoldCV_tree(labels, features, nfolds, pruning_purity; rng=StableRNG(1), verbose=false); @test mean(accuracy) > 0.8 println("\n##### 3 foldCV Classification Forest #####") @@ -37,13 +37,13 @@ n_subfeatures = 2 n_trees = 10 n_folds = 3 partial_sampling = 0.5 -accuracy = nfoldCV_forest(labels, features, n_folds, n_subfeatures, n_trees, partial_sampling; verbose=false) +accuracy = nfoldCV_forest(labels, features, n_folds, n_subfeatures, n_trees, partial_sampling; rng=StableRNG(1), verbose=false) @test mean(accuracy) > 0.8 println("\n##### nfoldCV Classification Adaboosted Stumps #####") n_iterations = 15 n_folds = 3 -accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations; verbose=false); +accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations; rng=StableRNG(1), verbose=false); @test mean(accuracy) > 0.8 end # @testset diff --git a/test/classification/digits.jl b/test/classification/digits.jl index f3917ca2..e4fbae1e 100644 --- a/test/classification/digits.jl +++ b/test/classification/digits.jl @@ -69,7 +69,8 @@ model = DecisionTree.build_forest( max_depth, min_samples_leaf, min_samples_split, - min_purity_increase) + min_purity_increase; + rng=StableRNG(1)) preds = apply_forest(model, X) cm = confusion_matrix(Y, preds) @test cm.accuracy > 0.95 diff --git a/test/classification/heterogeneous.jl b/test/classification/heterogeneous.jl index 26173696..004d89b2 100644 --- a/test/classification/heterogeneous.jl +++ b/test/classification/heterogeneous.jl @@ -5,29 +5,29 @@ m, n = 10^2, 5 tf = [trues(Int(m/2)) falses(Int(m/2))] -inds = Random.randperm(m) +inds = Random.randperm(StableRNG(1), m) labels = string.(tf[inds]) features = Array{Any}(undef, m, n) -features[:,:] = randn(m, n) -features[:,2] = string.(tf[Random.randperm(m)]) +features[:,:] = randn(StableRNG(1), m, n) +features[:,2] = string.(tf[Random.randperm(StableRNG(1), m)]) features[:,3] = map(t -> round.(Int, t), features[:,3]) features[:,4] = tf[inds] -model = build_tree(labels, features) +model = build_tree(labels, features; rng=StableRNG(1)) preds = apply_tree(model, features) cm = confusion_matrix(labels, preds) @test cm.accuracy > 0.9 n_subfeatures = 2 n_trees = 3 -model = build_forest(labels, features, n_subfeatures, n_trees) +model = build_forest(labels, features, n_subfeatures, n_trees; rng=StableRNG(1)) preds = apply_forest(model, features) cm = confusion_matrix(labels, preds) @test cm.accuracy > 0.9 n_subfeatures = 7 -model, coeffs = build_adaboost_stumps(labels, features, n_subfeatures) +model, coeffs = build_adaboost_stumps(labels, features, n_subfeatures; rng=StableRNG(1)) preds = apply_adaboost_stumps(model, coeffs, features) cm = confusion_matrix(labels, preds) @test cm.accuracy > 0.9 diff --git a/test/classification/iris.jl b/test/classification/iris.jl index f3e54501..ff1c5683 100644 --- a/test/classification/iris.jl +++ b/test/classification/iris.jl @@ -59,14 +59,14 @@ cm = confusion_matrix(labels, preds) # run n-fold cross validation for pruned tree println("\n##### nfoldCV Classification Tree #####") nfolds = 3 -accuracy = nfoldCV_tree(labels, features, nfolds) +accuracy = nfoldCV_tree(labels, features, nfolds; rng=StableRNG(1)) @test mean(accuracy) > 0.8 # train random forest classifier n_trees = 10 n_subfeatures = 2 partial_sampling = 0.5 -model = build_forest(labels, features, n_subfeatures, n_trees, partial_sampling) +model = build_forest(labels, features, n_subfeatures, n_trees, partial_sampling; rng=StableRNG(2)) preds = apply_forest(model, features) cm = confusion_matrix(labels, preds) @test cm.accuracy > 0.95 @@ -80,12 +80,12 @@ n_subfeatures = 2 n_trees = 10 n_folds = 3 partial_sampling = 0.5 -accuracy = nfoldCV_forest(labels, features, nfolds, n_subfeatures, n_trees, partial_sampling) +accuracy = nfoldCV_forest(labels, features, nfolds, n_subfeatures, n_trees, partial_sampling; rng=StableRNG(1)) @test mean(accuracy) > 0.9 # train adaptive-boosted decision stumps n_iterations = 15 -model, coeffs = build_adaboost_stumps(labels, features, n_iterations) +model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1)) preds = apply_adaboost_stumps(model, coeffs, features) cm = confusion_matrix(labels, preds) @test cm.accuracy > 0.9 @@ -97,7 +97,7 @@ probs = apply_adaboost_stumps_proba(model, coeffs, features, classes) println("\n##### nfoldCV Classification Adaboosted Stumps #####") n_iterations = 15 nfolds = 3 -accuracy = nfoldCV_stumps(labels, features, nfolds, n_iterations) +accuracy = nfoldCV_stumps(labels, features, nfolds, n_iterations; rng=StableRNG(1)) @test mean(accuracy) > 0.85 end # @testset diff --git a/test/classification/low_precision.jl b/test/classification/low_precision.jl index 6086d67b..eda010e6 100644 --- a/test/classification/low_precision.jl +++ b/test/classification/low_precision.jl @@ -5,9 +5,9 @@ Random.seed!(16) n,m = 10^3, 5; features = Array{Any}(undef, n, m); -features[:,:] = rand(n, m); +features[:,:] = rand(StableRNG(1), n, m); features[:,1] = round.(Int32, features[:,1]); # convert a column of 32bit integers -weights = rand(-1:1,m); +weights = rand(StableRNG(1), -1:1, m); labels = round.(Int32, features * weights); model = build_stump(labels, features) @@ -25,7 +25,8 @@ model = build_tree( n_subfeatures, max_depth, min_samples_leaf, min_samples_split, - min_purity_increase) + min_purity_increase; + rng=StableRNG(1)) preds = apply_tree(model, features) cm = confusion_matrix(labels, preds) @test typeof(preds) == Vector{Int32} @@ -40,14 +41,15 @@ model = build_forest( n_subfeatures, n_trees, partial_sampling, - max_depth) + max_depth; + rng=StableRNG(1)) preds = apply_forest(model, features) cm = confusion_matrix(labels, preds) @test typeof(preds) == Vector{Int32} @test cm.accuracy > 0.9 n_iterations = Int32(25) -model, coeffs = build_adaboost_stumps(labels, features, n_iterations); +model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1)); preds = apply_adaboost_stumps(model, coeffs, features); cm = confusion_matrix(labels, preds) @test typeof(preds) == Vector{Int32} @@ -67,7 +69,8 @@ accuracy = nfoldCV_tree( max_depth, min_samples_leaf, min_samples_split, - min_purity_increase) + min_purity_increase; + rng=StableRNG(1)) @test mean(accuracy) > 0.7 println("\n##### nfoldCV Classification Forest #####") @@ -87,12 +90,13 @@ accuracy = nfoldCV_forest( max_depth, min_samples_leaf, min_samples_split, - min_purity_increase) + min_purity_increase; + rng=StableRNG(1)) @test mean(accuracy) > 0.7 println("\n##### nfoldCV Adaboosted Stumps #####") n_iterations = Int32(25) -accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations) +accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations; rng=StableRNG(1)) @test mean(accuracy) > 0.6 diff --git a/test/classification/random.jl b/test/classification/random.jl index ab195199..a91ebac6 100644 --- a/test/classification/random.jl +++ b/test/classification/random.jl @@ -4,8 +4,8 @@ Random.seed!(16) n, m = 10^3, 5; -features = rand(n,m); -weights = rand(-1:1,m); +features = rand(StableRNG(1), n, m); +weights = rand(StableRNG(1), -1:1, m); labels = round.(Int, features * weights); model = build_stump(labels, round.(Int, features)) @@ -29,7 +29,7 @@ rx = r"[^<]*< [0-9\.]* ?" matches = eachmatch(rx, text) @test !isempty(matches) -model = build_tree(labels, features) +model = build_tree(labels, features; rng=StableRNG(1)) preds = apply_tree(model, features) cm = confusion_matrix(labels, preds) @test cm.accuracy > 0.9 @@ -49,7 +49,7 @@ t3 = build_tree(labels, features, n_subfeatures; rng=mt) @test (length(t1) != length(t3)) || (depth(t1) != depth(t3)) -model = build_forest(labels, features) +model = build_forest(labels, features; rng=StableRNG(1)) preds = apply_forest(model, features) cm = confusion_matrix(labels, preds) @test cm.accuracy > 0.9 @@ -70,16 +70,17 @@ model = build_forest( max_depth, min_samples_leaf, min_samples_split, - min_purity_increase) + min_purity_increase; + rng=StableRNG(1)) preds = apply_forest(model, features) cm = confusion_matrix(labels, preds) -@test cm.accuracy > 0.9 +@test cm.accuracy > 0.6 @test length(model) == n_trees # test n_subfeatures n_subfeatures = 0 -m_partial = build_forest(labels, features) # default sqrt(n_features) -m_full = build_forest(labels, features, n_subfeatures) +m_partial = build_forest(labels, features; rng=StableRNG(1)) # default sqrt(n_features) +m_full = build_forest(labels, features, n_subfeatures; rng=StableRNG(1)) @test all( length.(m_full.trees) .< length.(m_partial.trees) ) # test partial_sampling parameter, train on single sample @@ -122,43 +123,54 @@ m3 = build_forest(labels, features, n_iterations = 25 -model, coeffs = build_adaboost_stumps(labels, features, n_iterations); +model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1)); preds = apply_adaboost_stumps(model, coeffs, features); cm = confusion_matrix(labels, preds) @test cm.accuracy > 0.6 @test typeof(preds) == Vector{Int} @test length(model) == n_iterations +""" +RNGs can look like they produce stable results, but do in fact differ when you run it many times. +In some RNGs the problem already shows up when doing two runs and comparing those. +This loop tests multiple RNGs to have a higher chance of spotting a problem. +See https://github.com/JuliaAI/DecisionTree.jl/pull/174 for more information. +""" +function test_rng(f::Function, args, expected_accuracy) + println("Testing $f") + accuracy = f(args...; rng=StableRNG(10), verbose=false) + accuracy2 = f(args...; rng=StableRNG(5), verbose=false) + @test accuracy != accuracy2 + + for i in 10:14 + accuracy = f(args...; rng=StableRNG(i), verbose=false) + accuracy2 = f(args...; rng=StableRNG(i), verbose=false) + @test mean(accuracy) > expected_accuracy + @test accuracy == accuracy2 + end +end + println("\n##### nfoldCV Classification Tree #####") nfolds = 3 pruning_purity = 1.0 max_depth = 5 -accuracy = nfoldCV_tree(labels, features, nfolds, pruning_purity, max_depth; rng=10, verbose=false) -accuracy2 = nfoldCV_tree(labels, features, nfolds, pruning_purity, max_depth; rng=10) -accuracy3 = nfoldCV_tree(labels, features, nfolds, pruning_purity, max_depth; rng=5) -@test mean(accuracy) > 0.7 -@test accuracy == accuracy2 -@test accuracy != accuracy3 +args = [labels, features, nfolds, pruning_purity, max_depth] +test_rng(nfoldCV_tree, args, 0.7) println("\n##### nfoldCV Classification Forest #####") nfolds = 3 n_subfeatures = 2 n_trees = 10 -accuracy = nfoldCV_forest(labels, features, nfolds, n_subfeatures, n_trees; rng=10, verbose=false) -accuracy2 = nfoldCV_forest(labels, features, nfolds, n_subfeatures, n_trees; rng=10) -accuracy3 = nfoldCV_forest(labels, features, nfolds, n_subfeatures, n_trees; rng=5) -@test mean(accuracy) > 0.7 -@test accuracy == accuracy2 -@test accuracy != accuracy3 +args = [labels, features, nfolds, n_subfeatures, n_trees] +test_rng(nfoldCV_forest, args, 0.7) + +# This is a smoke test to verify that the multi-threaded code doesn't crash. +nfoldCV_forest(args...; rng=MersenneTwister(1)) println("\n##### nfoldCV Adaboosted Stumps #####") n_iterations = 25 n_folds = 3 -accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations; rng=10, verbose=false) -accuracy2 = nfoldCV_stumps(labels, features, n_folds, n_iterations; rng=10) -accuracy3 = nfoldCV_stumps(labels, features, n_folds, n_iterations; rng=5) -@test mean(accuracy) > 0.6 -@test accuracy == accuracy2 -@test accuracy != accuracy3 +args = [labels, features, n_folds, n_iterations] +test_rng(nfoldCV_stumps, args, 0.6) end # @testset diff --git a/test/classification/scikitlearn.jl b/test/classification/scikitlearn.jl index c4329439..9417f68f 100644 --- a/test/classification/scikitlearn.jl +++ b/test/classification/scikitlearn.jl @@ -1,37 +1,36 @@ @testset "scikitlearn.jl" begin Random.seed!(2) -n,m = 10^3, 5 ; -features = rand(n,m); -weights = rand(-1:1,m); +n, m = 10^3, 5; +features = rand(StableRNG(1), n, m); +weights = rand(StableRNG(1), -1:1, m); labels = round.(Int, features * weights); # I wish we could use ScikitLearn.jl's cross-validation, but that'd require # installing it on Travis -model = fit!(DecisionTreeClassifier(pruning_purity_threshold=0.9), features, labels) +model = fit!(DecisionTreeClassifier(; rng=StableRNG(1), pruning_purity_threshold=0.9), features, labels) @test mean(predict(model, features) .== labels) > 0.8 -model = fit!(RandomForestClassifier(), features, labels) +model = fit!(RandomForestClassifier(; rng=StableRNG(1)), features, labels) @test mean(predict(model, features) .== labels) > 0.8 -model = fit!(AdaBoostStumpClassifier(), features, labels) +model = fit!(AdaBoostStumpClassifier(; rng=StableRNG(1)), features, labels) # Adaboost isn't so hot on this task, disabled for now mean(predict(model, features) .== labels) -Random.seed!(2) N = 3000 -X = randn(N, 10) +X = randn(StableRNG(1), N, 10) # TODO: we should probably support fit!(::DecisionTreeClassifier, ::BitArray) y = convert(Vector{Bool}, randn(N) .< 0) max_depth = 5 -model = fit!(DecisionTreeClassifier(max_depth=max_depth), X, y) +model = fit!(DecisionTreeClassifier(; rng=StableRNG(1), max_depth=max_depth), X, y) @test depth(model) == max_depth ## Test that the RNG arguments work as expected Random.seed!(2) -X = randn(100, 10) -y = rand(Bool, 100); +X = randn(StableRNG(1), 100, 10) +y = rand(StableRNG(1), Bool, 100); @test predict_proba(fit!(RandomForestClassifier(; rng=10), X, y), X) == predict_proba(fit!(RandomForestClassifier(; rng=10), X, y), X) diff --git a/test/regression/digits.jl b/test/regression/digits.jl index 8b74f6d8..fdf6105d 100644 --- a/test/regression/digits.jl +++ b/test/regression/digits.jl @@ -75,13 +75,14 @@ model = build_forest( max_depth, min_samples_leaf, min_samples_split, - min_purity_increase) + min_purity_increase; + rng=StableRNG(1)) preds = apply_forest(model, X) @test R2(Y, preds) > 0.8 println("\n##### 3 foldCV Regression Tree #####") n_folds = 5 -r2 = nfoldCV_tree(Y, X, n_folds; verbose=false); +r2 = nfoldCV_tree(Y, X, n_folds; rng=StableRNG(1), verbose=false); @test mean(r2) > 0.55 println("\n##### 3 foldCV Regression Forest #####") @@ -89,7 +90,7 @@ n_subfeatures = 2 n_trees = 10 n_folds = 5 partial_sampling = 0.5 -r2 = nfoldCV_forest(Y, X, n_folds, n_subfeatures, n_trees, partial_sampling; verbose=false) +r2 = nfoldCV_forest(Y, X, n_folds, n_subfeatures, n_trees, partial_sampling; rng=StableRNG(1), verbose=false) @test mean(r2) > 0.55 end # @testset diff --git a/test/regression/low_precision.jl b/test/regression/low_precision.jl index 9b6a5f9b..2a5ec11c 100644 --- a/test/regression/low_precision.jl +++ b/test/regression/low_precision.jl @@ -2,12 +2,12 @@ Random.seed!(5) -n, m = 10^3, 5 ; +n, m = 10^3, 5; features = Array{Any}(undef, n, m); -features[:,:] = randn(n, m); -features[:,1] = round.(Int32, features[:,1]); # convert a column of 32bit integers -weights = rand(-2:2,m); -labels = float.(features * weights); # cast to Array{Float64,1} +features[:,:] = randn(StableRNG(1), n, m); +features[:,1] = round.(Int32, features[:,1]); +weights = rand(StableRNG(1), -2:2, m); +labels = float.(features * weights); min_samples_leaf = Int32(1) n_subfeatures = Int32(0) @@ -20,7 +20,8 @@ model = build_tree( max_depth, min_samples_leaf, min_samples_split, - min_purity_increase) + min_purity_increase; + rng=StableRNG(1)) preds = apply_tree(model, round.(Int32, features)) @test R2(labels, preds) < 0.95 @test typeof(preds) <: Vector{Float64} @@ -40,7 +41,8 @@ model = build_forest( max_depth, min_samples_leaf, min_samples_split, - min_purity_increase) + min_purity_increase; + rng=StableRNG(1)) preds = apply_forest(model, features) @test R2(labels, preds) > 0.9 @test typeof(preds) <: Vector{Float64} @@ -59,7 +61,8 @@ r2 = nfoldCV_tree( max_depth, min_samples_leaf, min_samples_split, - min_purity_increase) + min_purity_increase; + rng=StableRNG(1)) @test mean(r2) > 0.6 println("\n##### nfoldCV Regression Forest #####") @@ -79,10 +82,10 @@ r2 = nfoldCV_forest( max_depth, min_samples_leaf, min_samples_split, - min_purity_increase) + min_purity_increase; + rng=StableRNG(1)) @test mean(r2) > 0.8 - # Test Float16 labels, and Float16 features features = Float16.(features) labels = Float16.(labels) diff --git a/test/regression/random.jl b/test/regression/random.jl index 31dce3ce..ac5d40ae 100644 --- a/test/regression/random.jl +++ b/test/regression/random.jl @@ -4,9 +4,9 @@ Random.seed!(5) n, m = 10^3, 5 ; features = Array{Any}(undef, n, m); -features[:,:] = randn(n, m); +features[:,:] = randn(StableRNG(1), n, m); features[:,1] = round.(Integer, features[:,1]); # convert a column of integers -weights = rand(-2:2,m); +weights = rand(StableRNG(1), -2:2, m); labels = float.(features * weights); # cast to Array{Float64,1} model = build_stump(labels, features) @@ -20,7 +20,8 @@ model = build_tree( labels, features, n_subfeatures, max_depth, - min_samples_leaf) + min_samples_leaf; + rng=StableRNG(1)) preds = apply_tree(model, features); @test R2(labels, preds) > 0.99 # R2: coeff of determination @test typeof(preds) <: Vector{Float64} @@ -88,7 +89,7 @@ t3 = build_tree(labels, features, n_subfeatures; rng=mt) @test (length(t1) != length(t3)) || (depth(t1) != depth(t3)) -model = build_forest(labels, features) +model = build_forest(labels, features; rng=StableRNG(1)) preds = apply_forest(model, features) @test R2(labels, preds) > 0.9 @test typeof(preds) <: Vector{Float64} @@ -108,7 +109,8 @@ model = build_forest( max_depth, min_samples_leaf, min_samples_split, - min_purity_increase) + min_purity_increase; + rng=StableRNG(1)) preds = apply_forest(model, features) @test R2(labels, preds) > 0.9 @test length(model) == n_trees @@ -125,7 +127,8 @@ m_partial = build_forest( n_trees, partial_sampling, max_depth, - min_samples_leaf) + min_samples_leaf; + rng=10) n_subfeatures = 0 m_full = build_forest( labels, features, @@ -133,7 +136,8 @@ m_full = build_forest( n_trees, partial_sampling, max_depth, - min_samples_leaf) + min_samples_leaf; + rng=10) @test mean(depth.(m_full.trees)) < mean(depth.(m_partial.trees)) # test partial_sampling parameter, train on single sample @@ -190,9 +194,9 @@ println("\n##### nfoldCV Regression Forest #####") nfolds = 3 n_subfeatures = 2 n_trees = 10 -r2_1 = nfoldCV_forest(labels, features, nfolds, n_subfeatures, n_trees; rng=10, verbose=false) -r2_2 = nfoldCV_forest(labels, features, nfolds, n_subfeatures, n_trees; rng=10) -r2_3 = nfoldCV_forest(labels, features, nfolds, n_subfeatures, n_trees; rng=5) +r2_1 = nfoldCV_forest(labels, features, nfolds, n_subfeatures, n_trees; rng=StableRNG(10), verbose=false) +r2_2 = nfoldCV_forest(labels, features, nfolds, n_subfeatures, n_trees; rng=StableRNG(10)) +r2_3 = nfoldCV_forest(labels, features, nfolds, n_subfeatures, n_trees; rng=StableRNG(5)) @test mean(r2_1) > 0.8 @test r2_1 == r2_2 @test r2_1 != r2_3 diff --git a/test/regression/scikitlearn.jl b/test/regression/scikitlearn.jl index 13e78742..2da415d9 100644 --- a/test/regression/scikitlearn.jl +++ b/test/regression/scikitlearn.jl @@ -1,33 +1,36 @@ @testset "scikitlearn.jl" begin -Random.seed!(2) -n,m = 10^3, 5 ; -features = rand(n,m); -weights = rand(-1:1,m); +n, m = 10^3, 5; +features = rand(StableRNG(1), n, m); +weights = rand(StableRNG(1), -1:1, m); labels = features * weights; -model = fit!(DecisionTreeRegressor(min_samples_leaf=5, pruning_purity_threshold=0.1), features, labels) -@test R2(labels, predict(model, features)) > 0.8 +let + regressor = DecisionTreeRegressor(; rng=StableRNG(1), min_samples_leaf=5, pruning_purity_threshold=0.1) + model = fit!(regressor, features, labels) + @test R2(labels, predict(model, features)) > 0.8 +end -model = fit!(DecisionTreeRegressor(min_samples_split=5), features, labels) +model = fit!(DecisionTreeRegressor(; rng=StableRNG(1), min_samples_split=5), features, labels) @test R2(labels, predict(model, features)) > 0.8 -model = fit!(RandomForestRegressor(n_trees=10, min_samples_leaf=5, n_subfeatures=2), features, labels) -@test R2(labels, predict(model, features)) > 0.8 +let + regressor = RandomForestRegressor(; rng=StableRNG(1), n_trees=10, min_samples_leaf=5, n_subfeatures=2) + model = fit!(regressor, features, labels) + @test R2(labels, predict(model, features)) > 0.8 +end -Random.seed!(2) N = 3000 -X = randn(N, 10) -y = randn(N) +X = randn(StableRNG(1), N, 10) +y = randn(StableRNG(1), N) max_depth = 5 -model = fit!(DecisionTreeRegressor(max_depth=max_depth), X, y) +model = fit!(DecisionTreeRegressor(; rng=StableRNG(1), max_depth=max_depth), X, y) @test depth(model) == max_depth ## Test that the RNG arguments work as expected -Random.seed!(2) -X = randn(100, 10) -y = randn(100) +X = randn(StableRNG(1), 100, 10) +y = randn(StableRNG(1), 100) @test fit_predict!(RandomForestRegressor(; rng=10), X, y) == fit_predict!(RandomForestRegressor(; rng=10), X, y) @test fit_predict!(RandomForestRegressor(; rng=10), X, y) != diff --git a/test/runtests.jl b/test/runtests.jl index abf30f68..fd5dd6f5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,9 +1,11 @@ +import AbstractTrees + using DecisionTree using DelimitedFiles using Random using ScikitLearnBase +using StableRNGs using Statistics -import AbstractTrees using Test println("Julia version: ", VERSION) @@ -56,4 +58,4 @@ test_suites = [ end end end -end \ No newline at end of file +end From bfe6ac5e7a1d0e30b843081296b8ede5a3e67666 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Wed, 22 Jun 2022 10:47:58 +1200 Subject: [PATCH 10/10] bump 0.10.13 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f3f22e20..f8d46189 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "DecisionTree" uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" license = "MIT" desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms" -version = "0.10.12" +version = "0.10.13" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"