From fe0fda9a972a21fe257110dbfb7bc8d4a7f1d7cb Mon Sep 17 00:00:00 2001 From: Elias Carvalho <73039601+eliascarv@users.noreply.github.com> Date: Fri, 15 Dec 2023 17:06:35 -0300 Subject: [PATCH] Add accessor functions (#10) --- src/interface.jl | 21 +++++++++++++++++++++ test/Project.toml | 1 + test/runtests.jl | 14 ++++++++++++-- 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 1eeaf2d..12f4278 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -50,3 +50,24 @@ function Base.show(io::IO, model::StatsLearnModel{M}) where {M} println(io, "├─ input: $(model.input)") print(io, "└─ output: $(model.output)") end + +""" + StatsLearnModels.model(lmodel::StatsLearnModel) + +Returns the model of the `lmodel`. +""" +model(lmodel::StatsLearnModel) = lmodel.model + +""" + StatsLearnModels.input(lmodel::StatsLearnModel) + +Returns the input column selection of the `lmodel`. +""" +input(lmodel::StatsLearnModel) = lmodel.input + +""" + StatsLearnModels.output(lmodel::StatsLearnModel) + +Returns the output column selection of the `lmodel`. +""" +output(lmodel::StatsLearnModel) = lmodel.output diff --git a/test/Project.toml b/test/Project.toml index ce7ea51..4a857b1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" diff --git a/test/runtests.jl b/test/runtests.jl index 087347b..4f9cb7c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,7 @@ using Test using GLM: ProbitLink using Distributions: Binomial +using ColumnSelectors: selector import MLJ, MLJDecisionTreeInterface @@ -24,8 +25,17 @@ const SLM = StatsLearnModels end @testset "StatsLearnModel" begin - model = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c) - @test sprint(show, model) == """ + # accessor functions + model = DecisionTreeClassifier() + incols = selector([:a, :b]) + outcols = selector(:c) + lmodel = SLM.StatsLearnModel(model, incols, outcols) + @test SLM.model(lmodel) === model + @test SLM.input(lmodel) === incols + @test SLM.output(lmodel) === outcols + # show + lmodel = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c) + @test sprint(show, lmodel) == """ StatsLearnModel{DecisionTreeClassifier} ├─ input: [:a, :b] └─ output: :c"""