Skip to content

Commit

Permalink
Merge pull request #41 from JuliaAI/raw-fitresult-expose
Browse files Browse the repository at this point in the history
Expose raw fitted GLM model in report
  • Loading branch information
ablaom authored Jan 18, 2024
2 parents 924d0e9 + 6bba7aa commit 7f48db6
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 24 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJGLMInterface"
uuid = "caf8df21-4939-456d-ac9c-5fefbfb04c0c"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.3.5"
version = "0.3.6"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -22,8 +22,9 @@ julia = "1.6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["LinearAlgebra", "MLJBase", "StableRNGs", "Statistics", "Test"]
test = ["LinearAlgebra", "MLJBase", "StableRNGs", "StatisticalMeasures", "Statistics", "Test"]
41 changes: 31 additions & 10 deletions src/MLJGLMInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,16 @@ const LCR_DESCR = "Linear count regressor with specified "*
# LinearBinaryClassifier --> Probabilistic w Binary target // logit,cauchit,..
# MulticlassClassifier --> Probabilistic w Multiclass target

const VALID_KEYS = [:deviance, :dof_residual, :stderror, :vcov, :coef_table]
const DEFAULT_KEYS = VALID_KEYS # For more understandable warning mssg by `@mlj_model`.
const VALID_KEYS = [
:deviance,
:dof_residual,
:stderror,
:vcov,
:coef_table,
:glm_model,
]
const VALID_KEYS_LIST = join(map(k-> "`:$k`", VALID_KEYS), ", ", " and ")
const DEFAULT_KEYS = setdiff(VALID_KEYS, [:glm_model,])
const KEYS_TYPE = Union{Nothing, AbstractVector{Symbol}}

@mlj_model mutable struct LinearRegressor <: MMI.Probabilistic
Expand Down Expand Up @@ -287,6 +295,10 @@ function glm_report(glm_model, features, reportkeys)
end
report_dict[:coef_table] = coef_table
end
if :glm_model in reportkeys
report_dict[:glm_model] = glm_model
end

return NamedTuple{Tuple(keys(report_dict))}(values(report_dict))
end

Expand Down Expand Up @@ -589,9 +601,8 @@ Here
- `offsetcol=nothing`: Name of the column to be used as an offset, if any.
An offset is a variable which is known to have a coefficient of 1.
- `report_keys::Union{Symbol, Nothing}=DEFAULT_KEYS`: vector of keys to be used in
the report. Should be one of: `:deviance`, `:dof_residual`, `:stderror`, `:vcov`,
`:coef_table`.
- `report_keys`: `Vector` of keys for the report. Possible keys are: $VALID_KEYS_LIST. By
default only `:glm_model` is excluded.
Train the machine using `fit!(mach, rows=...)`.
Expand Down Expand Up @@ -619,7 +630,8 @@ The fields of `fitted_params(mach)` are:
# Report
When all keys are enabled in `report_keys`, the following fields are available in `report(mach)`:
When all keys are enabled in `report_keys`, the following fields are available in
`report(mach)`:
- `deviance`: Measure of deviance of fitted model with respect to
a perfectly fitted model. For a linear model, this is the weighted
Expand All @@ -634,6 +646,9 @@ When all keys are enabled in `report_keys`, the following fields are available i
- `coef_table`: Table which displays coefficients and summarizes their significance
and confidence intervals.
- `glm_model`: The raw fitted model returned by `GLM.lm`. Note this points to training
data. Refer to the GLM.jl documentation for usage.
# Examples
```
Expand Down Expand Up @@ -713,8 +728,8 @@ Train the machine using `fit!(mach, rows=...)`.
- `minstepfac::Real=0.001`: Minimum step fraction. Must be between 0 and 1. Lower bound for
the factor used to update the linear fit.
- `report_keys::Union{Symbol, Nothing}=DEFAULT_KEYS`: keys to be used in the report. Should
be one of: `:deviance`, `:dof_residual`, `:stderror`, `:vcov`, `:coef_table`.
- `report_keys`: `Vector` of keys for the report. Possible keys are: $VALID_KEYS_LIST. By
default only `:glm_model` is excluded.
# Operations
Expand Down Expand Up @@ -750,6 +765,9 @@ The fields of `report(mach)` are:
- `coef_table`: Table which displays coefficients and summarizes their significance and
confidence intervals.
- `glm_model`: The raw fitted model returned by `GLM.lm`. Note this points to training
data. Refer to the GLM.jl documentation for usage.
# Examples
```
Expand Down Expand Up @@ -842,8 +860,8 @@ Train the machine using `fit!(mach, rows=...)`.
- `minstepfac::Real=0.001`: Minimum step fraction. Must be between 0 and 1. Lower bound for
the factor used to update the linear fit.
- `report_keys::Union{Symbol, Nothing}=DEFAULT_KEYS`: keys to be used in the report. Should
be one of: `:deviance`, `:dof_residual`, `:stderror`, `:vcov`, `:coef_table`.
- `report_keys`: `Vector` of keys for the report. Possible keys are: $VALID_KEYS_LIST. By
default only `:glm_model` is excluded.
# Operations
Expand Down Expand Up @@ -880,6 +898,9 @@ The fields of `report(mach)` are:
- `coef_table`: Table which displays coefficients and summarizes their significance and
confidence intervals.
- `glm_model`: The raw fitted model returned by `GLM.lm`. Note this points to training
data. Refer to the GLM.jl documentation for usage.
# Examples
Expand Down
30 changes: 18 additions & 12 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test

