Skip to content

Commit

Permalink
Add accessor functions (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
eliascarv authored Dec 15, 2023
1 parent 507684f commit fe0fda9
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
21 changes: 21 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,24 @@ function Base.show(io::IO, model::StatsLearnModel{M}) where {M}
println(io, "├─ input: $(model.input)")
print(io, "└─ output: $(model.output)")
end

"""
StatsLearnModels.model(lmodel::StatsLearnModel)
Returns the model of the `lmodel`.
"""
model(lmodel::StatsLearnModel) = lmodel.model

"""
StatsLearnModels.input(lmodel::StatsLearnModel)
Returns the input column selection of the `lmodel`.
"""
input(lmodel::StatsLearnModel) = lmodel.input

"""
StatsLearnModels.output(lmodel::StatsLearnModel)
Returns the output column selection of the `lmodel`.
"""
output(lmodel::StatsLearnModel) = lmodel.output
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
Expand Down
14 changes: 12 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Test

using GLM: ProbitLink
using Distributions: Binomial
using ColumnSelectors: selector

import MLJ, MLJDecisionTreeInterface

Expand All @@ -24,8 +25,17 @@ const SLM = StatsLearnModels
end

@testset "StatsLearnModel" begin
model = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c)
@test sprint(show, model) == """
# accessor functions
model = DecisionTreeClassifier()
incols = selector([:a, :b])
outcols = selector(:c)
lmodel = SLM.StatsLearnModel(model, incols, outcols)
@test SLM.model(lmodel) === model
@test SLM.input(lmodel) === incols
@test SLM.output(lmodel) === outcols
# show
lmodel = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c)
@test sprint(show, lmodel) == """
StatsLearnModel{DecisionTreeClassifier}
├─ input: [:a, :b]
└─ output: :c"""
Expand Down

0 comments on commit fe0fda9

Please sign in to comment.