diff --git a/src/regressionmodel.jl b/src/regressionmodel.jl index d7a2afc..2cc1510 100644 --- a/src/regressionmodel.jl +++ b/src/regressionmodel.jl @@ -40,6 +40,17 @@ Return the model matrix (a.k.a. the design matrix). """ function modelmatrix end +""" + hasintercept(model::RegressionModel) + +Indicate whether the model has an intercept. +""" +function hasintercept(model::RegressionModel) + X = modelmatrix(model) + any(i -> all(==(1), view(X , :, i)), 1:size(X, 2)) +end + + """ crossmodelmatrix(model::RegressionModel) diff --git a/test/regressionmodel.jl b/test/regressionmodel.jl index a8892fa..a18c556 100644 --- a/test/regressionmodel.jl +++ b/test/regressionmodel.jl @@ -1,7 +1,7 @@ module TestRegressionModel using Test, LinearAlgebra, StatsAPI -using StatsAPI: RegressionModel, crossmodelmatrix +using StatsAPI: RegressionModel, hasintercept, crossmodelmatrix struct MyRegressionModel <: RegressionModel end @@ -10,9 +10,10 @@ StatsAPI.modelmatrix(::MyRegressionModel) = [1 2; 3 4] @testset "TestRegressionModel" begin m = MyRegressionModel() - + + @test !hasintercept(m) @test crossmodelmatrix(m) == [10 14; 14 20] @test crossmodelmatrix(m) isa Symmetric end -end # module TestRegressionModel \ No newline at end of file +end # module TestRegressionModel