Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose raw fitted GLM model in report #41

Merged
merged 5 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ julia = "1.6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["LinearAlgebra", "MLJBase", "StableRNGs", "Statistics", "Test"]
test = ["LinearAlgebra", "MLJBase", "StableRNGs", "StatisticalMeasures", "Statistics", "Test"]
37 changes: 29 additions & 8 deletions src/MLJGLMInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ const LCR_DESCR = "Linear count regressor with specified "*
# LinearBinaryClassifier --> Probabilistic w Binary target // logit,cauchit,..
# MulticlassClassifier --> Probabilistic w Multiclass target

const VALID_KEYS = [:deviance, :dof_residual, :stderror, :vcov, :coef_table]
const VALID_KEYS = [
:deviance,
:dof_residual,
:stderror,
:vcov,
:coef_table,
:raw_glm_model,
]
const VALID_KEYS_LIST = join(map(k-> ":$k", VALID_KEYS), ", ", " and ")
const DEFAULT_KEYS = VALID_KEYS # For more understandable warning mssg by `@mlj_model`.
const KEYS_TYPE = Union{Nothing, AbstractVector{Symbol}}

Expand Down Expand Up @@ -287,6 +295,10 @@ function glm_report(glm_model, features, reportkeys)
end
report_dict[:coef_table] = coef_table
end
if :raw_glm_model in reportkeys
report_dict[:raw_glm_model] = glm_model
end

return NamedTuple{Tuple(keys(report_dict))}(values(report_dict))
end

Expand Down Expand Up @@ -590,8 +602,7 @@ Here
An offset is a variable which is known to have a coefficient of 1.

- `report_keys::Union{Symbol, Nothing}=DEFAULT_KEYS`: vector of keys to be used in
the report. Should be one of: `:deviance`, `:dof_residual`, `:stderror`, `:vcov`,
`:coef_table`.
the report. Possible keys are: $VALID_KEYS_LIST.

Train the machine using `fit!(mach, rows=...)`.

Expand Down Expand Up @@ -619,7 +630,8 @@ The fields of `fitted_params(mach)` are:

# Report

When all keys are enabled in `report_keys`, the following fields are available in `report(mach)`:
When all keys are enabled in `report_keys`, the following fields are available in
`report(mach)`:

- `deviance`: Measure of deviance of fitted model with respect to
a perfectly fitted model. For a linear model, this is the weighted
Expand All @@ -634,6 +646,9 @@ When all keys are enabled in `report_keys`, the following fields are available i
- `coef_table`: Table which displays coefficients and summarizes their significance
and confidence intervals.

- `raw_glm_model`: The raw fitted model returned by `GLM.lm`. Note this points to training
data.

# Examples

```
Expand Down Expand Up @@ -713,8 +728,8 @@ Train the machine using `fit!(mach, rows=...)`.
- `minstepfac::Real=0.001`: Minimum step fraction. Must be between 0 and 1. Lower bound for
the factor used to update the linear fit.

- `report_keys::Union{Symbol, Nothing}=DEFAULT_KEYS`: keys to be used in the report. Should
be one of: `:deviance`, `:dof_residual`, `:stderror`, `:vcov`, `:coef_table`.
- `report_keys::Union{Symbol, Nothing}=DEFAULT_KEYS`: keys to be used in the
report. Possible keys are: $VALID_KEYS_LIST.

# Operations

Expand Down Expand Up @@ -750,6 +765,9 @@ The fields of `report(mach)` are:
- `coef_table`: Table which displays coefficients and summarizes their significance and
confidence intervals.

- `raw_glm_model`: The raw fitted model returned by `GLM.glm`. Note this points to training
data.

# Examples

```
Expand Down Expand Up @@ -842,8 +860,8 @@ Train the machine using `fit!(mach, rows=...)`.
- `minstepfac::Real=0.001`: Minimum step fraction. Must be between 0 and 1. Lower bound for
the factor used to update the linear fit.

- `report_keys::Union{Symbol, Nothing}=DEFAULT_KEYS`: keys to be used in the report. Should
be one of: `:deviance`, `:dof_residual`, `:stderror`, `:vcov`, `:coef_table`.
- `report_keys::Union{Symbol, Nothing}=DEFAULT_KEYS`: keys to be used in the
report. Possible keys are: $VALID_KEYS_LIST.

# Operations

Expand Down Expand Up @@ -880,6 +898,9 @@ The fields of `report(mach)` are:
- `coef_table`: Table which displays coefficients and summarizes their significance and
confidence intervals.

