Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use EvoTrees instead of XGBoost in documentation #57

Merged
merged 24 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.3'
- '1.6'
- '1'
- 'nightly'
os:
Expand Down
4 changes: 1 addition & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
*.jl.*.cov
*.jl.cov
*.jl.mem
/Manifest.toml
/test/Manifest.toml
/test/rstar/Manifest.toml
Manifest.toml
/docs/build/
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.2.5"
version = "0.2.6"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -27,7 +27,7 @@ SpecialFunctions = "0.8, 0.9, 0.10, 1, 2"
StatsBase = "0.33"
StatsFuns = "1"
Tables = "1"
julia = "1.3"
julia = "1.6"

[extras]
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
8 changes: 5 additions & 3 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Documenter = "0.27"
EvoTrees = "0.14.7"
MCMCDiagnosticTools = "0.2"
MLJBase = "0.19, 0.20, 0.21"
MLJXGBoostInterface = "0.1, 0.2, 0.3"
julia = "1.3"
MLJIteration = "0.5"
julia = "1.6"
2 changes: 1 addition & 1 deletion src/MCMCDiagnosticTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using AbstractFFTs: AbstractFFTs
using DataAPI: DataAPI
using DataStructures: DataStructures
using Distributions: Distributions
using MLJModelInterface: MLJModelInterface
using MLJModelInterface: MLJModelInterface as MMI
using SpecialFunctions: SpecialFunctions
using StatsBase: StatsBase
using StatsFuns: StatsFuns
Expand Down
160 changes: 119 additions & 41 deletions src/rstar.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
rstar(
rng::Random.AbstractRNG=Random.default_rng(),
classifier::MLJModelInterface.Supervised,
classifier,
samples,
chain_indices::AbstractVector{Int};
subset::Real=0.7,
Expand All @@ -21,53 +21,97 @@ This method supports ragged chains, i.e. chains of nonequal lengths.
"""
function rstar(
rng::Random.AbstractRNG,
classifier::MLJModelInterface.Supervised,
classifier,
x,
y::AbstractVector{Int};
subset::Real=0.7,
split_chains::Int=2,
verbosity::Int=0,
)
# checks
MLJModelInterface.nrows(x) != length(y) && throw(DimensionMismatch())
# check the arguments
_check_model_supports_continuous_inputs(classifier)
_check_model_supports_multiclass_targets(classifier)
_check_model_supports_multiclass_predictions(classifier)
MMI.nrows(x) != length(y) && throw(DimensionMismatch())
0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)"))

ysplit = split_chain_indices(y, split_chains)

# randomly sub-select training and testing set
ysplit = split_chain_indices(y, split_chains)
train_ids, test_ids = shuffle_split_stratified(rng, ysplit, subset)
0 < length(train_ids) < length(y) ||
throw(ArgumentError("training and test data subsets must not be empty"))

xtable = _astable(x)
ycategorical = MMI.categorical(ysplit)
xdata, ydata = MMI.reformat(classifier, xtable, ycategorical)

# train classifier on training data
ycategorical = MLJModelInterface.categorical(ysplit)
xtrain = MLJModelInterface.selectrows(xtable, train_ids)
fitresult, _ = MLJModelInterface.fit(
classifier, verbosity, xtrain, ycategorical[train_ids]
)
xtrain, ytrain = MMI.selectrows(classifier, train_ids, xdata, ydata)
fitresult, _ = MMI.fit(classifier, verbosity, xtrain, ytrain)

# compute predictions on test data
xtest = MLJModelInterface.selectrows(xtable, test_ids)
xtest, = MMI.selectrows(classifier, test_ids, xdata)
ytest = ycategorical[test_ids]
predictions = _predict(classifier, fitresult, xtest)

# compute statistic
ytest = ycategorical[test_ids]
result = _rstar(predictions, ytest)
result = _rstar(MMI.scitype(predictions), predictions, ytest)

return result
end

# check that the model supports the inputs and targets, and has predictions of the desired form
function _check_model_supports_continuous_inputs(classifier)
# ideally we would not allow MMI.Unknown but some models do not implement the traits
input_scitype_classifier = MMI.input_scitype(classifier)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
if input_scitype_classifier !== MMI.Unknown &&
!(MMI.Table(MMI.Continuous) <: input_scitype_classifier)
throw(
ArgumentError(
"classifier does not support tables of continuous values as inputs"
),
)
end
return nothing
end
function _check_model_supports_multiclass_targets(classifier)
target_scitype_classifier = MMI.target_scitype(classifier)
if target_scitype_classifier !== MMI.Unknown &&
!(AbstractVector{<:MMI.Finite} <: target_scitype_classifier)
throw(
ArgumentError(
"classifier does not support vectors of multi-class labels as targets"
),
)
end
return nothing
end
function _check_model_supports_multiclass_predictions(classifier)
if !(
MMI.predict_scitype(classifier) <: Union{
MMI.Unknown,
AbstractVector{<:MMI.Finite},
AbstractVector{<:MMI.Density{<:MMI.Finite}},
}
)
throw(
ArgumentError(
"classifier does not support vectors of multi-class labels or their densities as predictions",
),
)
end
return nothing
end

_astable(x::AbstractVecOrMat) = Tables.table(x)
_astable(x) = Tables.istable(x) ? x : throw(ArgumentError("Argument is not a valid table"))

# Workaround for https://github.com/JuliaAI/MLJBase.jl/issues/863
# `MLJModelInterface.predict` sometimes returns predictions and sometimes predictions + additional information
# TODO: Remove once the upstream issue is fixed
function _predict(model::MLJModelInterface.Model, fitresult, x)
y = MLJModelInterface.predict(model, fitresult, x)
return if :predict in MLJModelInterface.reporting_operations(model)
function _predict(model::MMI.Model, fitresult, x)
y = MMI.predict(model, fitresult, x)
return if :predict in MMI.reporting_operations(model)
first(y)
else
y
Expand All @@ -77,7 +121,7 @@ end
"""
rstar(
rng::Random.AbstractRNG=Random.default_rng(),
classifier::MLJModelInterface.Supervised,
classifier,
samples::AbstractArray{<:Real,3};
subset::Real=0.7,
split_chains::Int=2,
Expand Down Expand Up @@ -109,77 +153,111 @@ is returned (algorithm 2).
# Examples

```jldoctest rstar; setup = :(using Random; Random.seed!(101))
julia> using MLJBase, MLJXGBoostInterface, Statistics
julia> using MLJBase, MLJIteration, EvoTrees, Statistics

julia> samples = fill(4.0, 100, 3, 2);
```

One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the
One can compute the distribution of the ``R^*`` statistic (algorithm 2) with a
probabilistic classifier.
For instance, we can use a gradient-boosted trees model with `nrounds = 100` sequentially stacked trees and learning rate `eta = 0.05`:

```jldoctest rstar
julia> distribution = rstar(XGBoostClassifier(), samples);
julia> model = EvoTreeClassifier(; nrounds=100, eta=0.05);

julia> isapprox(mean(distribution), 1; atol=0.1)
true
julia> distribution = rstar(model, samples);

julia> round(mean(distribution); digits=2)
1.0f0
```

Note, however, that it is recommended to determine `nrounds` based on early-stopping.
With the MLJ framework, this can be achieved in the following way (see the [MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/) for additional explanations):

```jldoctest rstar
julia> model = IteratedModel(;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's too bad that this setup is so much more verbose than just calling XGBoostClassifier().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we can just use EvoTrees(; nrounds=100, eta=x) (don't remember the default in XGBoostClassifier) and would get the same setting since XGBoostClassifier just uses nrounds = 100 by default without any tuning of this hyperparameter. Based on the comments above I thought it would be good though to highlight how it can be set/estimated in a better way. Maybe should add a comment though and show EvoTrees(; nrounds=100) as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah for a usage example in the docstring I slightly prefer the simpler approach. But I agree that it is good to document the more robust approach somewhere.

model=EvoTreeClassifier(; eta=0.05),
iteration_parameter=:nrounds,
resampling=Holdout(),
measures=log_loss,
controls=[Step(5), Patience(2), NumberLimit(100)],
retrain=true,
);

julia> distribution = rstar(model, samples);

julia> round(mean(distribution); digits=2)
1.0f0
```

For deterministic classifiers, a single ``R^*`` statistic (algorithm 1) is returned.
Deterministic classifiers can also be derived from probabilistic classifiers by e.g.
predicting the mode. In MLJ this corresponds to a pipeline of models.

```jldoctest rstar
julia> xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode);
julia> evotree_deterministic = Pipeline(model; operation=predict_mode);

julia> value = rstar(xgboost_deterministic, samples);
julia> value = rstar(evotree_deterministic, samples);

julia> isapprox(value, 1; atol=0.2)
true
julia> round(value; digits=2)
1.0
```

# References

Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic with uncertainty using decision tree classifiers.
"""
function rstar(
rng::Random.AbstractRNG,
classifier::MLJModelInterface.Supervised,
x::AbstractArray{<:Any,3};
kwargs...,
)
function rstar(rng::Random.AbstractRNG, classifier, x::AbstractArray{<:Any,3}; kwargs...)
samples = reshape(x, :, size(x, 3))
chain_inds = repeat(axes(x, 2); inner=size(x, 1))
return rstar(rng, classifier, samples, chain_inds; kwargs...)
end

function rstar(classif::MLJModelInterface.Supervised, x, y::AbstractVector{Int}; kwargs...)
return rstar(Random.default_rng(), classif, x, y; kwargs...)
function rstar(classifier, x, y::AbstractVector{Int}; kwargs...)
return rstar(Random.default_rng(), classifier, x, y; kwargs...)
end

function rstar(classif::MLJModelInterface.Supervised, x::AbstractArray{<:Any,3}; kwargs...)
return rstar(Random.default_rng(), classif, x; kwargs...)
function rstar(classifier, x::AbstractArray{<:Any,3}; kwargs...)
return rstar(Random.default_rng(), classifier, x; kwargs...)
end

# R⋆ for deterministic predictions (algorithm 1)
function _rstar(predictions::AbstractVector{T}, ytest::AbstractVector{T}) where {T}
function _rstar(
::Type{<:AbstractVector{<:MMI.Finite}},
predictions::AbstractVector,
ytest::AbstractVector,
)
length(predictions) == length(ytest) ||
error("numbers of predictions and targets must be equal")
mean_accuracy = Statistics.mean(p == y for (p, y) in zip(predictions, ytest))
nclasses = length(MLJModelInterface.classes(ytest))
nclasses = length(MMI.classes(ytest))
return nclasses * mean_accuracy
end

# R⋆ for probabilistic predictions (algorithm 2)
function _rstar(predictions::AbstractVector, ytest::AbstractVector)
function _rstar(
::Type{<:AbstractVector{<:MMI.Density{<:MMI.Finite}}},
predictions::AbstractVector,
ytest::AbstractVector,
)
length(predictions) == length(ytest) ||
error("numbers of predictions and targets must be equal")

# create Poisson binomial distribution with support `0:length(predictions)`
distribution = Distributions.PoissonBinomial(map(Distributions.pdf, predictions, ytest))

# scale distribution to support in `[0, nclasses]`
nclasses = length(MLJModelInterface.classes(ytest))
nclasses = length(MMI.classes(ytest))
scaled_distribution = (nclasses//length(predictions)) * distribution

return scaled_distribution
end

# unsupported types of predictions and targets
function _rstar(::Any, predictions, targets)
throw(
ArgumentError(
"unsupported types of predictions ($(typeof(predictions))) and targets ($(typeof(targets)))",
),
)
end
13 changes: 9 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -19,14 +21,17 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[compat]
Distributions = "0.25"
DynamicHMC = "3"
EvoTrees = "0.14.7"
FFTW = "1.1"
LogDensityProblems = "0.12, 1, 2"
LogExpFunctions = "0.3"
MCMCDiagnosticTools = "0.2"
MLJBase = "0.19, 0.20, 0.21"
MLJLIBSVMInterface = "0.1, 0.2"
MLJXGBoostInterface = "0.1, 0.2, 0.3"
MLJIteration = "0.5"
MLJLIBSVMInterface = "0.2"
MLJModels = "0.16"
MLJXGBoostInterface = "0.3"
OffsetArrays = "1"
StatsBase = "0.33"
Tables = "1"
julia = "1.3"
julia = "1.6"
Loading