From 5ad916f91770f653d6d4e14436ad1fbfee26e719 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 28 Feb 2023 00:03:25 +0100 Subject: [PATCH] Fix issues with MLJDecisionTreeInterface (#76) * Add tests with MLJDecisionTreeInterface * Fix use of MLJ interface * Fix seed of DecisionTreeClassifier in reproducibility tests --- Project.toml | 2 +- src/rstar.jl | 15 +++++++++------ test/Project.toml | 2 ++ test/rstar.jl | 3 +++ 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index d4eb3152..848c78c3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.3.0" +version = "0.3.1" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/rstar.jl b/src/rstar.jl index 26d761b0..9b3f2ba6 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -43,18 +43,21 @@ function rstar( xtable = _astable(x) ycategorical = MMI.categorical(ysplit) - xdata, ydata = MMI.reformat(classifier, xtable, ycategorical) # train classifier on training data - xtrain, ytrain = MMI.selectrows(classifier, train_ids, xdata, ydata) - fitresult, _ = MMI.fit(classifier, verbosity, xtrain, ytrain) + data = MMI.reformat(classifier, xtable, ycategorical) + train_data = MMI.selectrows(classifier, train_ids, data...) + fitresult, _ = MMI.fit(classifier, verbosity, train_data...) # compute predictions on test data - xtest, = MMI.selectrows(classifier, test_ids, xdata) - ytest = ycategorical[test_ids] - predictions = _predict(classifier, fitresult, xtest) + # we exploit that MLJ demands that + # reformat(model, args...)[1] = reformat(model, args[1]) + # (https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/#Implementing-a-data-front-end) + test_data = MMI.selectrows(classifier, test_ids, data[1]) + predictions = _predict(classifier, fitresult, test_data...) # compute statistic + ytest = ycategorical[test_ids] result = _rstar(MMI.scitype(predictions), predictions, ytest) return result diff --git a/test/Project.toml b/test/Project.toml index 05c58580..ad2ebb0a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55" MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" @@ -25,6 +26,7 @@ FFTW = "1.1" LogDensityProblems = "0.12, 1, 2" LogExpFunctions = "0.3" MLJBase = "0.19, 0.20, 0.21" +MLJDecisionTreeInterface = "0.3" MLJIteration = "0.5" MLJLIBSVMInterface = "0.2" MLJModels = "0.16" diff --git a/test/rstar.jl b/test/rstar.jl index 9dee8e43..99522945 100644 --- a/test/rstar.jl +++ b/test/rstar.jl @@ -3,6 +3,7 @@ using MCMCDiagnosticTools using Distributions using EvoTrees using MLJBase: MLJBase, Pipeline, predict_mode +using MLJDecisionTreeInterface using MLJLIBSVMInterface using MLJModels using MLJXGBoostInterface @@ -26,6 +27,7 @@ end classifiers = ( EvoTreeClassifier(; nrounds=100, eta=0.3), Pipeline(EvoTreeClassifier(; nrounds=100, eta=0.3); operation=predict_mode), + DecisionTreeClassifier(), SVC(), XGBoostClassifiers..., ) @@ -131,6 +133,7 @@ end Pipeline( EvoTreeClassifier(; rng=rng, nrounds=100, eta=0.3); operation=predict_mode ), + DecisionTreeClassifier(; rng=rng), SVC(), XGBoostClassifiers..., )