Skip to content

Commit

Permalink
preparing merge
Browse files Browse the repository at this point in the history
  • Loading branch information
PasoStudio73 committed Jan 20, 2025
1 parent 0567d15 commit 622a875
Showing 1 changed file with 27 additions and 48 deletions.
75 changes: 27 additions & 48 deletions src/apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ function apply end
apply_model = apply

# apply_tree = apply_model
@deprecate apply_tree apply_model
# apply_forest = apply_model
@deprecate apply_forest apply_model
# @deprecate apply_tree apply_model
# @deprecate apply_forest apply_model

function apply_proba end

Expand Down Expand Up @@ -351,72 +350,52 @@ end

# use an array of trees to test features
function sprinkle(
trees::AbstractVector{<:DTree},
trees::AbstractVector{<:DTree{<:L}},
Xs,
Y::AbstractVector{<:L};
print_progress = !(Xs isa MultiLogiset),
tree_weights::Union{AbstractMatrix{Z},AbstractVector{Z},Nothing} = nothing,
suppress_parity_warning = false,
) where {L<:Label, Z<:Real}
) where {L<:Label,Z<:Real}
@logmsg LogDetail "sprinkle..."

trees = deepcopy(trees)
ntrees = length(trees)
_ninstances = ninstances(Xs)

# if !(tree_weights isa AbstractMatrix)
# if isnothing(tree_weights)
# tree_weights = Ones{Int}(length(trees), ninstances(Xs)) # TODO optimize?
# elseif tree_weights isa AbstractVector
# tree_weights = hcat([tree_weights for i_instance in 1:ninstances(Xs)]...)
# else
# @show typeof(tree_weights)
# error("Unexpected tree_weights encountered $(tree_weights).")
# end
# end

if isnothing(tree_weights)
tree_weights = Ones{Int}(ntrees)
if !(tree_weights isa AbstractMatrix)
if isnothing(tree_weights)
tree_weights = Ones{Int}(length(trees), ninstances(Xs)) # TODO optimize?
elseif tree_weights isa AbstractVector
tree_weights = hcat([tree_weights for i_instance in 1:ninstances(Xs)]...)
else
@show typeof(tree_weights)
error("Unexpected tree_weights encountered $(tree_weights).")
end
end

# @assert length(trees) == size(tree_weights, 1) "Each label must have a corresponding weight: labels length is $(length(labels)) and weights length is $(length(weights))."
# @assert ninstances(Xs) == size(tree_weights, 2) "Each label must have a corresponding weight: labels length is $(length(labels)) and weights length is $(length(weights))."
@assert length(trees) == size(tree_weights, 1) "Each label must have a corresponding weight: labels length is $(length(labels)) and weights length is $(length(weights))."
@assert ninstances(Xs) == size(tree_weights, 2) "Each label must have a corresponding weight: labels length is $(length(labels)) and weights length is $(length(weights))."

predictions = Vector{L}(undef, _ninstances)
# apply each tree to the whole dataset
_predictions = Matrix{L}(undef, ntrees, _ninstances)

if print_progress
p = Progress(ntrees; dt = 1, desc = "Applying trees...")
end

# apply each tree to the whole dataset
Threads.@threads for i_tree in 1:ntrees
_predictions[i_tree,:], trees[i_tree] = sprinkle(trees[i_tree], Xs, Y; print_progress = false)
print_progress && next!(p)
end

# for each instance, aggregate the predictions
for i_instance in 1:_ninstances
_counts = Dict{L, Float64}()

Threads.@threads for i_tree in 1:ntrees
prediction = _predictions[i_tree, i_instance]
_counts[prediction] = get(_counts, prediction, 0.0) + tree_weights[i_tree]
end

top_prediction = trees[1].root.left.prediction
top_count = -Inf

for (k, v) in _counts
if v > top_count
top_prediction = k
top_count = v
end
end
predictions[i_instance] = top_prediction

print_progress && next!(p)
predictions = Vector{L}(undef, _ninstances)
Threads.@threads for i_instance in 1:_ninstances
predictions[i_instance] = bestguess(
_predictions[:,i_instance],
tree_weights[:,i_instance];
suppress_parity_warning = suppress_parity_warning
)
end
# @show predictions

predictions, trees
end

Expand All @@ -430,7 +409,7 @@ function sprinkle(
) where {L<:Label}
predictions, trees = begin
if weight_trees_by == false
sprinkle(trees(forest), Xs, Y; kwargs...)
sprinkle(ModalDecisionTrees.trees(forest), Xs, Y; kwargs...)
elseif isa(weight_trees_by, AbstractVector)
sprinkle(trees(forest), Xs, Y; tree_weights = weight_trees_by, kwargs...)
# elseif weight_trees_by == :accuracy
Expand Down Expand Up @@ -747,4 +726,4 @@ end
# metrics
# end

# tree_walk_metrics(tree::DTree; kwargs...) = tree_walk_metrics(tree.root; kwargs...)
# tree_walk_metrics(tree::DTree; kwargs...) = tree_walk_metrics(tree.root; kwargs...)

0 comments on commit 622a875

Please sign in to comment.