Skip to content

Commit

Permalink
Merge pull request #178 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.10.13 release
  • Loading branch information
ablaom authored Jun 21, 2022
2 parents 7e090bb + bfe6ac5 commit 66f99b8
Show file tree
Hide file tree
Showing 19 changed files with 225 additions and 163 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
version:
version:
- '1.0'
- '1.6'
- '1' # automatically expands to the latest stable 1.x release of Julia
Expand Down
11 changes: 8 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@ name = "DecisionTree"
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
license = "MIT"
desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms"
version = "0.10.12"
version = "0.10.13"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScikitLearnBase = "6e75b9c4-186b-50bd-896f-2d2496a4843e"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
AbstractTrees = "0.3"
ScikitLearnBase = "0.5"
julia = "1"

[extras]
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["StableRNGs", "Test"]
74 changes: 40 additions & 34 deletions src/DecisionTree.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__precompile__()

module DecisionTree
module DecisionTree

import Base: length, show, convert, promote_rule, zero
using DelimitedFiles
Expand Down Expand Up @@ -80,55 +80,61 @@ length(ensemble::Ensemble) = length(ensemble.trees)
depth(leaf::Leaf) = 0
depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))

function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
n_matches = count(leaf.values .== leaf.majority)
ratio = string(n_matches, "/", length(leaf.values))
println(io, "$(leaf.majority) : $(ratio)")
end
function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
matches = findall(leaf.values .== leaf.majority)
ratio = string(length(matches)) * "/" * string(length(leaf.values))
println("$(leaf.majority) : $(ratio)")
return print_tree(stdout, leaf, depth, indent; feature_names=feature_names)
end


"""
print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing)
Print a textual visualization of the given decision tree `tree`.
In the example output below, the top node considers whether
"Feature 3" is above or below the threshold -28.156052806422238.
If the value of "Feature 3" is strictly below the threshold for some input to be classified,
we move to the `L->` part underneath, which is a node
looking at if "Feature 2" is above or below -161.04351901384842.
If the value of "Feature 2" is strictly below the threshold for some input to be classified,
we end up at `L-> 5 : 842/3650`. This is to be read as "In the left split,
the tree will classify the input as class 5, as 842 of the 3650 datapoints
in the training data that ended up here were of class 5."
print_tree([io::IO,] tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
Print a textual visualization of the specified `tree`. For example, if
for some input pattern the value of "Feature 3" is "-30" and the value
of "Feature 2" is "100", then, according to the sample output below,
the majority class prediction is 7. Moreover, one can see that of the
10555 training samples that terminate at the same leaf as this input
data, 2493 of these predict the majority class, leading to a
probabilistic prediction for class 7 of `2493/10555`. Ratios for
non-majority classes are not shown.
# Example output:
```
Feature 3, Threshold -28.156052806422238
L-> Feature 2, Threshold -161.04351901384842
L-> 5 : 842/3650
R-> 7 : 2493/10555
R-> Feature 7, Threshold 108.1408338577021
L-> 2 : 2434/15287
R-> 8 : 1227/3508
Feature 3 < -28.15 ?
├─ Feature 2 < -161.0 ?
├─ 5 : 842/3650
└─ 7 : 2493/10555
└─ Feature 7 < 108.1 ?
├─ 2 : 2434/15287
└─ 8 : 1227/3508
```
To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or
`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or
`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
"""
function print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing)
function print_tree(io::IO, tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
if depth == indent
println()
println(io)
return
end
featval = round(tree.featval; sigdigits=sigdigits)
if feature_names === nothing
println("Feature $(tree.featid), Threshold $(tree.featval)")
println(io, "Feature $(tree.featid) < $featval ?")
else
println("Feature $(tree.featid): \"$(feature_names[tree.featid])\", Threshold $(tree.featval)")
println(io, "Feature $(tree.featid): \"$(feature_names[tree.featid])\" < $featval ?")
end
print(" " ^ indent * "L-> ")
print_tree(tree.left, depth, indent + 1; feature_names = feature_names)
print(" " ^ indent * "R-> ")
print_tree(tree.right, depth, indent + 1; feature_names = feature_names)
print(io, " " ^ indent * "├─ ")
print_tree(io, tree.left, depth, indent + 1; feature_names=feature_names)
print(io, " " ^ indent * "└─ ")
print_tree(io, tree.right, depth, indent + 1; feature_names=feature_names)
end
function print_tree(tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
return print_tree(stdout, tree, depth, indent; sigdigits=sigdigits, feature_names=feature_names)
end

function show(io::IO, leaf::Leaf)
Expand Down
22 changes: 13 additions & 9 deletions src/classification/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ end
function _convert(
node :: treeclassifier.NodeMeta{S},
list :: AbstractVector{T},
labels :: AbstractVector{T}) where {S, T}
labels :: AbstractVector{T}
) where {S, T}

if node.is_leaf
return Leaf{T}(list[node.label], labels[node.region])
Expand Down Expand Up @@ -138,7 +139,7 @@ function prune_tree(tree::LeafOrNode{S, T}, purity_thresh=1.0) where {S, T}
end


apply_tree(leaf::Leaf{T}, feature::AbstractVector{S}) where {S, T} = leaf.majority
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.majority

function apply_tree(tree::Node{S, T}, features::AbstractVector{S}) where {S, T}
if tree.featid == 0
Expand Down Expand Up @@ -197,7 +198,7 @@ function build_forest(
min_samples_leaf = 1,
min_samples_split = 2,
min_purity_increase = 0.0;
rng = Random.GLOBAL_RNG) where {S, T}
rng::Union{Integer,AbstractRNG} = Random.GLOBAL_RNG) where {S, T}

