Skip to content

Commit

Permalink
Merge pull request #226 from xinadi/dev
Browse files Browse the repository at this point in the history
Float64 replaced by AbstractFloat for regression
  • Loading branch information
rikhuijzer authored Oct 16, 2023
2 parents 605e4d4 + 00fd6cb commit 2c75d94
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name = "DecisionTree"
uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
license = "MIT"
desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithms"
version = "0.12.3"
version = "0.12.4"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
14 changes: 5 additions & 9 deletions src/classification/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function update_pruned_impurity!(
feature_importance::Vector{Float64},
ntt::Int,
loss::Function=mean_squared_error,
) where {S,T<:Float64}
) where {S,T<:AbstractFloat}
μl = mean(tree.left.values)
nl = length(tree.left.values)
μr = mean(tree.right.values)
Expand Down Expand Up @@ -220,7 +220,7 @@ See also [`build_tree`](@ref).
function prune_tree(
tree::Union{Root{S,T},LeafOrNode{S,T}},
purity_thresh=1.0,
loss::Function=T <: Float64 ? mean_squared_error : util.entropy,
loss::Function=T <: AbstractFloat ? mean_squared_error : util.entropy,
) where {S,T}
if purity_thresh >= 1.0
return tree
Expand Down Expand Up @@ -293,11 +293,7 @@ function apply_tree(tree::LeafOrNode{S,T}, features::AbstractMatrix{S}) where {S
for i in 1:N
predictions[i] = apply_tree(tree, features[i, :])
end
if T <: Float64
return Float64.(predictions)
else
return predictions
end
return predictions
end

"""
Expand Down Expand Up @@ -343,7 +339,7 @@ end
Train a random forest model, built on standard CART decision trees, using the specified
`labels` (target) and `features` (patterns). Here:
- `labels` is any `AbstractVector`. If the element type is `Float64`, regression is
- `labels` is any `AbstractVector`. If the element type is `AbstractFloat`, regression is
applied, and otherwise classification is applied.
- `features` is any `AbstractMatrix{T}` where `T` supports ordering with `<` (unordered
Expand Down Expand Up @@ -619,7 +615,7 @@ function apply_forest(forest::Ensemble{S,T}, features::AbstractVector{S}) where
votes[i] = apply_tree(forest.trees[i], features)
end

if T <: Float64
if T <: AbstractFloat
return mean(votes)
else
return majority_vote(votes)
Expand Down
6 changes: 3 additions & 3 deletions src/measures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ function _nfoldCV(
args...;
verbose,
rng,
) where {T<:Float64}
) where {T<:AbstractFloat}
_rng = mk_rng(rng)::Random.AbstractRNG
nfolds = args[1]
if nfolds < 2
Expand Down Expand Up @@ -361,7 +361,7 @@ function nfoldCV_tree(
min_purity_increase::Float64=0.0;
verbose::Bool=true,
rng=Random.GLOBAL_RNG,
) where {S,T<:Float64}
) where {S,T<:AbstractFloat}
_nfoldCV(
:tree,
labels,
Expand Down Expand Up @@ -389,7 +389,7 @@ function nfoldCV_forest(
min_purity_increase::Float64=0.0;
verbose::Bool=true,
rng=Random.GLOBAL_RNG,
) where {S,T<:Float64}
) where {S,T<:AbstractFloat}
_nfoldCV(
:forest,
labels,
Expand Down
8 changes: 4 additions & 4 deletions src/regression/main.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
include("tree.jl")

function _convert(node::treeregressor.NodeMeta{S}, labels::Array{T}) where {S,T<:Float64}
function _convert(node::treeregressor.NodeMeta{S}, labels::Array{T}) where {S,T<:AbstractFloat}
if node.is_leaf
return Leaf{T}(node.label, labels[node.region])
else
Expand All @@ -27,7 +27,7 @@ function build_stump(
features::AbstractMatrix{S};
rng=Random.GLOBAL_RNG,
impurity_importance::Bool=true,
) where {S,T<:Float64}
) where {S,T<:AbstractFloat}
return build_tree(labels, features, 0, 1; rng, impurity_importance)
end

Expand All @@ -41,7 +41,7 @@ function build_tree(
min_purity_increase=0.0;
rng=Random.GLOBAL_RNG,
impurity_importance::Bool=true,
) where {S,T<:Float64}
) where {S,T<:AbstractFloat}
if max_depth == -1
max_depth = typemax(Int)
end
Expand Down Expand Up @@ -85,7 +85,7 @@ function build_forest(
min_purity_increase=0.0;
rng::Union{Integer,AbstractRNG}=Random.GLOBAL_RNG,
impurity_importance::Bool=true,
) where {S,T<:Float64}
) where {S,T<:AbstractFloat}
if n_trees < 1
throw("the number of trees must be >= 1")
end
Expand Down
14 changes: 7 additions & 7 deletions src/regression/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ end
# (max_depth, min_samples_split, min_purity_increase)
function _split!(
X::AbstractMatrix{S}, # the feature array
Y::AbstractVector{Float64}, # the label array
Y::AbstractVector{T}, # the label array
W::AbstractVector{U},
node::NodeMeta{S}, # the node to split
max_features::Int, # number of features to consider
Expand All @@ -59,10 +59,10 @@ function _split!(
# we split using samples in indX[node.region]
# the two arrays below are given for optimization purposes
Xf::AbstractVector{S},
Yf::AbstractVector{Float64},
Yf::AbstractVector{T},
Wf::AbstractVector{U},
rng::Random.AbstractRNG,
) where {S,U}
) where {S,T<:AbstractFloat,U}
region = node.region
n_samples = length(region)
r_start = region.start - 1
Expand Down Expand Up @@ -245,18 +245,18 @@ end

function _fit(
X::AbstractMatrix{S},
Y::AbstractVector{Float64},
Y::AbstractVector{T},
W::AbstractVector{U},
max_features::Int,
max_depth::Int,
min_samples_leaf::Int,
min_samples_split::Int,
min_purity_increase::Float64,
rng=Random.GLOBAL_RNG::Random.AbstractRNG,
) where {S,U}
) where {S,T<:AbstractFloat,U}
n_samples, n_features = size(X)

Yf = Array{Float64}(undef, n_samples)
Yf = Array{T}(undef, n_samples)
Xf = Array{S}(undef, n_samples)
Wf = Array{U}(undef, n_samples)

Expand Down Expand Up @@ -293,7 +293,7 @@ end

function fit(;
X::AbstractMatrix{S},
Y::AbstractVector{Float64},
Y::AbstractVector{<:AbstractFloat},
W::Union{Nothing,AbstractVector{U}},
max_features::Int,
max_depth::Int,
Expand Down
3 changes: 3 additions & 0 deletions test/regression/low_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@
model = build_forest(labels, features)
preds = apply_forest(model, features)
@test typeof(preds) == Vector{Float16}
# Verify that the `preds` were calculated based on `labels` of the same type.
# If the code at some point converts the numbers to, say, `Float64`, then this test will fail.
@test !all(x->(x in labels), preds)

preds_MT = apply_forest(model, features; use_multithreading=true)
@test typeof(preds_MT) == Vector{Float16}
Expand Down

0 comments on commit 2c75d94

Please sign in to comment.