Skip to content

Commit

Permalink
Fix issues with MLJDecisionTreeInterface (#76)
Browse files Browse the repository at this point in the history
* Add tests with MLJDecisionTreeInterface

* Fix use of MLJ interface

* Fix seed of DecisionTreeClassifier in reproducibility tests
  • Loading branch information
devmotion authored Feb 27, 2023
1 parent 5955f82 commit 5ad916f
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.3.0"
version = "0.3.1"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
15 changes: 9 additions & 6 deletions src/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,21 @@ function rstar(

xtable = _astable(x)
ycategorical = MMI.categorical(ysplit)
xdata, ydata = MMI.reformat(classifier, xtable, ycategorical)

# train classifier on training data
xtrain, ytrain = MMI.selectrows(classifier, train_ids, xdata, ydata)
fitresult, _ = MMI.fit(classifier, verbosity, xtrain, ytrain)
data = MMI.reformat(classifier, xtable, ycategorical)
train_data = MMI.selectrows(classifier, train_ids, data...)
fitresult, _ = MMI.fit(classifier, verbosity, train_data...)

# compute predictions on test data
xtest, = MMI.selectrows(classifier, test_ids, xdata)
ytest = ycategorical[test_ids]
predictions = _predict(classifier, fitresult, xtest)
# we exploit that MLJ demands that
# reformat(model, args...)[1] = reformat(model, args[1])
# (https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/#Implementing-a-data-front-end)
test_data = MMI.selectrows(classifier, test_ids, data[1])
predictions = _predict(classifier, fitresult, test_data...)

# compute statistic
ytest = ycategorical[test_ids]
result = _rstar(MMI.scitype(predictions), predictions, ytest)

return result
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
Expand All @@ -25,6 +26,7 @@ FFTW = "1.1"
LogDensityProblems = "0.12, 1, 2"
LogExpFunctions = "0.3"
MLJBase = "0.19, 0.20, 0.21"
MLJDecisionTreeInterface = "0.3"
MLJIteration = "0.5"
MLJLIBSVMInterface = "0.2"
MLJModels = "0.16"
Expand Down
3 changes: 3 additions & 0 deletions test/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using MCMCDiagnosticTools
using Distributions
using EvoTrees
using MLJBase: MLJBase, Pipeline, predict_mode
using MLJDecisionTreeInterface
using MLJLIBSVMInterface
using MLJModels
using MLJXGBoostInterface
Expand All @@ -26,6 +27,7 @@ end
classifiers = (
EvoTreeClassifier(; nrounds=100, eta=0.3),
Pipeline(EvoTreeClassifier(; nrounds=100, eta=0.3); operation=predict_mode),
DecisionTreeClassifier(),
SVC(),
XGBoostClassifiers...,
)
Expand Down Expand Up @@ -131,6 +133,7 @@ end
Pipeline(
EvoTreeClassifier(; rng=rng, nrounds=100, eta=0.3); operation=predict_mode
),
DecisionTreeClassifier(; rng=rng),
SVC(),
XGBoostClassifiers...,
)
Expand Down

2 comments on commit 5ad916f

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/78641

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.1 -m "<description of version>" 5ad916f91770f653d6d4e14436ad1fbfee26e719
git push origin v0.3.1

Please sign in to comment.