Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retain data in tabular form #45

Merged
merged 7 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 54 additions & 100 deletions src/MLJGLMInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import MLJModelInterface: metadata_pkg, metadata_model, Table, Continuous, Count, Finite,
OrderedFactor, Multiclass, @mlj_model
using Distributions: Bernoulli, Distribution, Poisson
using StatsModels: ConstantTerm, Term, FormulaTerm, term
using StatsModels: ConstantTerm, Term, FormulaTerm, term, modelcols
using Tables
import GLM

Expand Down Expand Up @@ -97,20 +97,6 @@
## Helper functions
###

"""
augment_X(X::AbstractMatrix, b::Bool)

Augment the matrix `X` with a column of ones if the intercept should
be fitted (`b==true`) and return `X` otherwise.
"""
function augment_X(X::AbstractMatrix, b::Bool)
if b
return hcat(X, ones(float(Int), size(X, 1), 1))
else
return X
end
end

_to_vector(v::Vector) = v
_to_vector(v) = collect(v)
_to_array(v::AbstractArray) = v
Expand Down Expand Up @@ -203,38 +189,15 @@
return nothing
end

function _matrix_and_features(model, Xcols, handle_intercept=false)
col_names = Tables.columnnames(Xcols)
n, p = Tables.rowcount(Xcols), length(col_names)
augment = handle_intercept && model.fit_intercept

if !handle_intercept # i.e This only runs during `fit`
check_sample_size(model, n, p)
end

if p == 0
Xmatrix = Matrix{float(Int)}(undef, n, p)
else
Xmatrix = Tables.matrix(Xcols)
end

Xmatrix = augment_X(Xmatrix, augment)

return Xmatrix, col_names
end

_to_columns(t::Tables.AbstractColumns) = t
_to_columns(t) = Tables.Columns(t)

