Skip to content

Commit

Permalink
Use EvoTrees instead of XGBoost in documentation (#57)
Browse files Browse the repository at this point in the history
* Use EvoTrees instead of XGBoost

* Update runtests.jl

* Try to fix use of MLJ interface

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/rstar.jl

* Apply suggestions from code review

Co-authored-by: Seth Axen <[email protected]>

* Fix RNG of `EvoTreeClassifier` in tests

* Update rstar.jl

* Update rstar.jl

* Update test/rstar.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Bump compat entries

* Apply suggested changes

* Use a more advanced example

* Test XGBoost as well

* Update rstar.jl

* Update rstar.jl

* Update EvoTrees dependency

* Update documentation

* Use MLJ traits

* Update src/rstar.jl

Co-authored-by: Seth Axen <[email protected]>

* Some refactoring and additional tests

* Update Project.toml

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Seth Axen <[email protected]>
  • Loading branch information
3 people authored Jan 17, 2023
1 parent 1799e79 commit 2e07d21
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.3'
- '1.6'
- '1'
- 'nightly'
os:
Expand Down
4 changes: 1 addition & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
*.jl.*.cov
*.jl.cov
*.jl.mem
/Manifest.toml
/test/Manifest.toml
/test/rstar/Manifest.toml
Manifest.toml
/docs/build/
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.2.5"
version = "0.2.6"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -27,7 +27,7 @@ SpecialFunctions = "0.8, 0.9, 0.10, 1, 2"
StatsBase = "0.33"
StatsFuns = "1"
Tables = "1"
julia = "1.3"
julia = "1.6"

[extras]
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
8 changes: 5 additions & 3 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Documenter = "0.27"
EvoTrees = "0.14.7"
MCMCDiagnosticTools = "0.2"
MLJBase = "0.19, 0.20, 0.21"
MLJXGBoostInterface = "0.1, 0.2, 0.3"
julia = "1.3"
MLJIteration = "0.5"
julia = "1.6"
2 changes: 1 addition & 1 deletion src/MCMCDiagnosticTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using AbstractFFTs: AbstractFFTs
using DataAPI: DataAPI
using DataStructures: DataStructures
using Distributions: Distributions
using MLJModelInterface: MLJModelInterface
using MLJModelInterface: MLJModelInterface as MMI
using SpecialFunctions: SpecialFunctions
using StatsBase: StatsBase
using StatsFuns: StatsFuns
Expand Down
160 changes: 119 additions & 41 deletions src/rstar.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
rstar(
rng::Random.AbstractRNG=Random.default_rng(),
classifier::MLJModelInterface.Supervised,
classifier,
samples,
chain_indices::AbstractVector{Int};
subset::Real=0.7,
Expand All @@ -21,53 +21,97 @@ This method supports ragged chains, i.e. chains of nonequal lengths.
"""
function rstar(
rng::Random.AbstractRNG,
classifier::MLJModelInterface.Supervised,
classifier,
x,
y::AbstractVector{Int};
subset::Real=0.7,
split_chains::Int=2,
verbosity::Int=0,
)
# checks
MLJModelInterface.nrows(x) != length(y) && throw(DimensionMismatch())
# check the arguments
_check_model_supports_continuous_inputs(classifier)
_check_model_supports_multiclass_targets(classifier)
_check_model_supports_multiclass_predictions(classifier)
MMI.nrows(x) != length(y) && throw(DimensionMismatch())
0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)"))

ysplit = split_chain_indices(y, split_chains)

# randomly sub-select training and testing set
ysplit = split_chain_indices(y, split_chains)
train_ids, test_ids = shuffle_split_stratified(rng, ysplit, subset)
0 < length(train_ids) < length(y) ||
throw(ArgumentError("training and test data subsets must not be empty"))

xtable = _astable(x)
ycategorical = MMI.categorical(ysplit)
xdata, ydata = MMI.reformat(classifier, xtable, ycategorical)

# train classifier on training data
ycategorical = MLJModelInterface.categorical(ysplit)
xtrain = MLJModelInterface.selectrows(xtable, train_ids)
fitresult, _ = MLJModelInterface.fit(
classifier, verbosity, xtrain, ycategorical[train_ids]
)
xtrain, ytrain = MMI.selectrows(classifier, train_ids, xdata, ydata)
fitresult, _ = MMI.fit(classifier, verbosity, xtrain, ytrain)

# compute predictions on test data
xtest = MLJModelInterface.selectrows(xtable, test_ids)
xtest, = MMI.selectrows(classifier, test_ids, xdata)
ytest = ycategorical[test_ids]
predictions = _predict(classifier, fitresult, xtest)

# compute statistic
ytest = ycategorical[test_ids]
result = _rstar(predictions, ytest)
result = _rstar(MMI.scitype(predictions), predictions, ytest)

return result
end

# check that the model supports the inputs and targets, and has predictions of the desired form
function _check_model_supports_continuous_inputs(classifier)
# ideally we would not allow MMI.Unknown but some models do not implement the traits
input_scitype_classifier = MMI.input_scitype(classifier)
if input_scitype_classifier !== MMI.Unknown &&
!(MMI.Table(MMI.Continuous) <: input_scitype_classifier)
throw(
ArgumentError(
"classifier does not support tables of continuous values as inputs"
),
)
end
return nothing
end
function _check_model_supports_multiclass_targets(classifier)
target_scitype_classifier = MMI.target_scitype(classifier)
if target_scitype_classifier !== MMI.Unknown &&
!(AbstractVector{<:MMI.Finite} <: target_scitype_classifier)
throw(
ArgumentError(
"classifier does not support vectors of multi-class labels as targets"
),
)
end
return nothing
end
function _check_model_supports_multiclass_predictions(classifier)
if !(
MMI.predict_scitype(classifier) <: Union{
MMI.Unknown,
AbstractVector{<:MMI.Finite},
AbstractVector{<:MMI.Density{<:MMI.Finite}},
}
)
throw(
ArgumentError(
"classifier does not support vectors of multi-class labels or their densities as predictions",
),
)
end
return nothing
end

_astable(x::AbstractVecOrMat) = Tables.table(x)
_astable(x) = Tables.istable(x) ? x : throw(ArgumentError("Argument is not a valid table"))

# Workaround for https://github.com/JuliaAI/MLJBase.jl/issues/863
# `MLJModelInterface.predict` sometimes returns predictions and sometimes predictions + additional information
# TODO: Remove once the upstream issue is fixed
function _predict(model::MLJModelInterface.Model, fitresult, x)
y = MLJModelInterface.predict(model, fitresult, x)
return if :predict in MLJModelInterface.reporting_operations(model)
function _predict(model::MMI.Model, fitresult, x)
y = MMI.predict(model, fitresult, x)
return if :predict in MMI.reporting_operations(model)
first(y)
else
y
Expand All @@ -77,7 +121,7 @@ end
"""
rstar(
rng::Random.AbstractRNG=Random.default_rng(),
classifier::MLJModelInterface.Supervised,
classifier,
samples::AbstractArray{<:Real,3};
subset::Real=0.7,
split_chains::Int=2,
Expand Down Expand Up @@ -109,77 +153,111 @@ is returned (algorithm 2).
# Examples
```jldoctest rstar; setup = :(using Random; Random.seed!(101))
julia> using MLJBase, MLJXGBoostInterface, Statistics
julia> using MLJBase, MLJIteration, EvoTrees, Statistics
julia> samples = fill(4.0, 100, 3, 2);
```
One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the
One can compute the distribution of the ``R^*`` statistic (algorithm 2) with a
probabilistic classifier.
For instance, we can use a gradient-boosted trees model with `nrounds = 100` sequentially stacked trees and learning rate `eta = 0.05`:
```jldoctest rstar
julia> distribution = rstar(XGBoostClassifier(), samples);
julia> model = EvoTreeClassifier(; nrounds=100, eta=0.05);
julia> isapprox(mean(distribution), 1; atol=0.1)
true
julia> distribution = rstar(model, samples);
julia> round(mean(distribution); digits=2)
1.0f0
```
Note, however, that it is recommended to determine `nrounds` based on early-stopping.
With the MLJ framework, this can be achieved in the following way (see the [MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/) for additional explanations):
```jldoctest rstar
julia> model = IteratedModel(;
model=EvoTreeClassifier(; eta=0.05),
iteration_parameter=:nrounds,
resampling=Holdout(),
measures=log_loss,
controls=[Step(5), Patience(2), NumberLimit(100)],
retrain=true,
);
julia> distribution = rstar(model, samples);
julia> round(mean(distribution); digits=2)
1.0f0
```
For deterministic classifiers, a single ``R^*`` statistic (algorithm 1) is returned.
Deterministic classifiers can also be derived from probabilistic classifiers by e.g.
predicting the mode. In MLJ this corresponds to a pipeline of models.
```jldoctest rstar
julia> xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode);
julia> evotree_deterministic = Pipeline(model; operation=predict_mode);
julia> value = rstar(xgboost_deterministic, samples);
julia> value = rstar(evotree_deterministic, samples);
julia> isapprox(value, 1; atol=0.2)
true
julia> round(value; digits=2)
1.0
```
# References
Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic with uncertainty using decision tree classifiers.
"""
function rstar(
rng::Random.AbstractRNG,
classifier::MLJModelInterface.Supervised,
x::AbstractArray{<:Any,3};
kwargs...,
)
function rstar(rng::Random.AbstractRNG, classifier, x::AbstractArray{<:Any,3}; kwargs...)
samples = reshape(x, :, size(x, 3))
chain_inds = repeat(axes(x, 2); inner=size(x, 1))
return rstar(rng, classifier, samples, chain_inds; kwargs...)
end

function rstar(classif::MLJModelInterface.Supervised, x, y::AbstractVector{Int}; kwargs...)
return rstar(Random.default_rng(), classif, x, y; kwargs...)
function rstar(classifier, x, y::AbstractVector{Int}; kwargs...)
return rstar(Random.default_rng(), classifier, x, y; kwargs...)
end

function rstar(classif::MLJModelInterface.Supervised, x::AbstractArray{<:Any,3}; kwargs...)
return rstar(Random.default_rng(), classif, x; kwargs...)
function rstar(classifier, x::AbstractArray{<:Any,3}; kwargs...)
return rstar(Random.default_rng(), classifier, x; kwargs...)
end

# R⋆ for deterministic predictions (algorithm 1)
function _rstar(predictions::AbstractVector{T}, ytest::AbstractVector{T}) where {T}
function _rstar(
::Type{<:AbstractVector{<:MMI.Finite}},
predictions::AbstractVector,
ytest::AbstractVector,
)
length(predictions) == length(ytest) ||
error("numbers of predictions and targets must be equal")
mean_accuracy = Statistics.mean(p == y for (p, y) in zip(predictions, ytest))
nclasses = length(MLJModelInterface.classes(ytest))
nclasses = length(MMI.classes(ytest))
return nclasses * mean_accuracy
end

# R⋆ for probabilistic predictions (algorithm 2)
function _rstar(predictions::AbstractVector, ytest::AbstractVector)
function _rstar(
::Type{<:AbstractVector{<:MMI.Density{<:MMI.Finite}}},
predictions::AbstractVector,
ytest::AbstractVector,
)
length(predictions) == length(ytest) ||
error("numbers of predictions and targets must be equal")

# create Poisson binomial distribution with support `0:length(predictions)`
distribution = Distributions.PoissonBinomial(map(Distributions.pdf, predictions, ytest))

# scale distribution to support in `[0, nclasses]`
nclasses = length(MLJModelInterface.classes(ytest))
nclasses = length(MMI.classes(ytest))
scaled_distribution = (nclasses//length(predictions)) * distribution

return scaled_distribution
end

# unsupported types of predictions and targets
function _rstar(::Any, predictions, targets)
throw(
ArgumentError(
"unsupported types of predictions ($(typeof(predictions))) and targets ($(typeof(targets)))",
),
)
end
13 changes: 9 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -19,14 +21,17 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[compat]
Distributions = "0.25"
DynamicHMC = "3"
EvoTrees = "0.14.7"
FFTW = "1.1"
LogDensityProblems = "0.12, 1, 2"
LogExpFunctions = "0.3"
MCMCDiagnosticTools = "0.2"
MLJBase = "0.19, 0.20, 0.21"
MLJLIBSVMInterface = "0.1, 0.2"
MLJXGBoostInterface = "0.1, 0.2, 0.3"
MLJIteration = "0.5"
MLJLIBSVMInterface = "0.2"
MLJModels = "0.16"
MLJXGBoostInterface = "0.3"
OffsetArrays = "1"
StatsBase = "0.33"
Tables = "1"
julia = "1.3"
julia = "1.6"
Loading

2 comments on commit 2e07d21

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/75838

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.6 -m "<description of version>" 2e07d216929245d7b57056e7a5dd10bc7a0f5cdb
git push origin v0.2.6

Please sign in to comment.