Skip to content

Commit

Permalink
Merge pull request #219 from JuliaAI/rh/format
Browse files Browse the repository at this point in the history
Apply JuliaFormatter
  • Loading branch information
ablaom authored Feb 15, 2023
2 parents 7591df6 + 413a2d9 commit 792567b
Show file tree
Hide file tree
Showing 33 changed files with 3,216 additions and 2,667 deletions.
149 changes: 91 additions & 58 deletions src/DecisionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,79 +7,105 @@ using Random
using Statistics
import AbstractTrees

export Leaf, Node, Root, Ensemble, print_tree, depth, build_stump, build_tree,
prune_tree, apply_tree, apply_tree_proba, nfoldCV_tree, build_forest,
apply_forest, apply_forest_proba, nfoldCV_forest, build_adaboost_stumps,
apply_adaboost_stumps, apply_adaboost_stumps_proba, nfoldCV_stumps,
load_data, impurity_importance, split_importance, permutation_importance
export Leaf,
Node,
Root,
Ensemble,
print_tree,
depth,
build_stump,
build_tree,
prune_tree,
apply_tree,
apply_tree_proba,
nfoldCV_tree,
build_forest,
apply_forest,
apply_forest_proba,
nfoldCV_forest,
build_adaboost_stumps,
apply_adaboost_stumps,
apply_adaboost_stumps_proba,
nfoldCV_stumps,
load_data,
impurity_importance,
split_importance,
permutation_importance

# ScikitLearn API
export DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier,
RandomForestRegressor, AdaBoostStumpClassifier,
# Should we export these functions? They have a conflict with
# DataFrames/RDataset over fit!, and users can always
# `using ScikitLearnBase`.
predict, predict_proba, fit!, get_classes
export DecisionTreeClassifier,
DecisionTreeRegressor,
RandomForestClassifier,
RandomForestRegressor,
AdaBoostStumpClassifier,
# Should we export these functions? They have a conflict with
# DataFrames/RDataset over fit!, and users can always
# `using ScikitLearnBase`.
predict,
predict_proba,
fit!,
get_classes

export InfoNode, InfoLeaf, wrap

###########################
########## Types ##########

struct Leaf{T}
majority :: T
values :: Vector{T}
majority::T
values::Vector{T}
end

struct Node{S, T}
featid :: Int
featval :: S
left :: Union{Leaf{T}, Node{S, T}}
right :: Union{Leaf{T}, Node{S, T}}
struct Node{S,T}
featid::Int
featval::S
left::Union{Leaf{T},Node{S,T}}
right::Union{Leaf{T},Node{S,T}}
end

const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
const LeafOrNode{S,T} = Union{Leaf{T},Node{S,T}}

struct Root{S, T}
node :: LeafOrNode{S, T}
n_feat :: Int
featim :: Vector{Float64} # impurity importance
struct Root{S,T}
node::LeafOrNode{S,T}
n_feat::Int
featim::Vector{Float64} # impurity importance
end

struct Ensemble{S, T}
trees :: Vector{LeafOrNode{S, T}}
n_feat :: Int
featim :: Vector{Float64}
struct Ensemble{S,T}
trees::Vector{LeafOrNode{S,T}}
n_feat::Int
featim::Vector{Float64}
end

is_leaf(l::Leaf) = true
is_leaf(n::Node) = false

_zero(::Type{String}) = ""
_zero(x::Any) = zero(x)
convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} =
function convert(::Type{Node{S,T}}, lf::Leaf{T}) where {S,T}
Node(0, _zero(S), lf, Leaf(_zero(T), [_zero(T)]))
convert(::Type{Root{S, T}}, node::LeafOrNode{S, T}) where {S, T} =
Root{S, T}(node, 0, Float64[])
convert(::Type{LeafOrNode{S, T}}, tree::Root{S, T}) where {S, T} = tree.node
promote_rule(::Type{Node{S, T}}, ::Type{Leaf{T}}) where {S, T} = Node{S, T}
promote_rule(::Type{Leaf{T}}, ::Type{Node{S, T}}) where {S, T} = Node{S, T}
promote_rule(::Type{Root{S, T}}, ::Type{Leaf{T}}) where {S, T} = Root{S, T}
promote_rule(::Type{Leaf{T}}, ::Type{Root{S, T}}) where {S, T} = Root{S, T}
promote_rule(::Type{Root{S, T}}, ::Type{Node{S, T}}) where {S, T} = Root{S, T}
promote_rule(::Type{Node{S, T}}, ::Type{Root{S, T}}) where {S, T} = Root{S, T}
end
function convert(::Type{Root{S,T}}, node::LeafOrNode{S,T}) where {S,T}
Root{S,T}(node, 0, Float64[])
end
convert(::Type{LeafOrNode{S,T}}, tree::Root{S,T}) where {S,T} = tree.node
promote_rule(::Type{Node{S,T}}, ::Type{Leaf{T}}) where {S,T} = Node{S,T}
promote_rule(::Type{Leaf{T}}, ::Type{Node{S,T}}) where {S,T} = Node{S,T}
promote_rule(::Type{Root{S,T}}, ::Type{Leaf{T}}) where {S,T} = Root{S,T}
promote_rule(::Type{Leaf{T}}, ::Type{Root{S,T}}) where {S,T} = Root{S,T}
promote_rule(::Type{Root{S,T}}, ::Type{Node{S,T}}) where {S,T} = Root{S,T}
promote_rule(::Type{Node{S,T}}, ::Type{Root{S,T}}) where {S,T} = Root{S,T}