using MLJBase
using StatisticalMeasures
using LinearAlgebra
using Statistics
using MLJGLMInterface
Expand Down Expand Up @@ -66,7 +67,7 @@ expit(X) = 1 ./ (1 .+ exp.(-X))
@test hyp_types[2] == "Bool"
@test hyp_types[3] == "Union{Nothing, Symbol}"
@test hyp_types[4] == "Union{Nothing, AbstractVector{Symbol}}"

end

###
Expand All @@ -86,18 +87,18 @@ end

fitresult, _, report = fit(lr, 1, X, y)
yhat = predict(lr, fitresult, X)
@test mean(cross_entropy(yhat, y)) < 0.25
@test cross_entropy(yhat, y) < 0.25
fitresult1, _, report1 = fit(pr, 1, X, y)
yhat1 = predict(pr, fitresult1, X)
@test mean(cross_entropy(yhat1, y)) < 0.25
@test cross_entropy(yhat1, y) < 0.25

fitresultw, _, reportw = fit(lr, 1, X, y, w)
yhatw = predict(lr, fitresultw, X)
@test mean(cross_entropy(yhatw, y)) < 0.25
@test cross_entropy(yhatw, y) < 0.25
@test yhatw yhat
fitresultw1, _, reportw1 = fit(pr, 1, X, y, w)
yhatw1 = predict(pr, fitresultw1, X)
@test mean(cross_entropy(yhatw1, y)) < 0.25
@test cross_entropy(yhatw1, y) < 0.25
@test yhatw1 yhat1

# check predict on `Xnew` with wrong dims
Expand All @@ -124,6 +125,7 @@ end
@test hyper_params[6] == :rtol
@test hyper_params[7] == :minstepfac
@test hyper_params[8] == :report_keys

end

###
Expand All @@ -150,7 +152,7 @@ end
fitresultw, _, _ = fit(lcr, 1, XTable, y, w)
θ̂w = fitted_params(lcr, fitresultw).coef
@test norm(θ̂w .- θ)/norm(θ) 0.03
@test θ̂w θ̂
@test θ̂w θ̂

# check predict on `Xnew` with wrong dims
Xnew = MLJBase.table(
Expand Down Expand Up @@ -278,7 +280,7 @@ end
N = 1000
rng = StableRNGs.StableRNG(0)
X = MLJBase.table(rand(rng, N, 3))
y = 2*X.x1 + X.x2 - X.x3 + rand(rng, Normal(0,1), N)
y = 2*X.x1 + X.x2 - X.x3 + rand(rng, Normal(0,1), N)

lr = LinearRegressor(fit_intercept=false, offsetcol=:x2)
fitresult, _, report = fit(lr, 1, X, y)
Expand Down Expand Up @@ -312,7 +314,7 @@ end
@test parameters == ["a", "b", "c", "(Intercept)"]
intercept = ctable.cols[1][4]
yhat = predict(lr, fitresult, X)
@test mean(cross_entropy(yhat, y)) < 0.6
@test cross_entropy(yhat, y) < 0.6

fp = fitted_params(lr, fitresult)
@test fp.features == [:a, :b, :c]
Expand All @@ -323,21 +325,25 @@ end
@testset "Param names in report" begin
X = (a=[1, 4, 3, 1], b=[2, 0, 1, 4], c=[7, 1, 7, 3])
y = categorical([true, false, true, false])
# check that by default all possible keys are added in the report
# check that by default all possible keys are added in the report,
# except glm_model:
lr = LinearBinaryClassifier()
_, _, report = fit(lr, 1, X, y)
@test :deviance in keys(report)
@test :deviance in keys(report)
@test :dof_residual in keys(report)
@test :stderror in keys(report)
@test :vcov in keys(report)
@test :coef_table in keys(report)
@test :glm_model keys(report)

# check that report is valid if only some keys are specified
lr = LinearBinaryClassifier(report_keys = [:stderror, :deviance])
lr = LinearBinaryClassifier(report_keys = [:stderror, :glm_model])
_, _, report = fit(lr, 1, X, y)
@test :deviance in keys(report)
@test :deviance keys(report)
@test :stderror in keys(report)
@test :dof_residual keys(report)
@test :glm_model in keys(report)
@test report.glm_model isa GLM.GeneralizedLinearModel

# check that an empty `NamedTuple` is outputed for
# `report_params === nothing`
Expand Down

0 comments on commit 7f48db6

Please sign in to comment.