From 622a875d55a5f3ac7b6b71f386ad1bd2d228814e Mon Sep 17 00:00:00 2001 From: PasoStudio73 Date: Mon, 20 Jan 2025 18:55:36 +0100 Subject: [PATCH] preparing merge --- src/apply.jl | 75 +++++++++++++++++++--------------------------------- 1 file changed, 27 insertions(+), 48 deletions(-) diff --git a/src/apply.jl b/src/apply.jl index 72806cb..d1030f1 100644 --- a/src/apply.jl +++ b/src/apply.jl @@ -16,9 +16,8 @@ function apply end apply_model = apply # apply_tree = apply_model -@deprecate apply_tree apply_model -# apply_forest = apply_model -@deprecate apply_forest apply_model +# @deprecate apply_tree apply_model +# @deprecate apply_forest apply_model function apply_proba end @@ -351,72 +350,52 @@ end # use an array of trees to test features function sprinkle( - trees::AbstractVector{<:DTree}, + trees::AbstractVector{<:DTree{<:L}}, Xs, Y::AbstractVector{<:L}; print_progress = !(Xs isa MultiLogiset), tree_weights::Union{AbstractMatrix{Z},AbstractVector{Z},Nothing} = nothing, suppress_parity_warning = false, -) where {L<:Label, Z<:Real} +) where {L<:Label,Z<:Real} @logmsg LogDetail "sprinkle..." - trees = deepcopy(trees) ntrees = length(trees) _ninstances = ninstances(Xs) - # if !(tree_weights isa AbstractMatrix) - # if isnothing(tree_weights) - # tree_weights = Ones{Int}(length(trees), ninstances(Xs)) # TODO optimize? - # elseif tree_weights isa AbstractVector - # tree_weights = hcat([tree_weights for i_instance in 1:ninstances(Xs)]...) - # else - # @show typeof(tree_weights) - # error("Unexpected tree_weights encountered $(tree_weights).") - # end - # end - - if isnothing(tree_weights) - tree_weights = Ones{Int}(ntrees) + if !(tree_weights isa AbstractMatrix) + if isnothing(tree_weights) + tree_weights = Ones{Int}(length(trees), ninstances(Xs)) # TODO optimize? + elseif tree_weights isa AbstractVector + tree_weights = hcat([tree_weights for i_instance in 1:ninstances(Xs)]...) + else + @show typeof(tree_weights) + error("Unexpected tree_weights encountered $(tree_weights).") + end end - # @assert length(trees) == size(tree_weights, 1) "Each label must have a corresponding weight: labels length is $(length(labels)) and weights length is $(length(weights))." - # @assert ninstances(Xs) == size(tree_weights, 2) "Each label must have a corresponding weight: labels length is $(length(labels)) and weights length is $(length(weights))." + @assert length(trees) == size(tree_weights, 1) "Each label must have a corresponding weight: labels length is $(length(labels)) and weights length is $(length(weights))." + @assert ninstances(Xs) == size(tree_weights, 2) "Each label must have a corresponding weight: labels length is $(length(labels)) and weights length is $(length(weights))." - predictions = Vector{L}(undef, _ninstances) + # apply each tree to the whole dataset _predictions = Matrix{L}(undef, ntrees, _ninstances) - if print_progress p = Progress(ntrees; dt = 1, desc = "Applying trees...") end - - # apply each tree to the whole dataset Threads.@threads for i_tree in 1:ntrees _predictions[i_tree,:], trees[i_tree] = sprinkle(trees[i_tree], Xs, Y; print_progress = false) + print_progress && next!(p) end # for each instance, aggregate the predictions - for i_instance in 1:_ninstances - _counts = Dict{L, Float64}() - - Threads.@threads for i_tree in 1:ntrees - prediction = _predictions[i_tree, i_instance] - _counts[prediction] = get(_counts, prediction, 0.0) + tree_weights[i_tree] - end - - top_prediction = trees[1].root.left.prediction - top_count = -Inf - - for (k, v) in _counts - if v > top_count - top_prediction = k - top_count = v - end - end - predictions[i_instance] = top_prediction - - print_progress && next!(p) + predictions = Vector{L}(undef, _ninstances) + Threads.@threads for i_instance in 1:_ninstances + predictions[i_instance] = bestguess( + _predictions[:,i_instance], + tree_weights[:,i_instance]; + suppress_parity_warning = suppress_parity_warning + ) end - # @show predictions + predictions, trees end @@ -430,7 +409,7 @@ function sprinkle( ) where {L<:Label} predictions, trees = begin if weight_trees_by == false - sprinkle(trees(forest), Xs, Y; kwargs...) + sprinkle(ModalDecisionTrees.trees(forest), Xs, Y; kwargs...) elseif isa(weight_trees_by, AbstractVector) sprinkle(trees(forest), Xs, Y; tree_weights = weight_trees_by, kwargs...) # elseif weight_trees_by == :accuracy @@ -747,4 +726,4 @@ end # metrics # end -# tree_walk_metrics(tree::DTree; kwargs...) = tree_walk_metrics(tree.root; kwargs...) +# tree_walk_metrics(tree::DTree; kwargs...) = tree_walk_metrics(tree.root; kwargs...) \ No newline at end of file