const DOC_WHATS_A_TREE =
"Here `tree` is any `DecisionTree.Root`, `DecisionTree.Node` or "*
"Here `tree` is any `DecisionTree.Root`, `DecisionTree.Node` or " *
"`DecisionTree.Leaf` instance, as returned, for example, by [`build_tree`](@ref)."
const DOC_WHATS_A_FOREST =
"Here `forest` is any `DecisionTree.Ensemble` instance, as returned, for "*
"Here `forest` is any `DecisionTree.Ensemble` instance, as returned, for " *
"example, by [`build_forest`](@ref)."
const DOC_ENSEMBLE =
"`DecisionTree.Ensemble` objects are returned by, for example, `build_forest`."
const DOC_ENSEMBLE = "`DecisionTree.Ensemble` objects are returned by, for example, `build_forest`."
const ERR_ENSEMBLE_VCAT = DimensionMismatch(
"Ensembles that record feature impurity importances cannot be combined when "*
"they were generated using differing numbers of features. "
"Ensembles that record feature impurity importances cannot be combined when " *
"they were generated using differing numbers of features. ",
)

"""
Expand Down Expand Up @@ -124,12 +150,13 @@ function Base.vcat(e1::Ensemble{S,T}, e2::Ensemble{S,T}) where {S,T}
Ensemble{S,T}(trees, e2.n_feat, featim)
end

Base.getindex(ensemble::DecisionTree.Ensemble, I) =
function Base.getindex(ensemble::DecisionTree.Ensemble, I)
DecisionTree.Ensemble(ensemble.trees[I], ensemble.n_feat, ensemble.featim)
end

# make a Random Number Generator object
mk_rng(rng::Random.AbstractRNG) = rng
mk_rng(seed::T) where T <: Integer = Random.MersenneTwister(seed)
mk_rng(seed::T) where {T<:Integer} = Random.MersenneTwister(seed)

##############################
########## Includes ##########
Expand All @@ -142,7 +169,6 @@ include("regression/main.jl")
include("scikitlearnAPI.jl")
include("abstract_trees.jl")


#############################
########## Methods ##########

Expand All @@ -155,7 +181,9 @@ depth(leaf::Leaf) = 0
depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
depth(tree::Root) = depth(tree.node)

function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
function print_tree(
io::IO, leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing
)
n_matches = count(leaf.values .== leaf.majority)
ratio = string(n_matches, "/", length(leaf.values))
println(io, "$(leaf.majority) : $(ratio)")
Expand All @@ -164,8 +192,9 @@ function print_tree(leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=n
return print_tree(stdout, leaf, depth, indent; sigdigits, feature_names)
end


function print_tree(io::IO, tree::Root, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
function print_tree(
io::IO, tree::Root, depth=-1, indent=0; sigdigits=4, feature_names=nothing
)
return print_tree(io, tree.node, depth, indent; sigdigits, feature_names)
end
function print_tree(tree::Root, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
Expand Down Expand Up @@ -199,20 +228,24 @@ To facilitate visualisation of trees using third party packages, a `DecisionTree
`DecisionTree.Node` object or `DecisionTree.Root` object can be wrapped to obtain a tree structure implementing the
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
"""
function print_tree(io::IO, tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
function print_tree(
io::IO, tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing
)
if depth == indent
println(io)
return
return nothing
end
featval = round(tree.featval; sigdigits)
if feature_names === nothing
println(io, "Feature $(tree.featid) < $featval ?")
else
println(io, "Feature $(tree.featid): \"$(feature_names[tree.featid])\" < $featval ?")
println(
io, "Feature $(tree.featid): \"$(feature_names[tree.featid])\" < $featval ?"
)
end
print(io, " " ^ indent * "├─ ")
print(io, " "^indent * "├─ ")
print_tree(io, tree.left, depth, indent + 1; sigdigits, feature_names)
print(io, " " ^ indent * "└─ ")
print(io, " "^indent * "└─ ")
print_tree(io, tree.right, depth, indent + 1; sigdigits, feature_names)
end
function print_tree(tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
Expand All @@ -222,26 +255,26 @@ end
function show(io::IO, leaf::Leaf)
println(io, "Decision Leaf")
println(io, "Majority: $(leaf.majority)")
print(io, "Samples: $(length(leaf.values))")
print(io, "Samples: $(length(leaf.values))")
end

function show(io::IO, tree::Node)
println(io, "Decision Tree")
println(io, "Leaves: $(length(tree))")
print(io, "Depth: $(depth(tree))")
print(io, "Depth: $(depth(tree))")
end

function show(io::IO, tree::Root)
println(io, "Decision Tree")
println(io, "Leaves: $(length(tree))")
print(io, "Depth: $(depth(tree))")
print(io, "Depth: $(depth(tree))")
end

function show(io::IO, ensemble::Ensemble)
println(io, "Ensemble of Decision Trees")
println(io, "Trees: $(length(ensemble))")
println(io, "Avg Leaves: $(mean([length(tree) for tree in ensemble.trees]))")
print(io, "Avg Depth: $(mean([depth(tree) for tree in ensemble.trees]))")
print(io, "Avg Depth: $(mean([depth(tree) for tree in ensemble.trees]))")
end

end # module
41 changes: 22 additions & 19 deletions src/abstract_trees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ between its children and `T` is the type of the classes given (these might be id
with this mechanism. In case you want add class labels, the generic type `T` must
be a subtype of `Integer`.
"""
struct InfoNode{S, T} <: AbstractTrees.AbstractNode{DecisionTree.Node{S,T}}
node :: DecisionTree.Node{S, T}
info :: NamedTuple
struct InfoNode{S,T} <: AbstractTrees.AbstractNode{DecisionTree.Node{S,T}}
node::DecisionTree.Node{S,T}
info::NamedTuple
end
AbstractTrees.nodevalue(n::InfoNode) = n.node

struct InfoLeaf{T} <: AbstractTrees.AbstractNode{DecisionTree.Leaf{T}}
leaf :: DecisionTree.Leaf{T}
info :: NamedTuple
leaf::DecisionTree.Leaf{T}
info::NamedTuple
end
AbstractTrees.nodevalue(l::InfoLeaf) = l.leaf

Expand Down Expand Up @@ -72,9 +72,9 @@ In the first case `dc` gets just wrapped, no information is added. No. 2 adds fe
as well as class labels. In the last two cases either of this information is added (Note the
trailing comma; it's needed to make it a tuple).
"""
wrap(tree::DecisionTree.Root, info::NamedTuple = NamedTuple()) = wrap(tree.node, info)
wrap(node::DecisionTree.Node, info::NamedTuple = NamedTuple()) = InfoNode(node, info)
wrap(leaf::DecisionTree.Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info)
wrap(tree::DecisionTree.Root, info::NamedTuple=NamedTuple()) = wrap(tree.node, info)
wrap(node::DecisionTree.Node, info::NamedTuple=NamedTuple()) = InfoNode(node, info)
wrap(leaf::DecisionTree.Leaf, info::NamedTuple=NamedTuple()) = InfoLeaf(leaf, info)

"""
children(node::InfoNode)
Expand All @@ -87,10 +87,9 @@ one right child. `children` is used for tree traversal.
The additional information `info` is carried over from `node` to its children.
"""
AbstractTrees.children(node::InfoNode) = (
wrap(node.node.left, node.info),
wrap(node.node.right, node.info)
)
function AbstractTrees.children(node::InfoNode)
(wrap(node.node.left, node.info), wrap(node.node.right, node.info))
end
AbstractTrees.children(node::InfoLeaf) = ()

"""
Expand Down Expand Up @@ -118,23 +117,27 @@ and then below the right subtree.
"""
function AbstractTrees.printnode(io::IO, node::InfoNode; sigdigits=4)
featval = round(node.node.featval; sigdigits)
if :featurenames keys(node.info)
if :featurenames keys(node.info)
print(io, node.info.featurenames[node.node.featid], " < ", featval)
else
print(io, "Feature: ", node.node.featid, " < ", featval)
print(io, "Feature: ", node.node.featid, " < ", featval)
end
end

function AbstractTrees.printnode(io::IO, leaf::InfoLeaf; sigdigits=4)
dt_leaf = leaf.leaf
matches = findall(dt_leaf.values .== dt_leaf.majority)
match_count = length(matches)
val_count = length(dt_leaf.values)
matches = findall(dt_leaf.values .== dt_leaf.majority)
match_count = length(matches)
val_count = length(dt_leaf.values)
if :classlabels keys(leaf.info)
@assert dt_leaf.majority isa Integer "classes must be represented as Integers"
print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)")
else
print(io, dt_leaf.majority isa Integer ? "Class: " : "",
dt_leaf.majority, " ($match_count/$val_count)")
print(
io,
dt_leaf.majority isa Integer ? "Class: " : "",
dt_leaf.majority,
" ($match_count/$val_count)",
)
end
end
Loading

0 comments on commit 792567b

Please sign in to comment.