From 19a09b94a8dc882eb04387fb01f8e21171cb7a3a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 4 Jan 2023 00:59:39 +0100 Subject: [PATCH 01/22] Use EvoTrees instead of XGBoost --- docs/Project.toml | 4 ++-- src/rstar.jl | 8 ++++---- test/Project.toml | 5 ++--- test/rstar.jl | 6 +++--- test/runtests.jl | 7 +------ 5 files changed, 12 insertions(+), 18 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index ad15f1d3..4a0fb86b 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,14 +1,14 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] Documenter = "0.27" +EvoTrees = "0.14" MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" -MLJXGBoostInterface = "0.1, 0.2, 0.3" julia = "1.3" diff --git a/src/rstar.jl b/src/rstar.jl index 820c3a02..78ec4b2f 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -109,7 +109,7 @@ is returned (algorithm 2). # Examples ```jldoctest rstar; setup = :(using Random; Random.seed!(101)) -julia> using MLJBase, MLJXGBoostInterface, Statistics +julia> using MLJBase, EvoTrees, Statistics julia> samples = fill(4.0, 100, 3, 2); ``` @@ -118,7 +118,7 @@ One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the probabilistic classifier. ```jldoctest rstar -julia> distribution = rstar(XGBoostClassifier(), samples); +julia> distribution = rstar(EvoTreeClassifier(), samples); julia> isapprox(mean(distribution), 1; atol=0.1) true @@ -129,9 +129,9 @@ Deterministic classifiers can also be derived from probabilistic classifiers by predicting the mode. In MLJ this corresponds to a pipeline of models. ```jldoctest rstar -julia> xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode); +julia> evotree_deterministic = Pipeline(EvoTreeClassifier(); operation=predict_mode); -julia> value = rstar(xgboost_deterministic, samples); +julia> value = rstar(evotree_deterministic, samples); julia> isapprox(value, 1; atol=0.2) true diff --git a/test/Project.toml b/test/Project.toml index 09fca5b2..97fe1542 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,11 +1,10 @@ [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" -MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" @@ -13,10 +12,10 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Distributions = "0.25" +EvoTrees = "0.14" FFTW = "1.1" MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" MLJLIBSVMInterface = "0.1, 0.2" -MLJXGBoostInterface = "0.1, 0.2, 0.3" Tables = "1" julia = "1.3" diff --git a/test/rstar.jl b/test/rstar.jl index 92706416..a47d8d25 100644 --- a/test/rstar.jl +++ b/test/rstar.jl @@ -1,18 +1,18 @@ using MCMCDiagnosticTools using Distributions +using EvoTrees using MLJBase using MLJLIBSVMInterface -using MLJXGBoostInterface using Tables using Random using Test -const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode) +const evotree_deterministic = Pipeline(EvoTreeClassifier(); operation=predict_mode) @testset "rstar.jl" begin - classifiers = (XGBoostClassifier(), xgboost_deterministic, SVC()) + classifiers = (EvoTreeClassifier(), evotree_deterministic, SVC()) N = 1_000 @testset "samples input type: $wrapper" for wrapper in [Vector, Array, Tables.table] diff --git a/test/runtests.jl b/test/runtests.jl index af19e7be..6207fb2e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,11 +40,6 @@ Random.seed!(1) include("rafterydiag.jl") end @testset "R⋆ diagnostic" begin - # XGBoost errors on 32bit systems: https://github.com/dmlc/XGBoost.jl/issues/92 - if Sys.WORD_SIZE == 64 - include("rstar.jl") - else - @info "R⋆ not tested: requires 64bit architecture" - end + include("rstar.jl") end end From faad5a97282186e460864932eab46813eb01df5f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 4 Jan 2023 01:07:55 +0100 Subject: [PATCH 02/22] Update runtests.jl --- test/runtests.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6207fb2e..e616643c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,3 @@ -using Pkg - using MCMCDiagnosticTools using FFTW From 8db36198679423b3e148e3fcec19ceb39b44576d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 4 Jan 2023 16:34:23 +0100 Subject: [PATCH 03/22] Try to fix use of MLJ interface --- src/rstar.jl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/rstar.jl b/src/rstar.jl index 78ec4b2f..8545f6a2 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -40,21 +40,22 @@ function rstar( throw(ArgumentError("training and test data subsets must not be empty")) xtable = _astable(x) + ycategorical = MLJModelInterface.categorical(ysplit) # train classifier on training data - ycategorical = MLJModelInterface.categorical(ysplit) - xtrain = MLJModelInterface.selectrows(xtable, train_ids) - fitresult, _ = MLJModelInterface.fit( - classifier, verbosity, xtrain, ycategorical[train_ids] + xtrain, ytrain = MLJModelInterface.reformat( + classifier, MLJModelInterface.selectrows(xtable, train_ids), ycategorical[train_ids] ) + fitresult, _ = MLJModelInterface.fit(classifier, verbosity, xtrain, ytrain) # compute predictions on test data - xtest = MLJModelInterface.selectrows(xtable, test_ids) + xtest, ytest = MLJModelInterface.reformat( + classifier, MLJModelInterface.selectrows(xtable, test_ids), ycategorical[train_ids] + ) predictions = _predict(classifier, fitresult, xtest) # compute statistic - ytest = ycategorical[test_ids] - result = _rstar(predictions, ytest) + result = _rstar(classifier, predictions, ytest) return result end @@ -161,7 +162,7 @@ function rstar(classif::MLJModelInterface.Supervised, x::AbstractArray{<:Any,3}; end # R⋆ for deterministic predictions (algorithm 1) -function _rstar(predictions::AbstractVector{T}, ytest::AbstractVector{T}) where {T} +function _rstar(::MLJModelIntetface.Deterministic, predictions::AbstractVector, ytest::AbstractVector) length(predictions) == length(ytest) || error("numbers of predictions and targets must be equal") mean_accuracy = Statistics.mean(p == y for (p, y) in zip(predictions, ytest)) @@ -170,7 +171,7 @@ function _rstar(predictions::AbstractVector{T}, ytest::AbstractVector{T}) where end # R⋆ for probabilistic predictions (algorithm 2) -function _rstar(predictions::AbstractVector, ytest::AbstractVector) +function _rstar(::MLJModelInferface.Probabilistic, predictions::AbstractVector, ytest::AbstractVector) length(predictions) == length(ytest) || error("numbers of predictions and targets must be equal") From 6abeb7013a7157a92b721fdb62ecd1d8c52d9412 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 4 Jan 2023 16:37:16 +0100 Subject: [PATCH 04/22] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rstar.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/rstar.jl b/src/rstar.jl index 8545f6a2..f01ec487 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -162,7 +162,9 @@ function rstar(classif::MLJModelInterface.Supervised, x::AbstractArray{<:Any,3}; end # R⋆ for deterministic predictions (algorithm 1) -function _rstar(::MLJModelIntetface.Deterministic, predictions::AbstractVector, ytest::AbstractVector) +function _rstar( + ::MLJModelIntetface.Deterministic, predictions::AbstractVector, ytest::AbstractVector +) length(predictions) == length(ytest) || error("numbers of predictions and targets must be equal") mean_accuracy = Statistics.mean(p == y for (p, y) in zip(predictions, ytest)) @@ -171,7 +173,9 @@ function _rstar(::MLJModelIntetface.Deterministic, predictions::AbstractVector, end # R⋆ for probabilistic predictions (algorithm 2) -function _rstar(::MLJModelInferface.Probabilistic, predictions::AbstractVector, ytest::AbstractVector) +function _rstar( + ::MLJModelInferface.Probabilistic, predictions::AbstractVector, ytest::AbstractVector +) length(predictions) == length(ytest) || error("numbers of predictions and targets must be equal") From e13968cf3c28216ccc66cfc039a2de6cf9545a4b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 4 Jan 2023 16:39:17 +0100 Subject: [PATCH 05/22] Update src/rstar.jl --- src/rstar.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rstar.jl b/src/rstar.jl index f01ec487..be3674f6 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -163,7 +163,7 @@ end # R⋆ for deterministic predictions (algorithm 1) function _rstar( - ::MLJModelIntetface.Deterministic, predictions::AbstractVector, ytest::AbstractVector + ::MLJModelInterface.Deterministic, predictions::AbstractVector, ytest::AbstractVector ) length(predictions) == length(ytest) || error("numbers of predictions and targets must be equal") From 06364453d69e109e016f88c29c416ddb4ce21ce7 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 4 Jan 2023 16:54:27 +0100 Subject: [PATCH 06/22] Apply suggestions from code review Co-authored-by: Seth Axen --- src/rstar.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/rstar.jl b/src/rstar.jl index be3674f6..f4f98d10 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -41,17 +41,14 @@ function rstar( xtable = _astable(x) ycategorical = MLJModelInterface.categorical(ysplit) + xdata, ydata = MLJModelInterface.reformat(classifier, xtable, ycategorical) # train classifier on training data - xtrain, ytrain = MLJModelInterface.reformat( - classifier, MLJModelInterface.selectrows(xtable, train_ids), ycategorical[train_ids] - ) + xtrain, ytrain = MLJModelInterface.selectrows(classifier, train_ids, xdata, ydata) fitresult, _ = MLJModelInterface.fit(classifier, verbosity, xtrain, ytrain) # compute predictions on test data - xtest, ytest = MLJModelInterface.reformat( - classifier, MLJModelInterface.selectrows(xtable, test_ids), ycategorical[train_ids] - ) + xtest, ytest = MLJModelInterface.selectrows(classifier, test_ids, xdata, ydata) predictions = _predict(classifier, fitresult, xtest) # compute statistic @@ -174,7 +171,7 @@ end # R⋆ for probabilistic predictions (algorithm 2) function _rstar( - ::MLJModelInferface.Probabilistic, predictions::AbstractVector, ytest::AbstractVector + ::MLJModelInterface.Probabilistic, predictions::AbstractVector, ytest::AbstractVector ) length(predictions) == length(ytest) || error("numbers of predictions and targets must be equal") From 4bac113b65d4baa00a5433be8d988890c02e96cb Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 4 Jan 2023 18:41:06 +0100 Subject: [PATCH 07/22] Fix RNG of `EvoTreeClassifier` in tests --- test/rstar.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/rstar.jl b/test/rstar.jl index a47d8d25..96baa27b 100644 --- a/test/rstar.jl +++ b/test/rstar.jl @@ -9,10 +9,10 @@ using Tables using Random using Test -const evotree_deterministic = Pipeline(EvoTreeClassifier(); operation=predict_mode) +const evotree_deterministic = Pipeline(EvoTreeClassifier(; rng=1234); operation=predict_mode) @testset "rstar.jl" begin - classifiers = (EvoTreeClassifier(), evotree_deterministic, SVC()) + classifiers = (EvoTreeClassifier(; rng=1234), evotree_deterministic, SVC()) N = 1_000 @testset "samples input type: $wrapper" for wrapper in [Vector, Array, Tables.table] From dac6e9b7cf766c78ef9a421c4bec449de32a65ba Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 8 Jan 2023 00:42:17 +0100 Subject: [PATCH 08/22] Update rstar.jl --- test/rstar.jl | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/test/rstar.jl b/test/rstar.jl index 96baa27b..0ee0e727 100644 --- a/test/rstar.jl +++ b/test/rstar.jl @@ -9,13 +9,15 @@ using Tables using Random using Test -const evotree_deterministic = Pipeline(EvoTreeClassifier(; rng=1234); operation=predict_mode) - @testset "rstar.jl" begin - classifiers = (EvoTreeClassifier(; rng=1234), evotree_deterministic, SVC()) N = 1_000 @testset "samples input type: $wrapper" for wrapper in [Vector, Array, Tables.table] + classifiers = ( + EvoTreeClassifier(), + Pipeline(EvoTreeClassifier(); operation=predict_mode), + SVC(), + ) @testset "examples (classifier = $classifier)" for classifier in classifiers sz = wrapper === Vector ? N : (N, 2) # Compute R⋆ statistic for a mixed chain. @@ -111,8 +113,14 @@ const evotree_deterministic = Pipeline(EvoTreeClassifier(; rng=1234); operation= i += 1 end + rng = MersenneTwister(42) + classifiers = ( + EvoTreeClassifier(; rng=rng), + Pipeline(EvoTreeClassifier(; rng=rng); operation=predict_mode), + SVC(), + ) @testset "classifier = $classifier" for classifier in classifiers - rng = MersenneTwister(42) + Random.seed!(rng, 42) dist1 = rstar(rng, classifier, samples_mat, chain_inds) Random.seed!(rng, 42) dist2 = rstar(rng, classifier, samples) From caccd36606098a4a762120711174889ef6d7de18 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 9 Jan 2023 00:12:02 +0100 Subject: [PATCH 09/22] Update rstar.jl --- test/rstar.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/rstar.jl b/test/rstar.jl index 0ee0e727..b1dd0192 100644 --- a/test/rstar.jl +++ b/test/rstar.jl @@ -13,9 +13,10 @@ using Test N = 1_000 @testset "samples input type: $wrapper" for wrapper in [Vector, Array, Tables.table] + # Should use EvoTreeClassifier with early stopping classifiers = ( - EvoTreeClassifier(), - Pipeline(EvoTreeClassifier(); operation=predict_mode), + EvoTreeClassifier(; nrounds=100, eta=0.3), + Pipeline(EvoTreeClassifier(; nrounds=100, eta=0.3); operation=predict_mode), SVC(), ) @testset "examples (classifier = $classifier)" for classifier in classifiers @@ -113,10 +114,11 @@ using Test i += 1 end + # Should use EvoTreeClassifier with early stopping rng = MersenneTwister(42) classifiers = ( - EvoTreeClassifier(; rng=rng), - Pipeline(EvoTreeClassifier(; rng=rng); operation=predict_mode), + EvoTreeClassifier(; rng=rng, nrounds=100, eta=0.3), + Pipeline(EvoTreeClassifier(; rng=rng, nrounds=100, eta=0.3); operation=predict_mode), SVC(), ) @testset "classifier = $classifier" for classifier in classifiers From b35269458b21ace41ce51edf95ae9fe186ba90aa Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 9 Jan 2023 07:35:16 +0100 Subject: [PATCH 10/22] Update test/rstar.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/rstar.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/rstar.jl b/test/rstar.jl index b1dd0192..2a46dc98 100644 --- a/test/rstar.jl +++ b/test/rstar.jl @@ -118,7 +118,9 @@ using Test rng = MersenneTwister(42) classifiers = ( EvoTreeClassifier(; rng=rng, nrounds=100, eta=0.3), - Pipeline(EvoTreeClassifier(; rng=rng, nrounds=100, eta=0.3); operation=predict_mode), + Pipeline( + EvoTreeClassifier(; rng=rng, nrounds=100, eta=0.3); operation=predict_mode + ), SVC(), ) @testset "classifier = $classifier" for classifier in classifiers From 82dae8832234bf01a167a56e71493403ede5f6ea Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 9 Jan 2023 10:46:49 +0100 Subject: [PATCH 11/22] Bump compat entries --- .github/workflows/CI.yml | 2 +- Project.toml | 2 +- docs/Project.toml | 4 ++-- test/Project.toml | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b96ee980..39be5311 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: version: - - '1.3' + - '1.6' - '1' - 'nightly' os: diff --git a/Project.toml b/Project.toml index 7d5cd9c9..f71e1d4d 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,7 @@ MLJModelInterface = "1.6" SpecialFunctions = "0.8, 0.9, 0.10, 1, 2" StatsBase = "0.33" Tables = "1" -julia = "1.3" +julia = "1.6" [extras] Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/docs/Project.toml b/docs/Project.toml index 4a0fb86b..4df17ef4 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,7 +8,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] Documenter = "0.27" -EvoTrees = "0.14" +EvoTrees = "0.14.6" MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" -julia = "1.3" +julia = "1.6" diff --git a/test/Project.toml b/test/Project.toml index 97fe1542..24e201ae 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,10 +12,10 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Distributions = "0.25" -EvoTrees = "0.14" +EvoTrees = "0.14.6" FFTW = "1.1" MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" MLJLIBSVMInterface = "0.1, 0.2" Tables = "1" -julia = "1.3" +julia = "1.6" From 07f15529def2438f7b79b1c3beaaec65c58d17ff Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 9 Jan 2023 22:53:28 +0100 Subject: [PATCH 12/22] Apply suggested changes --- src/rstar.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/rstar.jl b/src/rstar.jl index f4f98d10..8cc3472f 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -48,7 +48,8 @@ function rstar( fitresult, _ = MLJModelInterface.fit(classifier, verbosity, xtrain, ytrain) # compute predictions on test data - xtest, ytest = MLJModelInterface.selectrows(classifier, test_ids, xdata, ydata) + xtest, = MLJModelInterface.selectrows(classifier, test_ids, xdata) + ytest = ycategorical[test_ids] predictions = _predict(classifier, fitresult, xtest) # compute statistic From 10f412f898d435f45f2b720fedc26bcbdfdb82ac Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 14 Jan 2023 01:04:11 +0100 Subject: [PATCH 13/22] Use a more advanced example --- .gitignore | 4 +--- docs/Project.toml | 2 ++ src/rstar.jl | 15 ++++++++++++--- test/Project.toml | 2 ++ 4 files changed, 17 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index d6e09773..1c02e5e1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,5 @@ *.jl.*.cov *.jl.cov *.jl.mem -/Manifest.toml -/test/Manifest.toml -/test/rstar/Manifest.toml +Manifest.toml /docs/build/ diff --git a/docs/Project.toml b/docs/Project.toml index 4df17ef4..f76db00f 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,6 +3,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -11,4 +12,5 @@ Documenter = "0.27" EvoTrees = "0.14.6" MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" +MLJIteration = "0.5" julia = "1.6" diff --git a/src/rstar.jl b/src/rstar.jl index 8cc3472f..d7df291f 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -108,7 +108,7 @@ is returned (algorithm 2). # Examples ```jldoctest rstar; setup = :(using Random; Random.seed!(101)) -julia> using MLJBase, EvoTrees, Statistics +julia> using MLJBase, MLJIteration, EvoTrees, Statistics julia> samples = fill(4.0, 100, 3, 2); ``` @@ -117,7 +117,16 @@ One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the probabilistic classifier. ```jldoctest rstar -julia> distribution = rstar(EvoTreeClassifier(), samples); +julia> model = IteratedModel(; + model=EvoTreeClassifier(; eta=0.005), + iteration_parameter=:nrounds, + resampling=Holdout(), + measures=log_loss, + controls=[Step(5), Patience(2), NumberLimit(100)], + retrain=true, + ); + +julia> distribution = rstar(model, samples); julia> isapprox(mean(distribution), 1; atol=0.1) true @@ -128,7 +137,7 @@ Deterministic classifiers can also be derived from probabilistic classifiers by predicting the mode. In MLJ this corresponds to a pipeline of models. ```jldoctest rstar -julia> evotree_deterministic = Pipeline(EvoTreeClassifier(); operation=predict_mode); +julia> evotree_deterministic = Pipeline(model; operation=predict_mode); julia> value = rstar(evotree_deterministic, samples); diff --git a/test/Project.toml b/test/Project.toml index 24e201ae..0ea10f88 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55" MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -16,6 +17,7 @@ EvoTrees = "0.14.6" FFTW = "1.1" MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" +MLJIteration = "0.5" MLJLIBSVMInterface = "0.1, 0.2" Tables = "1" julia = "1.6" From c69e32a1721efce47dfca84f66aad94d70cbe3a9 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 14 Jan 2023 02:50:10 +0100 Subject: [PATCH 14/22] Test XGBoost as well --- test/Project.toml | 4 +++- test/rstar.jl | 17 +++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 2cba0dc4..fcf31e19 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,6 +9,7 @@ MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55" MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" +MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -25,7 +26,8 @@ LogExpFunctions = "0.3" MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" MLJIteration = "0.5" -MLJLIBSVMInterface = "0.1, 0.2" +MLJLIBSVMInterface = "0.2" +MLJXGBoostInterface = "0.3" StatsBase = "0.33" Tables = "1" julia = "1.6" diff --git a/test/rstar.jl b/test/rstar.jl index 2a46dc98..72bbda3a 100644 --- a/test/rstar.jl +++ b/test/rstar.jl @@ -4,20 +4,32 @@ using Distributions using EvoTrees using MLJBase using MLJLIBSVMInterface +using MLJXGBoostInterface using Tables using Random using Test +# XGBoost errors on 32bit systems: https://github.com/dmlc/XGBoost.jl/issues/92 +const XGBoostClassifiers = if Sys.WORD_SIZE == 64 + ( + XGBoostClassifier(), + Pipeline(XGBoostClassifier(); operation=predict_mode), + ) +else + () +end + @testset "rstar.jl" begin N = 1_000 @testset "samples input type: $wrapper" for wrapper in [Vector, Array, Tables.table] - # Should use EvoTreeClassifier with early stopping + # In practice, probably you want to use EvoTreeClassifier with early stopping classifiers = ( EvoTreeClassifier(; nrounds=100, eta=0.3), Pipeline(EvoTreeClassifier(; nrounds=100, eta=0.3); operation=predict_mode), SVC(), + XGBoostClassifiers..., ) @testset "examples (classifier = $classifier)" for classifier in classifiers sz = wrapper === Vector ? N : (N, 2) @@ -114,7 +126,7 @@ using Test i += 1 end - # Should use EvoTreeClassifier with early stopping + # In practice, probably you want to use EvoTreeClassifier with early stopping rng = MersenneTwister(42) classifiers = ( EvoTreeClassifier(; rng=rng, nrounds=100, eta=0.3), @@ -122,6 +134,7 @@ using Test EvoTreeClassifier(; rng=rng, nrounds=100, eta=0.3); operation=predict_mode ), SVC(), + XGBoostClassifiers..., ) @testset "classifier = $classifier" for classifier in classifiers Random.seed!(rng, 42) From 52e5a5aafec9ef3c4196616e03ae9f20b43d7963 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 14 Jan 2023 03:24:28 +0100 Subject: [PATCH 15/22] Update rstar.jl --- test/rstar.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/test/rstar.jl b/test/rstar.jl index 72bbda3a..45217844 100644 --- a/test/rstar.jl +++ b/test/rstar.jl @@ -12,12 +12,9 @@ using Test # XGBoost errors on 32bit systems: https://github.com/dmlc/XGBoost.jl/issues/92 const XGBoostClassifiers = if Sys.WORD_SIZE == 64 - ( - XGBoostClassifier(), - Pipeline(XGBoostClassifier(); operation=predict_mode), - ) + (XGBoostClassifier(), Pipeline(XGBoostClassifier(); operation=predict_mode)) else - () + () end @testset "rstar.jl" begin From df4772b75fd924a5d41144ae5c23097aa42877ea Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 14 Jan 2023 10:13:14 +0100 Subject: [PATCH 16/22] Update rstar.jl --- src/rstar.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rstar.jl b/src/rstar.jl index d7df291f..42e66a83 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -113,7 +113,7 @@ julia> using MLJBase, MLJIteration, EvoTrees, Statistics julia> samples = fill(4.0, 100, 3, 2); ``` -One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the +One can compute the distribution of the ``R^*`` statistic (algorithm 2) with a probabilistic classifier. ```jldoctest rstar From 1cfaebc8988119b8346c63bce93898bd480edae9 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 14 Jan 2023 17:01:12 +0100 Subject: [PATCH 17/22] Update EvoTrees dependency --- docs/Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index f76db00f..96106f96 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -9,7 +9,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] Documenter = "0.27" -EvoTrees = "0.14.6" +EvoTrees = "0.14.7" MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" MLJIteration = "0.5" diff --git a/test/Project.toml b/test/Project.toml index fcf31e19..65680ef7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -19,7 +19,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Distributions = "0.25" DynamicHMC = "3" -EvoTrees = "0.14.6" +EvoTrees = "0.14.7" FFTW = "1.1" LogDensityProblems = "0.12, 1, 2" LogExpFunctions = "0.3" From 36763fbb498e1a43192a3df7fc6a95b90b13c037 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 15 Jan 2023 01:05:26 +0100 Subject: [PATCH 18/22] Update documentation --- src/rstar.jl | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/rstar.jl b/src/rstar.jl index 42e66a83..5441c901 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -115,10 +115,23 @@ julia> samples = fill(4.0, 100, 3, 2); One can compute the distribution of the ``R^*`` statistic (algorithm 2) with a probabilistic classifier. +For instance, we can use a gradient-boosted trees model with `nrounds = 100` sequentially stacked trees and learning rate `eta = 0.05`: + +```jldoctest rstar +julia> model = EvoTreeClassifier(; nrounds=100, eta=0.05); + +julia> distribution = rstar(model, samples); + +julia> round(mean(distribution); digits=2) +1.0f0 +``` + +Note, however, that it is recommended to determine `nrounds` based on early-stopping. +With the MLJ framework, this can be achieved in the following way (see the [MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/) for additional explanations): ```jldoctest rstar julia> model = IteratedModel(; - model=EvoTreeClassifier(; eta=0.005), + model=EvoTreeClassifier(; eta=0.05), iteration_parameter=:nrounds, resampling=Holdout(), measures=log_loss, @@ -128,8 +141,8 @@ julia> model = IteratedModel(; julia> distribution = rstar(model, samples); -julia> isapprox(mean(distribution), 1; atol=0.1) -true +julia> round(mean(distribution); digits=2) +1.0f0 ``` For deterministic classifiers, a single ``R^*`` statistic (algorithm 1) is returned. @@ -141,8 +154,8 @@ julia> evotree_deterministic = Pipeline(model; operation=predict_mode); julia> value = rstar(evotree_deterministic, samples); -julia> isapprox(value, 1; atol=0.2) -true +julia> round(value; digits=2) +1.0 ``` # References From 86baaaf9ba0d9449aa979a247e4cfd9201ce8595 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 15 Jan 2023 01:11:01 +0100 Subject: [PATCH 19/22] Use MLJ traits --- src/MCMCDiagnosticTools.jl | 2 +- src/rstar.jl | 101 ++++++++++++++++++++++++++----------- 2 files changed, 72 insertions(+), 31 deletions(-) diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index beb1ca11..c2ef459e 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -4,7 +4,7 @@ using AbstractFFTs: AbstractFFTs using DataAPI: DataAPI using DataStructures: DataStructures using Distributions: Distributions -using MLJModelInterface: MLJModelInterface +using MLJModelInterface: MLJModelInterface as MMI using SpecialFunctions: SpecialFunctions using StatsBase: StatsBase using StatsFuns: StatsFuns diff --git a/src/rstar.jl b/src/rstar.jl index 5441c901..3bf83bce 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -1,7 +1,7 @@ """ rstar( rng::Random.AbstractRNG=Random.default_rng(), - classifier::MLJModelInterface.Supervised, + classifier, samples, chain_indices::AbstractVector{Int}; subset::Real=0.7, @@ -21,39 +21,72 @@ This method supports ragged chains, i.e. chains of nonequal lengths. """ function rstar( rng::Random.AbstractRNG, - classifier::MLJModelInterface.Supervised, + classifier, x, y::AbstractVector{Int}; subset::Real=0.7, split_chains::Int=2, verbosity::Int=0, ) - # checks - MLJModelInterface.nrows(x) != length(y) && throw(DimensionMismatch()) - 0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)")) + # check that the model supports the inputs and targets, and has predictions of the desired form + # ideally we would not allow MMI.Unknown but some models do not implement the traits + input_scitype_classifier = MMI.input_scitype(classifier) + if input_scitype_classifier !== MMI.Unknown && + !(MMI.Table(MMI.Continuous) <: input_scitype_classifier) + throw( + ArgumentError( + "classifier does not support tables of continuous values as inputs" + ), + ) + end + target_scitype_classifier = MMI.target_scitype(classifier) + if target_scitype_classifier !== MMI.Unknown && + !(AbstractVector{<:MMI.Finite} <: target_scitype_classifier) + throw( + ArgumentError( + "classifier does not support vectors of multi-class labels as targets" + ), + ) + end + if !( + MMI.predict_scitype(classifier) <: Union{ + MMI.Unknown, + AbstractVector{<:MMI.Finite}, + AbstractVector{<:MMI.Density{<:MMI.Finite}}, + } + ) + throw( + ArgumentError( + "classifier does not support vectors of multi-class labels or their densities as predictions", + ), + ) + end - ysplit = split_chain_indices(y, split_chains) + # check the other arguments + MMI.nrows(x) != length(y) && throw(DimensionMismatch()) + 0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)")) # randomly sub-select training and testing set + ysplit = split_chain_indices(y, split_chains) train_ids, test_ids = shuffle_split_stratified(rng, ysplit, subset) 0 < length(train_ids) < length(y) || throw(ArgumentError("training and test data subsets must not be empty")) xtable = _astable(x) - ycategorical = MLJModelInterface.categorical(ysplit) - xdata, ydata = MLJModelInterface.reformat(classifier, xtable, ycategorical) + ycategorical = MMI.categorical(ysplit) + xdata, ydata = MMI.reformat(classifier, xtable, ycategorical) # train classifier on training data - xtrain, ytrain = MLJModelInterface.selectrows(classifier, train_ids, xdata, ydata) - fitresult, _ = MLJModelInterface.fit(classifier, verbosity, xtrain, ytrain) + xtrain, ytrain = MMI.selectrows(classifier, train_ids, xdata, ydata) + fitresult, _ = MMI.fit(classifier, verbosity, xtrain, ytrain) # compute predictions on test data - xtest, = MLJModelInterface.selectrows(classifier, test_ids, xdata) + xtest, = MMI.selectrows(classifier, test_ids, xdata) ytest = ycategorical[test_ids] predictions = _predict(classifier, fitresult, xtest) # compute statistic - result = _rstar(classifier, predictions, ytest) + result = _rstar(MMI.scitype(predictions), predictions, ytest) return result end @@ -64,9 +97,9 @@ _astable(x) = Tables.istable(x) ? x : throw(ArgumentError("Argument is not a val # Workaround for https://github.com/JuliaAI/MLJBase.jl/issues/863 # `MLJModelInterface.predict` sometimes returns predictions and sometimes predictions + additional information # TODO: Remove once the upstream issue is fixed -function _predict(model::MLJModelInterface.Model, fitresult, x) - y = MLJModelInterface.predict(model, fitresult, x) - return if :predict in MLJModelInterface.reporting_operations(model) +function _predict(model::MMI.Model, fitresult, x) + y = MMI.predict(model, fitresult, x) + return if :predict in MMI.reporting_operations(model) first(y) else y @@ -76,7 +109,7 @@ end """ rstar( rng::Random.AbstractRNG=Random.default_rng(), - classifier::MLJModelInterface.Supervised, + classifier, samples::AbstractArray{<:Real,3}; subset::Real=0.7, split_chains::Int=2, @@ -162,39 +195,38 @@ julia> round(value; digits=2) Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic with uncertainty using decision tree classifiers. """ -function rstar( - rng::Random.AbstractRNG, - classifier::MLJModelInterface.Supervised, - x::AbstractArray{<:Any,3}; - kwargs..., -) +function rstar(rng::Random.AbstractRNG, classifier, x::AbstractArray{<:Any,3}; kwargs...) samples = reshape(x, :, size(x, 3)) chain_inds = repeat(axes(x, 2); inner=size(x, 1)) return rstar(rng, classifier, samples, chain_inds; kwargs...) end -function rstar(classif::MLJModelInterface.Supervised, x, y::AbstractVector{Int}; kwargs...) - return rstar(Random.default_rng(), classif, x, y; kwargs...) +function rstar(classifier, x, y::AbstractVector{Int}; kwargs...) + return rstar(Random.default_rng(), classifier, x, y; kwargs...) end -function rstar(classif::MLJModelInterface.Supervised, x::AbstractArray{<:Any,3}; kwargs...) - return rstar(Random.default_rng(), classif, x; kwargs...) +function rstar(classifier, x::AbstractArray{<:Any,3}; kwargs...) + return rstar(Random.default_rng(), classifier, x; kwargs...) end # R⋆ for deterministic predictions (algorithm 1) function _rstar( - ::MLJModelInterface.Deterministic, predictions::AbstractVector, ytest::AbstractVector + ::Type{<:AbstractVector{<:MMI.Finite}}, + predictions::AbstractVector, + ytest::AbstractVector, ) length(predictions) == length(ytest) || error("numbers of predictions and targets must be equal") mean_accuracy = Statistics.mean(p == y for (p, y) in zip(predictions, ytest)) - nclasses = length(MLJModelInterface.classes(ytest)) + nclasses = length(MMI.classes(ytest)) return nclasses * mean_accuracy end # R⋆ for probabilistic predictions (algorithm 2) function _rstar( - ::MLJModelInterface.Probabilistic, predictions::AbstractVector, ytest::AbstractVector + ::Type{<:AbstractVector{<:MMI.Density{<:MMI.Finite}}}, + predictions::AbstractVector, + ytest::AbstractVector, ) length(predictions) == length(ytest) || error("numbers of predictions and targets must be equal") @@ -203,8 +235,17 @@ function _rstar( distribution = Distributions.PoissonBinomial(map(Distributions.pdf, predictions, ytest)) # scale distribution to support in `[0, nclasses]` - nclasses = length(MLJModelInterface.classes(ytest)) + nclasses = length(MMI.classes(ytest)) scaled_distribution = (nclasses//length(predictions)) * distribution return scaled_distribution end + +# unsupported types of predictions and targets +function _rstar(::Type, predictions::Any, targets::Any) + throw( + ArgumentError( + "unsupported types of predictions ($(typeof(predictions))) and targets ($(typeof(targets)))", + ), + ) +end From 101863737d62316da18a0e3b08e333c0fca8e026 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 17 Jan 2023 01:25:11 +0100 Subject: [PATCH 20/22] Update src/rstar.jl Co-authored-by: Seth Axen --- src/rstar.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rstar.jl b/src/rstar.jl index 3bf83bce..dfd2db83 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -242,7 +242,7 @@ function _rstar( end # unsupported types of predictions and targets -function _rstar(::Type, predictions::Any, targets::Any) +function _rstar(::Any, predictions, targets) throw( ArgumentError( "unsupported types of predictions ($(typeof(predictions))) and targets ($(typeof(targets)))", From c1ca14b45d8d60540e1703fd3584cdd9cd3b3543 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 17 Jan 2023 02:48:45 +0100 Subject: [PATCH 21/22] Some refactoring and additional tests --- src/rstar.jl | 70 +++++++++++++++++++++++++++-------------------- test/Project.toml | 2 ++ test/rstar.jl | 44 ++++++++++++++++++++++++++++- 3 files changed, 86 insertions(+), 30 deletions(-) diff --git a/src/rstar.jl b/src/rstar.jl index dfd2db83..26d761b0 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -28,7 +28,40 @@ function rstar( split_chains::Int=2, verbosity::Int=0, ) - # check that the model supports the inputs and targets, and has predictions of the desired form + # check the arguments + _check_model_supports_continuous_inputs(classifier) + _check_model_supports_multiclass_targets(classifier) + _check_model_supports_multiclass_predictions(classifier) + MMI.nrows(x) != length(y) && throw(DimensionMismatch()) + 0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)")) + + # randomly sub-select training and testing set + ysplit = split_chain_indices(y, split_chains) + train_ids, test_ids = shuffle_split_stratified(rng, ysplit, subset) + 0 < length(train_ids) < length(y) || + throw(ArgumentError("training and test data subsets must not be empty")) + + xtable = _astable(x) + ycategorical = MMI.categorical(ysplit) + xdata, ydata = MMI.reformat(classifier, xtable, ycategorical) + + # train classifier on training data + xtrain, ytrain = MMI.selectrows(classifier, train_ids, xdata, ydata) + fitresult, _ = MMI.fit(classifier, verbosity, xtrain, ytrain) + + # compute predictions on test data + xtest, = MMI.selectrows(classifier, test_ids, xdata) + ytest = ycategorical[test_ids] + predictions = _predict(classifier, fitresult, xtest) + + # compute statistic + result = _rstar(MMI.scitype(predictions), predictions, ytest) + + return result +end + +# check that the model supports the inputs and targets, and has predictions of the desired form +function _check_model_supports_continuous_inputs(classifier) # ideally we would not allow MMI.Unknown but some models do not implement the traits input_scitype_classifier = MMI.input_scitype(classifier) if input_scitype_classifier !== MMI.Unknown && @@ -39,6 +72,9 @@ function rstar( ), ) end + return nothing +end +function _check_model_supports_multiclass_targets(classifier) target_scitype_classifier = MMI.target_scitype(classifier) if target_scitype_classifier !== MMI.Unknown && !(AbstractVector{<:MMI.Finite} <: target_scitype_classifier) @@ -48,6 +84,9 @@ function rstar( ), ) end + return nothing +end +function _check_model_supports_multiclass_predictions(classifier) if !( MMI.predict_scitype(classifier) <: Union{ MMI.Unknown, @@ -61,34 +100,7 @@ function rstar( ), ) end - - # check the other arguments - MMI.nrows(x) != length(y) && throw(DimensionMismatch()) - 0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)")) - - # randomly sub-select training and testing set - ysplit = split_chain_indices(y, split_chains) - train_ids, test_ids = shuffle_split_stratified(rng, ysplit, subset) - 0 < length(train_ids) < length(y) || - throw(ArgumentError("training and test data subsets must not be empty")) - - xtable = _astable(x) - ycategorical = MMI.categorical(ysplit) - xdata, ydata = MMI.reformat(classifier, xtable, ycategorical) - - # train classifier on training data - xtrain, ytrain = MMI.selectrows(classifier, train_ids, xdata, ydata) - fitresult, _ = MMI.fit(classifier, verbosity, xtrain, ytrain) - - # compute predictions on test data - xtest, = MMI.selectrows(classifier, test_ids, xdata) - ytest = ycategorical[test_ids] - predictions = _predict(classifier, fitresult, xtest) - - # compute statistic - result = _rstar(MMI.scitype(predictions), predictions, ytest) - - return result + return nothing end _astable(x::AbstractVecOrMat) = Tables.table(x) diff --git a/test/Project.toml b/test/Project.toml index 65680ef7..13812c0a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,6 +9,7 @@ MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55" MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" +MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -27,6 +28,7 @@ MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" MLJIteration = "0.5" MLJLIBSVMInterface = "0.2" +MLJModels = "0.16" MLJXGBoostInterface = "0.3" StatsBase = "0.33" Tables = "1" diff --git a/test/rstar.jl b/test/rstar.jl index 45217844..9dee8e43 100644 --- a/test/rstar.jl +++ b/test/rstar.jl @@ -2,8 +2,9 @@ using MCMCDiagnosticTools using Distributions using EvoTrees -using MLJBase +using MLJBase: MLJBase, Pipeline, predict_mode using MLJLIBSVMInterface +using MLJModels using MLJXGBoostInterface using Tables @@ -142,4 +143,45 @@ end @test typeof(rstar(classifier, samples)) === typeof(dist2) end end + + @testset "model traits requirements" begin + samples = randn(2, 3, 4) + + inputs_error = ArgumentError( + "classifier does not support tables of continuous values as inputs" + ) + model = UnivariateDiscretizer() + @test_throws inputs_error rstar(model, samples) + @test_throws inputs_error MCMCDiagnosticTools._check_model_supports_continuous_inputs( + model + ) + + targets_error = ArgumentError( + "classifier does not support vectors of multi-class labels as targets" + ) + predictions_error = ArgumentError( + "classifier does not support vectors of multi-class labels or their densities as predictions", + ) + models = if Sys.WORD_SIZE == 64 + (EvoTreeRegressor(), EvoTreeCount(), XGBoostRegressor(), XGBoostCount()) + else + (EvoTreeRegressor(), EvoTreeCount()) + end + for model in models + @test_throws targets_error rstar(model, samples) + @test_throws targets_error MCMCDiagnosticTools._check_model_supports_multiclass_targets( + model + ) + @test_throws predictions_error MCMCDiagnosticTools._check_model_supports_multiclass_predictions( + model + ) + end + end + + @testset "incorrect type of predictions" begin + @test_throws ArgumentError MCMCDiagnosticTools._rstar( + AbstractVector{<:MLJBase.Continuous}, rand(2), rand(3) + ) + @test_throws ArgumentError MCMCDiagnosticTools._rstar(1.0, rand(2), rand(2)) + end end From a1f428b2d252f962db4394d7e184d6ae30a91c54 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 17 Jan 2023 13:04:06 +0100 Subject: [PATCH 22/22] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 362c197d..9e702160 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.2.5" +version = "0.2.6" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"