- `raw_glm_model`: The raw fitted model returned by `GLM.glm`. Note this points to training
data.


# Examples

Expand Down
28 changes: 17 additions & 11 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test

using MLJBase
using StatisticalMeasures
using LinearAlgebra
using Statistics
using MLJGLMInterface
Expand Down Expand Up @@ -66,7 +67,7 @@ expit(X) = 1 ./ (1 .+ exp.(-X))
@test hyp_types[2] == "Bool"
@test hyp_types[3] == "Union{Nothing, Symbol}"
@test hyp_types[4] == "Union{Nothing, AbstractVector{Symbol}}"

end

###
Expand All @@ -86,18 +87,18 @@ end

fitresult, _, report = fit(lr, 1, X, y)
yhat = predict(lr, fitresult, X)
@test mean(cross_entropy(yhat, y)) < 0.25
@test cross_entropy(yhat, y) < 0.25
fitresult1, _, report1 = fit(pr, 1, X, y)
yhat1 = predict(pr, fitresult1, X)
@test mean(cross_entropy(yhat1, y)) < 0.25
@test cross_entropy(yhat1, y) < 0.25

fitresultw, _, reportw = fit(lr, 1, X, y, w)
yhatw = predict(lr, fitresultw, X)
@test mean(cross_entropy(yhatw, y)) < 0.25
@test cross_entropy(yhatw, y) < 0.25
@test yhatw ≈ yhat
fitresultw1, _, reportw1 = fit(pr, 1, X, y, w)
yhatw1 = predict(pr, fitresultw1, X)
@test mean(cross_entropy(yhatw1, y)) < 0.25
@test cross_entropy(yhatw1, y) < 0.25
@test yhatw1 ≈ yhat1

# check predict on `Xnew` with wrong dims
Expand All @@ -124,6 +125,7 @@ end
@test hyper_params[6] == :rtol
@test hyper_params[7] == :minstepfac
@test hyper_params[8] == :report_keys

end

###
Expand All @@ -150,7 +152,7 @@ end
fitresultw, _, _ = fit(lcr, 1, XTable, y, w)
θ̂w = fitted_params(lcr, fitresultw).coef
@test norm(θ̂w .- θ)/norm(θ) ≤ 0.03
@test θ̂w ≈ θ̂
@test θ̂w ≈ θ̂

# check predict on `Xnew` with wrong dims
Xnew = MLJBase.table(
Expand Down Expand Up @@ -278,7 +280,7 @@ end
N = 1000
rng = StableRNGs.StableRNG(0)
X = MLJBase.table(rand(rng, N, 3))
y = 2*X.x1 + X.x2 - X.x3 + rand(rng, Normal(0,1), N)
y = 2*X.x1 + X.x2 - X.x3 + rand(rng, Normal(0,1), N)

lr = LinearRegressor(fit_intercept=false, offsetcol=:x2)
fitresult, _, report = fit(lr, 1, X, y)
Expand Down Expand Up @@ -312,7 +314,7 @@ end
@test parameters == ["a", "b", "c", "(Intercept)"]
intercept = ctable.cols[1][4]
yhat = predict(lr, fitresult, X)
@test mean(cross_entropy(yhat, y)) < 0.6
@test cross_entropy(yhat, y) < 0.6

fp = fitted_params(lr, fitresult)
@test fp.features == [:a, :b, :c]
Expand All @@ -326,18 +328,22 @@ end
# check that by default all possible keys are added in the report
lr = LinearBinaryClassifier()
_, _, report = fit(lr, 1, X, y)
@test :deviance in keys(report)
@test :deviance in keys(report)
@test :dof_residual in keys(report)
@test :stderror in keys(report)
@test :vcov in keys(report)
@test :coef_table in keys(report)
@test :raw_glm_model in keys(report)

@test report.raw_glm_model isa GLM.GeneralizedLinearModel

# check that report is valid if only some keys are specified
lr = LinearBinaryClassifier(report_keys = [:stderror, :deviance])
_, _, report = fit(lr, 1, X, y)
@test :deviance in keys(report)
@test :deviance in keys(report)
@test :stderror in keys(report)
@test :dof_residual ∉ keys(report)
@test :dof_residua ∉ keys(report)
@test :raw_glm_model ∉ keys(report)

# check that an empty `NamedTuple` is outputed for
# `report_params === nothing`
Expand Down
Loading