From 2e07d216929245d7b57056e7a5dd10bc7a0f5cdb Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 17 Jan 2023 13:53:01 +0100 Subject: [PATCH] Use EvoTrees instead of XGBoost in documentation (#57) * Use EvoTrees instead of XGBoost * Update runtests.jl * Try to fix use of MLJ interface * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/rstar.jl * Apply suggestions from code review Co-authored-by: Seth Axen * Fix RNG of `EvoTreeClassifier` in tests * Update rstar.jl * Update rstar.jl * Update test/rstar.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Bump compat entries * Apply suggested changes * Use a more advanced example * Test XGBoost as well * Update rstar.jl * Update rstar.jl * Update EvoTrees dependency * Update documentation * Use MLJ traits * Update src/rstar.jl Co-authored-by: Seth Axen * Some refactoring and additional tests * Update Project.toml Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Seth Axen --- .github/workflows/CI.yml | 2 +- .gitignore | 4 +- Project.toml | 4 +- docs/Project.toml | 8 +- src/MCMCDiagnosticTools.jl | 2 +- src/rstar.jl | 160 +++++++++++++++++++++++++++---------- test/Project.toml | 13 ++- test/rstar.jl | 72 ++++++++++++++++- test/runtests.jl | 9 +-- 9 files changed, 207 insertions(+), 67 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b96ee980..39be5311 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: version: - - '1.3' + - '1.6' - '1' - 'nightly' os: diff --git a/.gitignore b/.gitignore index d6e09773..1c02e5e1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,5 @@ *.jl.*.cov *.jl.cov *.jl.mem -/Manifest.toml -/test/Manifest.toml -/test/rstar/Manifest.toml +Manifest.toml /docs/build/ diff --git a/Project.toml b/Project.toml index abac7047..9e702160 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.2.5" +version = "0.2.6" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" @@ -27,7 +27,7 @@ SpecialFunctions = "0.8, 0.9, 0.10, 1, 2" StatsBase = "0.33" StatsFuns = "1" Tables = "1" -julia = "1.3" +julia = "1.6" [extras] Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/docs/Project.toml b/docs/Project.toml index ad15f1d3..96106f96 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,14 +1,16 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" +MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] Documenter = "0.27" +EvoTrees = "0.14.7" MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" -MLJXGBoostInterface = "0.1, 0.2, 0.3" -julia = "1.3" +MLJIteration = "0.5" +julia = "1.6" diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index beb1ca11..c2ef459e 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -4,7 +4,7 @@ using AbstractFFTs: AbstractFFTs using DataAPI: DataAPI using DataStructures: DataStructures using Distributions: Distributions -using MLJModelInterface: MLJModelInterface +using MLJModelInterface: MLJModelInterface as MMI using SpecialFunctions: SpecialFunctions using StatsBase: StatsBase using StatsFuns: StatsFuns diff --git a/src/rstar.jl b/src/rstar.jl index 820c3a02..26d761b0 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -1,7 +1,7 @@ """ rstar( rng::Random.AbstractRNG=Random.default_rng(), - classifier::MLJModelInterface.Supervised, + classifier, samples, chain_indices::AbstractVector{Int}; subset::Real=0.7, @@ -21,53 +21,97 @@ This method supports ragged chains, i.e. chains of nonequal lengths. """ function rstar( rng::Random.AbstractRNG, - classifier::MLJModelInterface.Supervised, + classifier, x, y::AbstractVector{Int}; subset::Real=0.7, split_chains::Int=2, verbosity::Int=0, ) - # checks - MLJModelInterface.nrows(x) != length(y) && throw(DimensionMismatch()) + # check the arguments + _check_model_supports_continuous_inputs(classifier) + _check_model_supports_multiclass_targets(classifier) + _check_model_supports_multiclass_predictions(classifier) + MMI.nrows(x) != length(y) && throw(DimensionMismatch()) 0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)")) - ysplit = split_chain_indices(y, split_chains) - # randomly sub-select training and testing set + ysplit = split_chain_indices(y, split_chains) train_ids, test_ids = shuffle_split_stratified(rng, ysplit, subset) 0 < length(train_ids) < length(y) || throw(ArgumentError("training and test data subsets must not be empty")) xtable = _astable(x) + ycategorical = MMI.categorical(ysplit) + xdata, ydata = MMI.reformat(classifier, xtable, ycategorical) # train classifier on training data - ycategorical = MLJModelInterface.categorical(ysplit) - xtrain = MLJModelInterface.selectrows(xtable, train_ids) - fitresult, _ = MLJModelInterface.fit( - classifier, verbosity, xtrain, ycategorical[train_ids] - ) + xtrain, ytrain = MMI.selectrows(classifier, train_ids, xdata, ydata) + fitresult, _ = MMI.fit(classifier, verbosity, xtrain, ytrain) # compute predictions on test data - xtest = MLJModelInterface.selectrows(xtable, test_ids) + xtest, = MMI.selectrows(classifier, test_ids, xdata) + ytest = ycategorical[test_ids] predictions = _predict(classifier, fitresult, xtest) # compute statistic - ytest = ycategorical[test_ids] - result = _rstar(predictions, ytest) + result = _rstar(MMI.scitype(predictions), predictions, ytest) return result end +# check that the model supports the inputs and targets, and has predictions of the desired form +function _check_model_supports_continuous_inputs(classifier) + # ideally we would not allow MMI.Unknown but some models do not implement the traits + input_scitype_classifier = MMI.input_scitype(classifier) + if input_scitype_classifier !== MMI.Unknown && + !(MMI.Table(MMI.Continuous) <: input_scitype_classifier) + throw( + ArgumentError( + "classifier does not support tables of continuous values as inputs" + ), + ) + end + return nothing +end +function _check_model_supports_multiclass_targets(classifier) + target_scitype_classifier = MMI.target_scitype(classifier) + if target_scitype_classifier !== MMI.Unknown && + !(AbstractVector{<:MMI.Finite} <: target_scitype_classifier) + throw( + ArgumentError( + "classifier does not support vectors of multi-class labels as targets" + ), + ) + end + return nothing +end +function _check_model_supports_multiclass_predictions(classifier) + if !( + MMI.predict_scitype(classifier) <: Union{ + MMI.Unknown, + AbstractVector{<:MMI.Finite}, + AbstractVector{<:MMI.Density{<:MMI.Finite}}, + } + ) + throw( + ArgumentError( + "classifier does not support vectors of multi-class labels or their densities as predictions", + ), + ) + end + return nothing +end + _astable(x::AbstractVecOrMat) = Tables.table(x) _astable(x) = Tables.istable(x) ? x : throw(ArgumentError("Argument is not a valid table")) # Workaround for https://github.com/JuliaAI/MLJBase.jl/issues/863 # `MLJModelInterface.predict` sometimes returns predictions and sometimes predictions + additional information # TODO: Remove once the upstream issue is fixed -function _predict(model::MLJModelInterface.Model, fitresult, x) - y = MLJModelInterface.predict(model, fitresult, x) - return if :predict in MLJModelInterface.reporting_operations(model) +function _predict(model::MMI.Model, fitresult, x) + y = MMI.predict(model, fitresult, x) + return if :predict in MMI.reporting_operations(model) first(y) else y @@ -77,7 +121,7 @@ end """ rstar( rng::Random.AbstractRNG=Random.default_rng(), - classifier::MLJModelInterface.Supervised, + classifier, samples::AbstractArray{<:Real,3}; subset::Real=0.7, split_chains::Int=2, @@ -109,19 +153,41 @@ is returned (algorithm 2). # Examples ```jldoctest rstar; setup = :(using Random; Random.seed!(101)) -julia> using MLJBase, MLJXGBoostInterface, Statistics +julia> using MLJBase, MLJIteration, EvoTrees, Statistics julia> samples = fill(4.0, 100, 3, 2); ``` -One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the +One can compute the distribution of the ``R^*`` statistic (algorithm 2) with a probabilistic classifier. +For instance, we can use a gradient-boosted trees model with `nrounds = 100` sequentially stacked trees and learning rate `eta = 0.05`: ```jldoctest rstar -julia> distribution = rstar(XGBoostClassifier(), samples); +julia> model = EvoTreeClassifier(; nrounds=100, eta=0.05); -julia> isapprox(mean(distribution), 1; atol=0.1) -true +julia> distribution = rstar(model, samples); + +julia> round(mean(distribution); digits=2) +1.0f0 +``` + +Note, however, that it is recommended to determine `nrounds` based on early-stopping. +With the MLJ framework, this can be achieved in the following way (see the [MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/) for additional explanations): + +```jldoctest rstar +julia> model = IteratedModel(; + model=EvoTreeClassifier(; eta=0.05), + iteration_parameter=:nrounds, + resampling=Holdout(), + measures=log_loss, + controls=[Step(5), Patience(2), NumberLimit(100)], + retrain=true, + ); + +julia> distribution = rstar(model, samples); + +julia> round(mean(distribution); digits=2) +1.0f0 ``` For deterministic classifiers, a single ``R^*`` statistic (algorithm 1) is returned. @@ -129,48 +195,51 @@ Deterministic classifiers can also be derived from probabilistic classifiers by predicting the mode. In MLJ this corresponds to a pipeline of models. ```jldoctest rstar -julia> xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode); +julia> evotree_deterministic = Pipeline(model; operation=predict_mode); -julia> value = rstar(xgboost_deterministic, samples); +julia> value = rstar(evotree_deterministic, samples); -julia> isapprox(value, 1; atol=0.2) -true +julia> round(value; digits=2) +1.0 ``` # References Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic with uncertainty using decision tree classifiers. """ -function rstar( - rng::Random.AbstractRNG, - classifier::MLJModelInterface.Supervised, - x::AbstractArray{<:Any,3}; - kwargs..., -) +function rstar(rng::Random.AbstractRNG, classifier, x::AbstractArray{<:Any,3}; kwargs...) samples = reshape(x, :, size(x, 3)) chain_inds = repeat(axes(x, 2); inner=size(x, 1)) return rstar(rng, classifier, samples, chain_inds; kwargs...) end -function rstar(classif::MLJModelInterface.Supervised, x, y::AbstractVector{Int}; kwargs...) - return rstar(Random.default_rng(), classif, x, y; kwargs...) +function rstar(classifier, x, y::AbstractVector{Int}; kwargs...) + return rstar(Random.default_rng(), classifier, x, y; kwargs...) end -function rstar(classif::MLJModelInterface.Supervised, x::AbstractArray{<:Any,3}; kwargs...) - return rstar(Random.default_rng(), classif, x; kwargs...) +function rstar(classifier, x::AbstractArray{<:Any,3}; kwargs...) + return rstar(Random.default_rng(), classifier, x; kwargs...) end # R⋆ for deterministic predictions (algorithm 1) -function _rstar(predictions::AbstractVector{T}, ytest::AbstractVector{T}) where {T} +function _rstar( + ::Type{<:AbstractVector{<:MMI.Finite}}, + predictions::AbstractVector, + ytest::AbstractVector, +) length(predictions) == length(ytest) || error("numbers of predictions and targets must be equal") mean_accuracy = Statistics.mean(p == y for (p, y) in zip(predictions, ytest)) - nclasses = length(MLJModelInterface.classes(ytest)) + nclasses = length(MMI.classes(ytest)) return nclasses * mean_accuracy end # R⋆ for probabilistic predictions (algorithm 2) -function _rstar(predictions::AbstractVector, ytest::AbstractVector) +function _rstar( + ::Type{<:AbstractVector{<:MMI.Density{<:MMI.Finite}}}, + predictions::AbstractVector, + ytest::AbstractVector, +) length(predictions) == length(ytest) || error("numbers of predictions and targets must be equal") @@ -178,8 +247,17 @@ function _rstar(predictions::AbstractVector, ytest::AbstractVector) distribution = Distributions.PoissonBinomial(map(Distributions.pdf, predictions, ytest)) # scale distribution to support in `[0, nclasses]` - nclasses = length(MLJModelInterface.classes(ytest)) + nclasses = length(MMI.classes(ytest)) scaled_distribution = (nclasses//length(predictions)) * distribution return scaled_distribution end + +# unsupported types of predictions and targets +function _rstar(::Any, predictions, targets) + throw( + ArgumentError( + "unsupported types of predictions ($(typeof(predictions))) and targets ($(typeof(targets)))", + ), + ) +end diff --git a/test/Project.toml b/test/Project.toml index d571fe59..838532aa 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,15 +1,17 @@ [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" +EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55" MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" +MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -19,14 +21,17 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Distributions = "0.25" DynamicHMC = "3" +EvoTrees = "0.14.7" FFTW = "1.1" LogDensityProblems = "0.12, 1, 2" LogExpFunctions = "0.3" MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" -MLJLIBSVMInterface = "0.1, 0.2" -MLJXGBoostInterface = "0.1, 0.2, 0.3" +MLJIteration = "0.5" +MLJLIBSVMInterface = "0.2" +MLJModels = "0.16" +MLJXGBoostInterface = "0.3" OffsetArrays = "1" StatsBase = "0.33" Tables = "1" -julia = "1.3" +julia = "1.6" diff --git a/test/rstar.jl b/test/rstar.jl index 92706416..9dee8e43 100644 --- a/test/rstar.jl +++ b/test/rstar.jl @@ -1,21 +1,34 @@ using MCMCDiagnosticTools using Distributions -using MLJBase +using EvoTrees +using MLJBase: MLJBase, Pipeline, predict_mode using MLJLIBSVMInterface +using MLJModels using MLJXGBoostInterface using Tables using Random using Test -const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode) +# XGBoost errors on 32bit systems: https://github.com/dmlc/XGBoost.jl/issues/92 +const XGBoostClassifiers = if Sys.WORD_SIZE == 64 + (XGBoostClassifier(), Pipeline(XGBoostClassifier(); operation=predict_mode)) +else + () +end @testset "rstar.jl" begin - classifiers = (XGBoostClassifier(), xgboost_deterministic, SVC()) N = 1_000 @testset "samples input type: $wrapper" for wrapper in [Vector, Array, Tables.table] + # In practice, probably you want to use EvoTreeClassifier with early stopping + classifiers = ( + EvoTreeClassifier(; nrounds=100, eta=0.3), + Pipeline(EvoTreeClassifier(; nrounds=100, eta=0.3); operation=predict_mode), + SVC(), + XGBoostClassifiers..., + ) @testset "examples (classifier = $classifier)" for classifier in classifiers sz = wrapper === Vector ? N : (N, 2) # Compute R⋆ statistic for a mixed chain. @@ -111,8 +124,18 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo i += 1 end + # In practice, probably you want to use EvoTreeClassifier with early stopping + rng = MersenneTwister(42) + classifiers = ( + EvoTreeClassifier(; rng=rng, nrounds=100, eta=0.3), + Pipeline( + EvoTreeClassifier(; rng=rng, nrounds=100, eta=0.3); operation=predict_mode + ), + SVC(), + XGBoostClassifiers..., + ) @testset "classifier = $classifier" for classifier in classifiers - rng = MersenneTwister(42) + Random.seed!(rng, 42) dist1 = rstar(rng, classifier, samples_mat, chain_inds) Random.seed!(rng, 42) dist2 = rstar(rng, classifier, samples) @@ -120,4 +143,45 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo @test typeof(rstar(classifier, samples)) === typeof(dist2) end end + + @testset "model traits requirements" begin + samples = randn(2, 3, 4) + + inputs_error = ArgumentError( + "classifier does not support tables of continuous values as inputs" + ) + model = UnivariateDiscretizer() + @test_throws inputs_error rstar(model, samples) + @test_throws inputs_error MCMCDiagnosticTools._check_model_supports_continuous_inputs( + model + ) + + targets_error = ArgumentError( + "classifier does not support vectors of multi-class labels as targets" + ) + predictions_error = ArgumentError( + "classifier does not support vectors of multi-class labels or their densities as predictions", + ) + models = if Sys.WORD_SIZE == 64 + (EvoTreeRegressor(), EvoTreeCount(), XGBoostRegressor(), XGBoostCount()) + else + (EvoTreeRegressor(), EvoTreeCount()) + end + for model in models + @test_throws targets_error rstar(model, samples) + @test_throws targets_error MCMCDiagnosticTools._check_model_supports_multiclass_targets( + model + ) + @test_throws predictions_error MCMCDiagnosticTools._check_model_supports_multiclass_predictions( + model + ) + end + end + + @testset "incorrect type of predictions" begin + @test_throws ArgumentError MCMCDiagnosticTools._rstar( + AbstractVector{<:MLJBase.Continuous}, rand(2), rand(3) + ) + @test_throws ArgumentError MCMCDiagnosticTools._rstar(1.0, rand(2), rand(2)) + end end diff --git a/test/runtests.jl b/test/runtests.jl index af19e7be..e616643c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,3 @@ -using Pkg - using MCMCDiagnosticTools using FFTW @@ -40,11 +38,6 @@ Random.seed!(1) include("rafterydiag.jl") end @testset "R⋆ diagnostic" begin - # XGBoost errors on 32bit systems: https://github.com/dmlc/XGBoost.jl/issues/92 - if Sys.WORD_SIZE == 64 - include("rstar.jl") - else - @info "R⋆ not tested: requires 64bit architecture" - end + include("rstar.jl") end end