Skip to content

Commit

Permalink
Merge pull request #201 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.12.1 release
  • Loading branch information
ablaom authored Dec 7, 2022
2 parents 5a04aba + 9062bd1 commit f71ebb1
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name = "DecisionTree"
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
license = "MIT"
desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms"
version = "0.12.0"
version = "0.12.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
24 changes: 17 additions & 7 deletions src/abstract_trees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ apart from the two points mentioned.
In analogy to the type definitions of `DecisionTree`, the generic type `S` is
the type of the feature values used within a node as a threshold for the splits
between its children and `T` is the type of the classes given (these might be ids or labels).
!!! note
You may only add lacking class labels. It's not possible to overwrite existing labels
with this mechanism. In case you want add class labels, the generic type `T` must
be a subtype of `Integer`.
"""
struct InfoNode{S, T} <: AbstractTrees.AbstractNode{DecisionTree.Node{S,T}}
node :: DecisionTree.Node{S, T}
Expand Down Expand Up @@ -89,8 +94,8 @@ AbstractTrees.children(node::InfoNode) = (
AbstractTrees.children(node::InfoLeaf) = ()

"""
printnode(io::IO, node::InfoNode)
printnode(io::IO, leaf::InfoLeaf)
printnode(io::IO, node::InfoNode; sigdigits=4)
printnode(io::IO, leaf::InfoLeaf; sigdigits=4)
Write a printable representation of `node` or `leaf` to output-stream `io`.
Expand All @@ -108,23 +113,28 @@ For the condition of the form `feature < value` which gets printed in the `print
variant for `InfoNode`, the left subtree is the 'yes-branch' and the right subtree
accordingly the 'no-branch'. `AbstractTrees.print_tree` outputs the left subtree first
and then below the right subtree.
`value` gets rounded to `sigdigits` significant digits.
"""
function AbstractTrees.printnode(io::IO, node::InfoNode)
function AbstractTrees.printnode(io::IO, node::InfoNode; sigdigits=4)
featval = round(node.node.featval; sigdigits)
if :featurenames keys(node.info)
print(io, node.info.featurenames[node.node.featid], " < ", node.node.featval)
print(io, node.info.featurenames[node.node.featid], " < ", featval)
else
print(io, "Feature: ", node.node.featid, " < ", node.node.featval)
print(io, "Feature: ", node.node.featid, " < ", featval)
end
end

function AbstractTrees.printnode(io::IO, leaf::InfoLeaf)
function AbstractTrees.printnode(io::IO, leaf::InfoLeaf; sigdigits=4)
dt_leaf = leaf.leaf
matches = findall(dt_leaf.values .== dt_leaf.majority)
match_count = length(matches)
val_count = length(dt_leaf.values)
if :classlabels keys(leaf.info)
@assert dt_leaf.majority isa Integer "classes must be represented as Integers"
print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)")
else
print(io, "Class: ", dt_leaf.majority, " ($match_count/$val_count)")
print(io, dt_leaf.majority isa Integer ? "Class: " : "",
dt_leaf.majority, " ($match_count/$val_count)")
end
end
19 changes: 19 additions & 0 deletions test/miscellaneous/abstract_trees_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,23 @@ end
traverse_tree(leaf::InfoLeaf) = nothing

traverse_tree(wrapped_tree)
end

@testset "abstract_trees - test misuse" begin

@info("Test misuse of `classlabel` information")

@info("Create test data - a decision tree based on the iris data set")
features, labels = load_data("iris")
features = float.(features)
labels = string.(labels)
model = DecisionTreeClassifier()
fit!(model, features, labels)

@info("Try to replace the exisitng class labels")
class_labels = unique(labels)
dtree = model.root.node
wt = DecisionTree.wrap(dtree, (classlabels = class_labels,))
@test_throws AssertionError AbstractTrees.print_tree(wt)

end

0 comments on commit f71ebb1

Please sign in to comment.