if n_trees < 1
throw("the number of trees must be >= 1")
Expand All @@ -221,7 +222,12 @@ function build_forest(

if rng isa Random.AbstractRNG
Threads.@threads for i in 1:n_trees
inds = rand(rng, 1:t_samples, n_samples)
# The Mersenne Twister (Julia's default) is not thread-safe.
_rng = copy(rng)
# Take some elements from the ring to have different states for each tree.
# This is the only way given that only a `copy` can be expected to exist for RNGs.
rand(_rng, i)
inds = rand(_rng, 1:t_samples, n_samples)
forest[i] = build_tree(
labels[inds],
features[inds,:],
Expand All @@ -231,9 +237,9 @@ function build_forest(
min_samples_split,
min_purity_increase,
loss = loss,
rng = rng)
rng = _rng)
end
elseif rng isa Integer # each thread gets its own seeded rng
else # each thread gets its own seeded rng
Threads.@threads for i in 1:n_trees
Random.seed!(rng + i)
inds = rand(1:t_samples, n_samples)
Expand All @@ -247,8 +253,6 @@ function build_forest(
min_purity_increase,
loss = loss)
end
else
throw("rng must of be type Integer or Random.AbstractRNG")
end

return Ensemble{S, T}(forest)
Expand Down Expand Up @@ -298,7 +302,7 @@ function build_adaboost_stumps(
labels :: AbstractVector{T},
features :: AbstractMatrix{S},
n_iterations :: Integer;
rng = Random.GLOBAL_RNG) where {S, T}
rng = Random.GLOBAL_RNG) where {S, T}
N = length(labels)
weights = ones(N) / N
stumps = Node{S, T}[]
Expand Down
3 changes: 2 additions & 1 deletion src/measures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ function _nfoldCV(classifier::Symbol, labels::AbstractVector{T}, features::Abstr
predictions = apply_forest(model, test_features)
elseif classifier == :stumps
model, coeffs = build_adaboost_stumps(
train_labels, train_features, n_iterations)
train_labels, train_features, n_iterations; rng=rng)
predictions = apply_adaboost_stumps(model, coeffs, test_features)
end
cm = confusion_matrix(test_labels, predictions)
Expand Down Expand Up @@ -186,6 +186,7 @@ function nfoldCV_stumps(
n_iterations ::Integer = 10;
verbose :: Bool = true,
rng = Random.GLOBAL_RNG) where {S, T}
rng = mk_rng(rng)::Random.AbstractRNG
_nfoldCV(:stumps, labels, features, n_folds, n_iterations; verbose=verbose, rng=rng)
end

Expand Down
15 changes: 9 additions & 6 deletions src/regression/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function build_forest(
min_samples_leaf = 5,
min_samples_split = 2,
min_purity_increase = 0.0;
rng = Random.GLOBAL_RNG) where {S, T <: Float64}
rng::Union{Integer,AbstractRNG} = Random.GLOBAL_RNG) where {S, T <: Float64}

