diff --git a/Project.toml b/Project.toml index 16cf07b..1cf59e7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJGLMInterface" uuid = "caf8df21-4939-456d-ac9c-5fefbfb04c0c" authors = ["Anthony D. Blaom "] -version = "0.3.5" +version = "0.3.6" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -22,8 +22,9 @@ julia = "1.6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["LinearAlgebra", "MLJBase", "StableRNGs", "Statistics", "Test"] +test = ["LinearAlgebra", "MLJBase", "StableRNGs", "StatisticalMeasures", "Statistics", "Test"] diff --git a/src/MLJGLMInterface.jl b/src/MLJGLMInterface.jl index 0315bac..7a682e3 100644 --- a/src/MLJGLMInterface.jl +++ b/src/MLJGLMInterface.jl @@ -45,8 +45,16 @@ const LCR_DESCR = "Linear count regressor with specified "* # LinearBinaryClassifier --> Probabilistic w Binary target // logit,cauchit,.. # MulticlassClassifier --> Probabilistic w Multiclass target -const VALID_KEYS = [:deviance, :dof_residual, :stderror, :vcov, :coef_table] -const DEFAULT_KEYS = VALID_KEYS # For more understandable warning mssg by `@mlj_model`. +const VALID_KEYS = [ + :deviance, + :dof_residual, + :stderror, + :vcov, + :coef_table, + :glm_model, +] +const VALID_KEYS_LIST = join(map(k-> "`:$k`", VALID_KEYS), ", ", " and ") +const DEFAULT_KEYS = setdiff(VALID_KEYS, [:glm_model,]) const KEYS_TYPE = Union{Nothing, AbstractVector{Symbol}} @mlj_model mutable struct LinearRegressor <: MMI.Probabilistic @@ -287,6 +295,10 @@ function glm_report(glm_model, features, reportkeys) end report_dict[:coef_table] = coef_table end + if :glm_model in reportkeys + report_dict[:glm_model] = glm_model + end + return NamedTuple{Tuple(keys(report_dict))}(values(report_dict)) end @@ -589,9 +601,8 @@ Here - `offsetcol=nothing`: Name of the column to be used as an offset, if any. An offset is a variable which is known to have a coefficient of 1. -- `report_keys::Union{Symbol, Nothing}=DEFAULT_KEYS`: vector of keys to be used in - the report. Should be one of: `:deviance`, `:dof_residual`, `:stderror`, `:vcov`, - `:coef_table`. +- `report_keys`: `Vector` of keys for the report. Possible keys are: $VALID_KEYS_LIST. By + default only `:glm_model` is excluded. Train the machine using `fit!(mach, rows=...)`. @@ -619,7 +630,8 @@ The fields of `fitted_params(mach)` are: # Report -When all keys are enabled in `report_keys`, the following fields are available in `report(mach)`: +When all keys are enabled in `report_keys`, the following fields are available in +`report(mach)`: - `deviance`: Measure of deviance of fitted model with respect to a perfectly fitted model. For a linear model, this is the weighted @@ -634,6 +646,9 @@ When all keys are enabled in `report_keys`, the following fields are available i - `coef_table`: Table which displays coefficients and summarizes their significance and confidence intervals. +- `glm_model`: The raw fitted model returned by `GLM.lm`. Note this points to training + data. Refer to the GLM.jl documentation for usage. + # Examples ``` @@ -713,8 +728,8 @@ Train the machine using `fit!(mach, rows=...)`. - `minstepfac::Real=0.001`: Minimum step fraction. Must be between 0 and 1. Lower bound for the factor used to update the linear fit. -- `report_keys::Union{Symbol, Nothing}=DEFAULT_KEYS`: keys to be used in the report. Should - be one of: `:deviance`, `:dof_residual`, `:stderror`, `:vcov`, `:coef_table`. +- `report_keys`: `Vector` of keys for the report. Possible keys are: $VALID_KEYS_LIST. By + default only `:glm_model` is excluded. # Operations @@ -750,6 +765,9 @@ The fields of `report(mach)` are: - `coef_table`: Table which displays coefficients and summarizes their significance and confidence intervals. +- `glm_model`: The raw fitted model returned by `GLM.lm`. Note this points to training + data. Refer to the GLM.jl documentation for usage. + # Examples ``` @@ -842,8 +860,8 @@ Train the machine using `fit!(mach, rows=...)`. - `minstepfac::Real=0.001`: Minimum step fraction. Must be between 0 and 1. Lower bound for the factor used to update the linear fit. -- `report_keys::Union{Symbol, Nothing}=DEFAULT_KEYS`: keys to be used in the report. Should - be one of: `:deviance`, `:dof_residual`, `:stderror`, `:vcov`, `:coef_table`. +- `report_keys`: `Vector` of keys for the report. Possible keys are: $VALID_KEYS_LIST. By + default only `:glm_model` is excluded. # Operations @@ -880,6 +898,9 @@ The fields of `report(mach)` are: - `coef_table`: Table which displays coefficients and summarizes their significance and confidence intervals. +- `glm_model`: The raw fitted model returned by `GLM.lm`. Note this points to training + data. Refer to the GLM.jl documentation for usage. + # Examples diff --git a/test/runtests.jl b/test/runtests.jl index 519cef4..a266fc6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Test using MLJBase +using StatisticalMeasures using LinearAlgebra using Statistics using MLJGLMInterface @@ -66,7 +67,7 @@ expit(X) = 1 ./ (1 .+ exp.(-X)) @test hyp_types[2] == "Bool" @test hyp_types[3] == "Union{Nothing, Symbol}" @test hyp_types[4] == "Union{Nothing, AbstractVector{Symbol}}" - + end ### @@ -86,18 +87,18 @@ end fitresult, _, report = fit(lr, 1, X, y) yhat = predict(lr, fitresult, X) - @test mean(cross_entropy(yhat, y)) < 0.25 + @test cross_entropy(yhat, y) < 0.25 fitresult1, _, report1 = fit(pr, 1, X, y) yhat1 = predict(pr, fitresult1, X) - @test mean(cross_entropy(yhat1, y)) < 0.25 + @test cross_entropy(yhat1, y) < 0.25 fitresultw, _, reportw = fit(lr, 1, X, y, w) yhatw = predict(lr, fitresultw, X) - @test mean(cross_entropy(yhatw, y)) < 0.25 + @test cross_entropy(yhatw, y) < 0.25 @test yhatw ≈ yhat fitresultw1, _, reportw1 = fit(pr, 1, X, y, w) yhatw1 = predict(pr, fitresultw1, X) - @test mean(cross_entropy(yhatw1, y)) < 0.25 + @test cross_entropy(yhatw1, y) < 0.25 @test yhatw1 ≈ yhat1 # check predict on `Xnew` with wrong dims @@ -124,6 +125,7 @@ end @test hyper_params[6] == :rtol @test hyper_params[7] == :minstepfac @test hyper_params[8] == :report_keys + end ### @@ -150,7 +152,7 @@ end fitresultw, _, _ = fit(lcr, 1, XTable, y, w) θ̂w = fitted_params(lcr, fitresultw).coef @test norm(θ̂w .- θ)/norm(θ) ≤ 0.03 - @test θ̂w ≈ θ̂ + @test θ̂w ≈ θ̂ # check predict on `Xnew` with wrong dims Xnew = MLJBase.table( @@ -278,7 +280,7 @@ end N = 1000 rng = StableRNGs.StableRNG(0) X = MLJBase.table(rand(rng, N, 3)) - y = 2*X.x1 + X.x2 - X.x3 + rand(rng, Normal(0,1), N) + y = 2*X.x1 + X.x2 - X.x3 + rand(rng, Normal(0,1), N) lr = LinearRegressor(fit_intercept=false, offsetcol=:x2) fitresult, _, report = fit(lr, 1, X, y) @@ -312,7 +314,7 @@ end @test parameters == ["a", "b", "c", "(Intercept)"] intercept = ctable.cols[1][4] yhat = predict(lr, fitresult, X) - @test mean(cross_entropy(yhat, y)) < 0.6 + @test cross_entropy(yhat, y) < 0.6 fp = fitted_params(lr, fitresult) @test fp.features == [:a, :b, :c] @@ -323,21 +325,25 @@ end @testset "Param names in report" begin X = (a=[1, 4, 3, 1], b=[2, 0, 1, 4], c=[7, 1, 7, 3]) y = categorical([true, false, true, false]) - # check that by default all possible keys are added in the report + # check that by default all possible keys are added in the report, + # except glm_model: lr = LinearBinaryClassifier() _, _, report = fit(lr, 1, X, y) - @test :deviance in keys(report) + @test :deviance in keys(report) @test :dof_residual in keys(report) @test :stderror in keys(report) @test :vcov in keys(report) @test :coef_table in keys(report) + @test :glm_model ∉ keys(report) # check that report is valid if only some keys are specified - lr = LinearBinaryClassifier(report_keys = [:stderror, :deviance]) + lr = LinearBinaryClassifier(report_keys = [:stderror, :glm_model]) _, _, report = fit(lr, 1, X, y) - @test :deviance in keys(report) + @test :deviance ∉ keys(report) @test :stderror in keys(report) @test :dof_residual ∉ keys(report) + @test :glm_model in keys(report) + @test report.glm_model isa GLM.GeneralizedLinearModel # check that an empty `NamedTuple` is outputed for # `report_params === nothing`