Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retain data in tabular form #45

Merged
merged 7 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 58 additions & 124 deletions src/MLJGLMInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import MLJModelInterface
import MLJModelInterface: metadata_pkg, metadata_model, Table, Continuous, Count, Finite,
OrderedFactor, Multiclass, @mlj_model
using Distributions: Bernoulli, Distribution, Poisson
using StatsModels: ConstantTerm, Term, FormulaTerm, term
using StatsModels: ConstantTerm, Term, FormulaTerm, term, modelcols
using Tables
import GLM

Expand Down Expand Up @@ -97,20 +97,6 @@ const GLM_MODELS = Union{
## Helper functions
###

"""
augment_X(X::AbstractMatrix, b::Bool)

Augment the matrix `X` with a column of ones if the intercept should
be fitted (`b==true`) and return `X` otherwise.
"""
function augment_X(X::AbstractMatrix, b::Bool)
if b
return hcat(X, ones(float(Int), size(X, 1), 1))
else
return X
end
end

_to_vector(v::Vector) = v
_to_vector(v) = collect(v)
_to_array(v::AbstractArray) = v
Expand Down Expand Up @@ -203,38 +189,15 @@ function check_sample_size(model, n, p)
return nothing
end

function _matrix_and_features(model, Xcols, handle_intercept=false)
col_names = Tables.columnnames(Xcols)
n, p = Tables.rowcount(Xcols), length(col_names)
augment = handle_intercept && model.fit_intercept

if !handle_intercept # i.e This only runs during `fit`
check_sample_size(model, n, p)
end

if p == 0
Xmatrix = Matrix{float(Int)}(undef, n, p)
else
Xmatrix = Tables.matrix(Xcols)
end

Xmatrix = augment_X(Xmatrix, augment)

return Xmatrix, col_names
end

_to_columns(t::Tables.AbstractColumns) = t
_to_columns(t) = Tables.Columns(t)

"""
prepare_inputs(model, X; handle_intercept=false)

Handle `model.offsetcol` and `model.fit_intercept` if `handle_intercept=true`.
`handle_intercept` is disabled for fitting since the StatsModels.@formula handles the intercept.
"""
function prepare_inputs(model, X; handle_intercept=false)
Xcols = _to_columns(X)
table_features = Tables.columnnames(Xcols)
function prepare_inputs(model, X)
Xcols = Tables.columntable(X)
table_features = Base.keys(Xcols)
p = length(table_features)
p >= 1 || throw(
ArgumentError("`X` must contain at least one feature column.")
Expand All @@ -253,9 +216,11 @@ function prepare_inputs(model, X; handle_intercept=false)
end
end
Xminoffset, offset = split_X_offset(Xcols, model.offsetcol)
Xminoffset_cols = _to_columns(Xminoffset)
Xmatrix, features = _matrix_and_features(model, Xminoffset_cols , handle_intercept)
return Xmatrix, offset, _to_array(features)
features = Tables.columnnames(Xminoffset)

check_sample_size(model, length(first(Xminoffset)), p)

return Xminoffset, offset, _to_array(features)
end

"""
Expand Down Expand Up @@ -285,14 +250,6 @@ function glm_report(glm_model, features, reportkeys)
end
if :coef_table in reportkeys
coef_table = GLM.coeftable(glm_model)
# Update the variable names in the `coef_table` with the actual variable
# names seen during fit.
if length(coef_table.rownms) == length(features)
# This means `fit_intercept` is false
coef_table.rownms = string.(features)
else
coef_table.rownms = [string.(features); "(Intercept)"]
end
report_dict[:coef_table] = coef_table
end
if :glm_model in reportkeys
Expand All @@ -302,15 +259,6 @@ function glm_report(glm_model, features, reportkeys)
return NamedTuple{Tuple(keys(report_dict))}(values(report_dict))
end

function _ensure_intercept_at_end(form::FormulaTerm)
fixed_rhs = if first(form.rhs) isa ConstantTerm
(form.rhs[2:end]..., form.rhs[1])
else
form.rhs
end
return FormulaTerm(form.lhs, fixed_rhs)
end

"""
glm_formula(model, features::AbstractVector{Symbol}) -> FormulaTerm

Expand All @@ -321,19 +269,8 @@ function glm_formula(model, features::AbstractVector{Symbol})::FormulaTerm
# Adding a zero term explicitly disables the intercept.
# See the StatsModels.jl tests for more information.
intercept_term = model.fit_intercept ? 1 : 0
form = FormulaTerm(Term(:y), sum(term.(features)) + term(intercept_term))
fixed_form = _ensure_intercept_at_end(form)
return fixed_form
end

"""
glm_data(model, Xmatrix, y, features)

Return data which is ready to be passed to `fit(form, data, ...)`.
"""
function glm_data(model, Xmatrix, y, features)
data = Tables.table([Xmatrix y]; header=[features...; :y])
return data
form = FormulaTerm(Term(:y), term(intercept_term) + sum(term.(features)))
return form
end

"""
Expand Down Expand Up @@ -364,7 +301,9 @@ end
#### FIT FUNCTIONS
####

struct FitResult{V<:AbstractVector, T, R}
struct FitResult{F, V<:AbstractVector, T, R}
"Formula containing all coefficients and their types"
formula::F
"Vector containg coeficients of the predictors and intercept"
coefs::V
"An estimate of the dispersion parameter of the glm model. "
Expand All @@ -373,20 +312,23 @@ struct FitResult{V<:AbstractVector, T, R}
params::R
end

FitResult(fitted_glm, features) = FitResult(GLM.formula(fitted_glm), GLM.coef(fitted_glm), GLM.dispersion(fitted_glm.model), (features = features,))

dispersion(fr::FitResult) = fr.dispersion
params(fr::FitResult) = fr.params
coefs(fr::FitResult) = fr.coefs

function MMI.fit(model::LinearRegressor, verbosity::Int, X, y, w=nothing)
# apply the model
Xmatrix, offset, features = prepare_inputs(model, X)
X_col_table, offset, features = prepare_inputs(model, X)
y_ = isempty(offset) ? y : y .- offset
wts = check_weights(w, y_)
data = glm_data(model, Xmatrix, y_, features)
data = merge(X_col_table, (; y = y_))
form = glm_formula(model, features)
fitted_lm = GLM.lm(form, data; model.dropcollinear, wts).model
fitresult = FitResult(
GLM.coef(fitted_lm), GLM.dispersion(fitted_lm), (features = features,)
)
fitted_lm = GLM.lm(form, data; model.dropcollinear, wts)

fitresult = FitResult(fitted_lm, features)

# form the report
report = glm_report(fitted_lm, features, model.report_keys)
cache = nothing
Expand All @@ -396,11 +338,11 @@ end

function MMI.fit(model::LinearCountRegressor, verbosity::Int, X, y, w=nothing)
# apply the model
Xmatrix, offset, features = prepare_inputs(model, X)
data = glm_data(model, Xmatrix, y, features)
X_col_table, offset, features = prepare_inputs(model, X)
data = merge(X_col_table, (; y))
wts = check_weights(w, y)
form = glm_formula(model, features)
fitted_glm_frame = GLM.glm(
fitted_glm = GLM.glm(
form, data, model.distribution, model.link;
offset,
model.maxiter,
Expand All @@ -409,10 +351,9 @@ function MMI.fit(model::LinearCountRegressor, verbosity::Int, X, y, w=nothing)
model.minstepfac,
wts
)
fitted_glm = fitted_glm_frame.model
fitresult = FitResult(
GLM.coef(fitted_glm), GLM.dispersion(fitted_glm), (features = features,)
)

fitresult = FitResult(fitted_glm, features)

# form the report
report = glm_report(fitted_glm, features, model.report_keys)
cache = nothing
Expand All @@ -422,13 +363,13 @@ end

function MMI.fit(model::LinearBinaryClassifier, verbosity::Int, X, y, w=nothing)
# apply the model
decode = y[1]
decode = MMI.classes(y)
y_plain = MMI.int(y) .- 1 # 0, 1 of type Int
wts = check_weights(w, y_plain)
Xmatrix, offset, features = prepare_inputs(model, X)
data = glm_data(model, Xmatrix, y_plain, features)
X_col_table, offset, features = prepare_inputs(model, X)
data = merge(X_col_table, (; y = y_plain))
form = glm_formula(model, features)
fitted_glm_frame = GLM.glm(
fitted_glm = GLM.glm(
form, data, Bernoulli(), model.link;
offset,
model.maxiter,
Expand All @@ -437,10 +378,9 @@ function MMI.fit(model::LinearBinaryClassifier, verbosity::Int, X, y, w=nothing)
model.minstepfac,
wts
)
fitted_glm = fitted_glm_frame.model
fitresult = FitResult(
GLM.coef(fitted_glm), GLM.dispersion(fitted_glm), (features = features,)
)

fitresult = FitResult(fitted_glm, features)

# form the report
report = glm_report(fitted_glm, features, model.report_keys)
cache = nothing
Expand All @@ -452,54 +392,48 @@ glm_fitresult(::LinearRegressor, fitresult) = fitresult
glm_fitresult(::LinearCountRegressor, fitresult) = fitresult
glm_fitresult(::LinearBinaryClassifier, fitresult) = fitresult[1]

coefs(fr::FitResult) = fr.coefs

function MMI.fitted_params(model::GLM_MODELS, fitresult)
result = glm_fitresult(model, fitresult)
coef = coefs(result)
features = copy(params(result).features)
if model.fit_intercept
intercept = coef[end]
coef_ = coef[1:end-1]
intercept = coef[1]
coef_ = coef[2:end]
else
intercept = zero(eltype(coef))
coef_ = copy(coef)
end
return (; features, coef=coef_, intercept)
end


####
#### PREDICT FUNCTIONS
####

glm_link(model) = model.link
glm_link(::LinearRegressor) = GLM.IdentityLink()

# more efficient than MLJBase fallback
function MMI.predict_mean(model::GLM_MODELS, fitresult, Xnew)
Xmatrix, offset, _ = prepare_inputs(model, Xnew; handle_intercept=true)
result = glm_fitresult(model, fitresult) # ::FitResult
coef = coefs(result)
p = size(Xmatrix, 2)
if p != length(coef)
throw(
DimensionMismatch(
"The number of features in training and prediction datasets must be equal"
)
)
end
link = glm_link(model)
return glm_predict(link, coef, Xmatrix, model.offsetcol, offset)
function glm_predict(link, terms, coef, offsetcol::Nothing, Xnew)
mm = modelcols(terms, Xnew)
η = mm * coef
μ = GLM.linkinv.(link, η)
return μ
end

# barrier function to aid performance
function glm_predict(link, coef, Xmatrix, offsetcol, offset)
η = offsetcol === nothing ? (Xmatrix * coef) : (Xmatrix * coef .+ offset)
function glm_predict(link, terms, coef, offsetcol::Symbol, Xnew)
mm = modelcols(terms, Xnew)
offset = Tables.getcolumn(Xnew, offsetcol)
η = mm * coef .+ offset
μ = GLM.linkinv.(link, η)
return μ
end

# More efficient fallback. predict_mean is not defined for LinearBinaryClassifier
function MMI.predict_mean(model::Union{LinearRegressor, LinearCountRegressor}, fitresult, Xnew)
p = glm_predict(glm_link(model), fitresult.formula.rhs, fitresult.coefs, model.offsetcol, Xnew)
return p
end

function MMI.predict(model::LinearRegressor, fitresult, Xnew)
μ = MMI.predict_mean(model, fitresult, Xnew)
σ̂ = dispersion(fitresult)
Expand All @@ -512,8 +446,8 @@ function MMI.predict(model::LinearCountRegressor, fitresult, Xnew)
end

function MMI.predict(model::LinearBinaryClassifier, (fitresult, decode), Xnew)
π = MMI.predict_mean(model, (fitresult, decode), Xnew)
return MMI.UnivariateFinite(MMI.classes(decode), π, augment=true)
p = glm_predict(glm_link(model), fitresult.formula.rhs, fitresult.coefs, model.offsetcol, Xnew)
return MMI.UnivariateFinite(decode, p, augment=true)
end

# NOTE: predict_mode uses MLJBase's fallback
Expand All @@ -539,23 +473,23 @@ metadata_pkg.(

metadata_model(
LinearRegressor,
input = Table(Continuous),
input = Table(Continuous, Finite),
target = AbstractVector{Continuous},
supports_weights = true,
path = "$PKG.LinearRegressor"
)

metadata_model(
LinearBinaryClassifier,
input = Table(Continuous),
input = Table(Continuous, Finite),
target = AbstractVector{<:Finite{2}},
supports_weights = true,
path = "$PKG.LinearBinaryClassifier"
)

metadata_model(
LinearCountRegressor,
input = Table(Continuous),
input = Table(Continuous, Finite),
target = AbstractVector{Count},
supports_weights = true,
path = "$PKG.LinearCountRegressor"
Expand Down
Loading
Loading