if n_trees < 1
throw("the number of trees must be >= 1")
Expand All @@ -77,7 +77,12 @@ function build_forest(

if rng isa Random.AbstractRNG
Threads.@threads for i in 1:n_trees
inds = rand(rng, 1:t_samples, n_samples)
# The Mersenne Twister (Julia's default) is not thread-safe.
_rng = copy(rng)
# Take some elements from the ring to have different states for each tree.
# This is the only way given that only a `copy` can be expected to exist for RNGs.
rand(_rng, i)
inds = rand(_rng, 1:t_samples, n_samples)
forest[i] = build_tree(
labels[inds],
features[inds,:],
Expand All @@ -86,9 +91,9 @@ function build_forest(
min_samples_leaf,
min_samples_split,
min_purity_increase,
rng = rng)
rng = _rng)
end
elseif rng isa Integer # each thread gets its own seeded rng
else # each thread gets its own seeded rng
Threads.@threads for i in 1:n_trees
Random.seed!(rng + i)
inds = rand(1:t_samples, n_samples)
Expand All @@ -101,8 +106,6 @@ function build_forest(
min_samples_split,
min_purity_increase)
end
else
throw("rng must of be type Integer or Random.AbstractRNG")
end

return Ensemble{S, T}(forest)
Expand Down
6 changes: 4 additions & 2 deletions src/scikitlearnAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -386,5 +386,7 @@ length(dt::DecisionTreeClassifier) = length(dt.root)
length(dt::DecisionTreeRegressor) = length(dt.root)

print_tree(dt::DecisionTreeClassifier, depth=-1; kwargs...) = print_tree(dt.root, depth; kwargs...)
print_tree(dt::DecisionTreeRegressor, depth=-1; kwargs...) = print_tree(dt.root, depth; kwargs...)
print_tree(n::Nothing, depth=-1; kwargs...) = show(n)
print_tree(io::IO, dt::DecisionTreeClassifier, depth=-1; kwargs...) = print_tree(io, dt.root, depth; kwargs...)
print_tree(dt::DecisionTreeRegressor, depth=-1; kwargs...) = print_tree(dt.root, depth; kwargs...)
print_tree(io::IO, dt::DecisionTreeRegressor, depth=-1; kwargs...) = print_tree(io, dt.root, depth; kwargs...)
print_tree(n::Nothing, depth=-1; kwargs...) = show(n)
12 changes: 6 additions & 6 deletions test/classification/adult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

features, labels = load_data("adult")

model = build_tree(labels, features)
model = build_tree(labels, features; rng=StableRNG(1))
preds = apply_tree(model, features)
cm = confusion_matrix(labels, preds)
@test cm.accuracy > 0.99
Expand All @@ -15,35 +15,35 @@ labels = string.(labels)

n_subfeatures = 3
n_trees = 5
model = build_forest(labels, features, n_subfeatures, n_trees)
model = build_forest(labels, features, n_subfeatures, n_trees; rng=StableRNG(1))
preds = apply_forest(model, features)
cm = confusion_matrix(labels, preds)
@test cm.accuracy > 0.9

n_iterations = 15
model, coeffs = build_adaboost_stumps(labels, features, n_iterations);
model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1));
preds = apply_adaboost_stumps(model, coeffs, features);
cm = confusion_matrix(labels, preds);
@test cm.accuracy > 0.8

println("\n##### 3 foldCV Classification Tree #####")
pruning_purity = 0.9
nfolds = 3
accuracy = nfoldCV_tree(labels, features, nfolds, pruning_purity; verbose=false);
accuracy = nfoldCV_tree(labels, features, nfolds, pruning_purity; rng=StableRNG(1), verbose=false);
@test mean(accuracy) > 0.8

println("\n##### 3 foldCV Classification Forest #####")
n_subfeatures = 2
n_trees = 10
n_folds = 3
partial_sampling = 0.5
accuracy = nfoldCV_forest(labels, features, n_folds, n_subfeatures, n_trees, partial_sampling; verbose=false)
accuracy = nfoldCV_forest(labels, features, n_folds, n_subfeatures, n_trees, partial_sampling; rng=StableRNG(1), verbose=false)
@test mean(accuracy) > 0.8

println("\n##### nfoldCV Classification Adaboosted Stumps #####")
n_iterations = 15
n_folds = 3
accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations; verbose=false);
accuracy = nfoldCV_stumps(labels, features, n_folds, n_iterations; rng=StableRNG(1), verbose=false);
@test mean(accuracy) > 0.8

end # @testset
3 changes: 2 additions & 1 deletion test/classification/digits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ model = DecisionTree.build_forest(
max_depth,
min_samples_leaf,
min_samples_split,
min_purity_increase)
min_purity_increase;
rng=StableRNG(1))
preds = apply_forest(model, X)
cm = confusion_matrix(Y, preds)
@test cm.accuracy > 0.95
Expand Down
12 changes: 6 additions & 6 deletions test/classification/heterogeneous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,29 @@
m, n = 10^2, 5

tf = [trues(Int(m/2)) falses(Int(m/2))]
inds = Random.randperm(m)
inds = Random.randperm(StableRNG(1), m)
labels = string.(tf[inds])

features = Array{Any}(undef, m, n)
features[:,:] = randn(m, n)
features[:,2] = string.(tf[Random.randperm(m)])
features[:,:] = randn(StableRNG(1), m, n)
features[:,2] = string.(tf[Random.randperm(StableRNG(1), m)])
features[:,3] = map(t -> round.(Int, t), features[:,3])
features[:,4] = tf[inds]

model = build_tree(labels, features)
model = build_tree(labels, features; rng=StableRNG(1))
preds = apply_tree(model, features)
cm = confusion_matrix(labels, preds)
@test cm.accuracy > 0.9

n_subfeatures = 2
n_trees = 3
model = build_forest(labels, features, n_subfeatures, n_trees)
model = build_forest(labels, features, n_subfeatures, n_trees; rng=StableRNG(1))
preds = apply_forest(model, features)
cm = confusion_matrix(labels, preds)
@test cm.accuracy > 0.9

n_subfeatures = 7
model, coeffs = build_adaboost_stumps(labels, features, n_subfeatures)
model, coeffs = build_adaboost_stumps(labels, features, n_subfeatures; rng=StableRNG(1))
preds = apply_adaboost_stumps(model, coeffs, features)
cm = confusion_matrix(labels, preds)
@test cm.accuracy > 0.9
Expand Down
Loading

0 comments on commit 66f99b8

Please sign in to comment.