Skip to content

Commit

Permalink
raw_glm_model -> glm_model
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Jan 18, 2024
1 parent 12c9ff5 commit 6bba7aa
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
26 changes: 13 additions & 13 deletions src/MLJGLMInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ const VALID_KEYS = [
:stderror,
:vcov,
:coef_table,
:raw_glm_model,
:glm_model,
]
const VALID_KEYS_LIST = join(map(k-> "`:$k`", VALID_KEYS), ", ", " and ")
const DEFAULT_KEYS = setdiff(VALID_KEYS, [:raw_glm_model,])
const DEFAULT_KEYS = setdiff(VALID_KEYS, [:glm_model,])
const KEYS_TYPE = Union{Nothing, AbstractVector{Symbol}}

@mlj_model mutable struct LinearRegressor <: MMI.Probabilistic
Expand Down Expand Up @@ -295,8 +295,8 @@ function glm_report(glm_model, features, reportkeys)
end
report_dict[:coef_table] = coef_table
end
if :raw_glm_model in reportkeys
report_dict[:raw_glm_model] = glm_model
if :glm_model in reportkeys
report_dict[:glm_model] = glm_model
end

return NamedTuple{Tuple(keys(report_dict))}(values(report_dict))
Expand Down Expand Up @@ -602,7 +602,7 @@ Here
An offset is a variable which is known to have a coefficient of 1.
- `report_keys`: `Vector` of keys for the report. Possible keys are: $VALID_KEYS_LIST. By
default only `:raw_glm_model` is excluded.
default only `:glm_model` is excluded.
Train the machine using `fit!(mach, rows=...)`.
Expand Down Expand Up @@ -646,8 +646,8 @@ When all keys are enabled in `report_keys`, the following fields are available i
- `coef_table`: Table which displays coefficients and summarizes their significance
and confidence intervals.
- `raw_glm_model`: The raw fitted model returned by `GLM.lm`. Note this points to training
data.
- `glm_model`: The raw fitted model returned by `GLM.lm`. Note this points to training
data. Refer to the GLM.jl documentation for usage.
# Examples
Expand Down Expand Up @@ -729,7 +729,7 @@ Train the machine using `fit!(mach, rows=...)`.
the factor used to update the linear fit.
- `report_keys`: `Vector` of keys for the report. Possible keys are: $VALID_KEYS_LIST. By
default only `:raw_glm_model` is excluded.
default only `:glm_model` is excluded.
# Operations
Expand Down Expand Up @@ -765,8 +765,8 @@ The fields of `report(mach)` are:
- `coef_table`: Table which displays coefficients and summarizes their significance and
confidence intervals.
- `raw_glm_model`: The raw fitted model returned by `GLM.glm`. Note this points to training
data.
- `glm_model`: The raw fitted model returned by `GLM.lm`. Note this points to training
data. Refer to the GLM.jl documentation for usage.
# Examples
Expand Down Expand Up @@ -861,7 +861,7 @@ Train the machine using `fit!(mach, rows=...)`.
the factor used to update the linear fit.
- `report_keys`: `Vector` of keys for the report. Possible keys are: $VALID_KEYS_LIST. By
default only `:raw_glm_model` is excluded.
default only `:glm_model` is excluded.
# Operations
Expand Down Expand Up @@ -898,8 +898,8 @@ The fields of `report(mach)` are:
- `coef_table`: Table which displays coefficients and summarizes their significance and
confidence intervals.
- `raw_glm_model`: The raw fitted model returned by `GLM.glm`. Note this points to training
data.
- `glm_model`: The raw fitted model returned by `GLM.lm`. Note this points to training
data. Refer to the GLM.jl documentation for usage.
# Examples
Expand Down
10 changes: 5 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,24 +326,24 @@ end
X = (a=[1, 4, 3, 1], b=[2, 0, 1, 4], c=[7, 1, 7, 3])
y = categorical([true, false, true, false])
# check that by default all possible keys are added in the report,
# except raw_glm_model:
# except glm_model:
lr = LinearBinaryClassifier()
_, _, report = fit(lr, 1, X, y)
@test :deviance in keys(report)
@test :dof_residual in keys(report)
@test :stderror in keys(report)
@test :vcov in keys(report)
@test :coef_table in keys(report)
@test :raw_glm_model keys(report)
@test :glm_model keys(report)

# check that report is valid if only some keys are specified
lr = LinearBinaryClassifier(report_keys = [:stderror, :raw_glm_model])
lr = LinearBinaryClassifier(report_keys = [:stderror, :glm_model])
_, _, report = fit(lr, 1, X, y)
@test :deviance keys(report)
@test :stderror in keys(report)
@test :dof_residual keys(report)
@test :raw_glm_model in keys(report)
@test report.raw_glm_model isa GLM.GeneralizedLinearModel
@test :glm_model in keys(report)
@test report.glm_model isa GLM.GeneralizedLinearModel

# check that an empty `NamedTuple` is outputed for
# `report_params === nothing`
Expand Down

0 comments on commit 6bba7aa

Please sign in to comment.