Skip to content

Commit

Permalink
Add StatsLearnModel (#9)
Browse files Browse the repository at this point in the history
* Add 'StatsLearnModel'

* Update 'Learn'
  • Loading branch information
eliascarv authored Dec 15, 2023
1 parent 88c84c8 commit 507684f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/StatsLearnModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Tables
using Distances
using DataScienceTraits
using StatsBase: mode, mean
using ColumnSelectors: selector
using ColumnSelectors: ColumnSelector, selector
using TableTransforms: StatelessFeatureTransform

import DataScienceTraits as DST
Expand Down
19 changes: 19 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,22 @@ struct FittedModel{M,C}
end

Base.show(io::IO, ::FittedModel{M}) where {M} = print(io, "FittedModel{$(nameof(M))}")

"""
StatsLearnModels.StatsLearnModel(model, incols, outcols)
Wrapper type for learning models used for dispatch purposes.
"""
struct StatsLearnModel{M,I<:ColumnSelector,O<:ColumnSelector}
model::M
input::I
output::O
end

StatsLearnModel(model, incols, outcols) = StatsLearnModel(model, selector(incols), selector(outcols))

function Base.show(io::IO, model::StatsLearnModel{M}) where {M}
println(io, "StatsLearnModel{$(nameof(M))}")
println(io, "├─ input: $(model.input)")
print(io, "└─ output: $(model.output)")
end
10 changes: 6 additions & 4 deletions src/learn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,22 @@ struct Learn{M<:FittedModel} <: StatelessFeatureTransform
input::Vector{Symbol}
end

function Learn(train, model, (incols, outcols)::Pair)
Learn(train, model, (incols, outcols)::Pair) = Learn(train, StatsLearnModel(model, incols, outcols))

function Learn(train, lmodel::StatsLearnModel)
if !Tables.istable(train)
throw(ArgumentError("training data must be a table"))
end

cols = Tables.columns(train)
names = Tables.columnnames(cols)
innms = selector(incols)(names)
outnms = selector(outcols)(names)
innms = lmodel.input(names)
outnms = lmodel.output(names)

input = (; (nm => Tables.getcolumn(cols, nm) for nm in innms)...)
output = (; (nm => Tables.getcolumn(cols, nm) for nm in outnms)...)

fmodel = fit(model, input, output)
fmodel = fit(lmodel.model, input, output)
Learn(fmodel, innms)
end

Expand Down
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ const SLM = StatsLearnModels
@test sprint(show, fmodel) == "FittedModel{DecisionTreeClassifier}"
end

@testset "StatsLearnModel" begin
model = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c)
@test sprint(show, model) == """
StatsLearnModel{DecisionTreeClassifier}
├─ input: [:a, :b]
└─ output: :c"""
end

@testset "models" begin
@testset "MLJ" begin
Random.seed!(123)
Expand Down

0 comments on commit 507684f

Please sign in to comment.