"""
prepare_inputs(model, X; handle_intercept=false)

Handle `model.offsetcol` and `model.fit_intercept` if `handle_intercept=true`.
`handle_intercept` is disabled for fitting since the StatsModels.@formula handles the intercept.
"""
function prepare_inputs(model, X; handle_intercept=false)
Xcols = _to_columns(X)
table_features = Tables.columnnames(Xcols)
function prepare_inputs(model, X)
Xcols = Tables.columntable(X)
table_features = Base.keys(Xcols)
p = length(table_features)
p >= 1 || throw(
ArgumentError("`X` must contain at least one feature column.")
Expand All @@ -253,9 +216,11 @@
end
end
Xminoffset, offset = split_X_offset(Xcols, model.offsetcol)
Xminoffset_cols = _to_columns(Xminoffset)
Xmatrix, features = _matrix_and_features(model, Xminoffset_cols , handle_intercept)
return Xmatrix, offset, _to_array(features)
features = Tables.columnnames(Xminoffset)

check_sample_size(model, length(first(Xminoffset)), p)

return Xminoffset, offset, _to_array(features)
end

"""
Expand Down Expand Up @@ -285,14 +250,6 @@
end
if :coef_table in reportkeys
coef_table = GLM.coeftable(glm_model)
# Update the variable names in the `coef_table` with the actual variable
# names seen during fit.
if length(coef_table.rownms) == length(features)
# This means `fit_intercept` is false
coef_table.rownms = string.(features)
else
coef_table.rownms = [string.(features); "(Intercept)"]
end
report_dict[:coef_table] = coef_table
end
if :glm_model in reportkeys
Expand Down Expand Up @@ -364,7 +321,9 @@
#### FIT FUNCTIONS
####

struct FitResult{V<:AbstractVector, T, R}
struct FitResult{F, V<:AbstractVector, T, R}
"Formula containing all coefficients and their types"
formula::F
"Vector containg coeficients of the predictors and intercept"
coefs::V
"An estimate of the dispersion parameter of the glm model. "
Expand All @@ -373,20 +332,23 @@
params::R
end

FitResult(fitted_glm, features) = FitResult(GLM.formula(fitted_glm), GLM.coef(fitted_glm), GLM.dispersion(fitted_glm.model), (features = features,))

dispersion(fr::FitResult) = fr.dispersion
params(fr::FitResult) = fr.params
coefs(fr::FitResult) = fr.coefs

function MMI.fit(model::LinearRegressor, verbosity::Int, X, y, w=nothing)
# apply the model
Xmatrix, offset, features = prepare_inputs(model, X)
X_col_table, offset, features = prepare_inputs(model, X)
y_ = isempty(offset) ? y : y .- offset
wts = check_weights(w, y_)
data = glm_data(model, Xmatrix, y_, features)
data = merge(X_col_table, (; y = y_))
form = glm_formula(model, features)
fitted_lm = GLM.lm(form, data; model.dropcollinear, wts).model
fitresult = FitResult(
GLM.coef(fitted_lm), GLM.dispersion(fitted_lm), (features = features,)
)
fitted_lm = GLM.lm(form, data; model.dropcollinear, wts)

fitresult = FitResult(fitted_lm, features)

# form the report
report = glm_report(fitted_lm, features, model.report_keys)
cache = nothing
Expand All @@ -396,11 +358,11 @@

function MMI.fit(model::LinearCountRegressor, verbosity::Int, X, y, w=nothing)
# apply the model
Xmatrix, offset, features = prepare_inputs(model, X)
data = glm_data(model, Xmatrix, y, features)
X_col_table, offset, features = prepare_inputs(model, X)
data = merge(X_col_table, (; y))
wts = check_weights(w, y)
form = glm_formula(model, features)
fitted_glm_frame = GLM.glm(
fitted_glm = GLM.glm(
form, data, model.distribution, model.link;
offset,
model.maxiter,
Expand All @@ -409,10 +371,9 @@
model.minstepfac,
wts
)
fitted_glm = fitted_glm_frame.model
fitresult = FitResult(
GLM.coef(fitted_glm), GLM.dispersion(fitted_glm), (features = features,)
)

fitresult = FitResult(fitted_glm, features)

# form the report
report = glm_report(fitted_glm, features, model.report_keys)
cache = nothing
Expand All @@ -422,13 +383,13 @@

function MMI.fit(model::LinearBinaryClassifier, verbosity::Int, X, y, w=nothing)
# apply the model
decode = y[1]
decode = MMI.classes(y)
y_plain = MMI.int(y) .- 1 # 0, 1 of type Int
wts = check_weights(w, y_plain)
Xmatrix, offset, features = prepare_inputs(model, X)
data = glm_data(model, Xmatrix, y_plain, features)
X_col_table, offset, features = prepare_inputs(model, X)
data = merge(X_col_table, (; y = y_plain))
form = glm_formula(model, features)
fitted_glm_frame = GLM.glm(
fitted_glm = GLM.glm(
form, data, Bernoulli(), model.link;
offset,
model.maxiter,
Expand All @@ -437,10 +398,9 @@
model.minstepfac,
wts
)
fitted_glm = fitted_glm_frame.model
fitresult = FitResult(
GLM.coef(fitted_glm), GLM.dispersion(fitted_glm), (features = features,)
)

fitresult = FitResult(fitted_glm, features)

# form the report
report = glm_report(fitted_glm, features, model.report_keys)
cache = nothing
Expand All @@ -452,7 +412,6 @@
glm_fitresult(::LinearCountRegressor, fitresult) = fitresult
glm_fitresult(::LinearBinaryClassifier, fitresult) = fitresult[1]

coefs(fr::FitResult) = fr.coefs

function MMI.fitted_params(model::GLM_MODELS, fitresult)
result = glm_fitresult(model, fitresult)
Expand All @@ -468,38 +427,33 @@
return (; features, coef=coef_, intercept)
end


####
#### PREDICT FUNCTIONS
####

glm_link(model) = model.link
glm_link(::LinearRegressor) = GLM.IdentityLink()

# more efficient than MLJBase fallback
function MMI.predict_mean(model::GLM_MODELS, fitresult, Xnew)
Xmatrix, offset, _ = prepare_inputs(model, Xnew; handle_intercept=true)
result = glm_fitresult(model, fitresult) # ::FitResult
coef = coefs(result)
p = size(Xmatrix, 2)
if p != length(coef)
throw(
DimensionMismatch(
"The number of features in training and prediction datasets must be equal"
)
)
end
link = glm_link(model)
return glm_predict(link, coef, Xmatrix, model.offsetcol, offset)
function glm_predict(link, terms, coef, offsetcol::Nothing, Xnew)
mm = modelcols(terms, Xnew)
η = mm * coef
μ = GLM.linkinv.(link, η)
return μ
end

# barrier function to aid performance
function glm_predict(link, coef, Xmatrix, offsetcol, offset)
η = offsetcol === nothing ? (Xmatrix * coef) : (Xmatrix * coef .+ offset)
function glm_predict(link, terms, coef, offsetcol::Symbol, Xnew)
mm = modelcols(terms, Xnew)
offset = Tables.getcolumn(Xnew, offsetcol)
η = mm * coef .+ offset

Check warning on line 446 in src/MLJGLMInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/MLJGLMInterface.jl#L443-L446

Added lines #L443 - L446 were not covered by tests
μ = GLM.linkinv.(link, η)
return μ
end

# More efficient fallback. predict_mean is not defined for LinearBinaryClassifier
function MMI.predict_mean(model::Union{LinearRegressor, LinearCountRegressor}, fitresult, Xnew)
p = glm_predict(glm_link(model), fitresult.formula.rhs, fitresult.coefs, model.offsetcol, Xnew)
return p
end

function MMI.predict(model::LinearRegressor, fitresult, Xnew)
μ = MMI.predict_mean(model, fitresult, Xnew)
σ̂ = dispersion(fitresult)
Expand All @@ -512,8 +466,8 @@
end

function MMI.predict(model::LinearBinaryClassifier, (fitresult, decode), Xnew)
π = MMI.predict_mean(model, (fitresult, decode), Xnew)
return MMI.UnivariateFinite(MMI.classes(decode), π, augment=true)
p = glm_predict(glm_link(model), fitresult.formula.rhs, fitresult.coefs, model.offsetcol, Xnew)
return MMI.UnivariateFinite(decode, p, augment=true)
end

# NOTE: predict_mode uses MLJBase's fallback
Expand All @@ -539,23 +493,23 @@

metadata_model(
LinearRegressor,
input = Table(Continuous),
input = Table(Continuous, Finite),
target = AbstractVector{Continuous},
supports_weights = true,
path = "$PKG.LinearRegressor"
)

metadata_model(
LinearBinaryClassifier,
input = Table(Continuous),
input = Table(Continuous, Finite),
target = AbstractVector{<:Finite{2}},
supports_weights = true,
path = "$PKG.LinearBinaryClassifier"
)

metadata_model(
LinearCountRegressor,
input = Table(Continuous),
input = Table(Continuous, Finite),
target = AbstractVector{Count},
supports_weights = true,
path = "$PKG.LinearCountRegressor"
Expand Down
Loading
Loading