Skip to content

Commit

Permalink
New Feature: Fix and improve coeftable for otpimize() output (#2034)
Browse files Browse the repository at this point in the history
* Fix #2033

* Propose additional columns

* Add other suggestions

* Update src/modes/OptimInterface.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/modes/OptimInterface.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/modes/OptimInterface.jl

Co-authored-by: David Widmann <[email protected]>

* bump version

* import SsStatsAPI

* Update Project.toml

* Update src/modes/OptimInterface.jl

Co-authored-by: David Widmann <[email protected]>

---------

Co-authored-by: David Widmann <[email protected]>
Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
3 people authored Jul 12, 2023
1 parent 90e1d21 commit 59e5cce
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 47 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.26.3"
version = "0.26.4"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -32,6 +32,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand Down Expand Up @@ -62,6 +63,7 @@ Requires = "0.5, 1.0"
SciMLBase = "1.37.1"
Setfield = "0.8, 1"
SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2"
StatsAPI = "1.6"
StatsBase = "0.32, 0.33, 0.34"
StatsFuns = "0.8, 0.9, 1"
Tracker = "0.2.3"
Expand Down
102 changes: 56 additions & 46 deletions src/modes/OptimInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,32 @@ import ..ForwardDiff
import NamedArrays
import StatsBase
import Printf
import StatsAPI


"""
ModeResult{
V<:NamedArrays.NamedArray,
M<:NamedArrays.NamedArray,
O<:Optim.MultivariateOptimizationResults,
V<:NamedArrays.NamedArray,
M<:NamedArrays.NamedArray,
O<:Optim.MultivariateOptimizationResults,
S<:NamedArrays.NamedArray
}
A wrapper struct to store various results from a MAP or MLE estimation.
"""
struct ModeResult{
V<:NamedArrays.NamedArray,
V<:NamedArrays.NamedArray,
O<:Optim.MultivariateOptimizationResults,
M<:OptimLogDensity
} <: StatsBase.StatisticalModel
"A vector with the resulting point estimates."
values :: V
values::V
"The stored Optim.jl results."
optim_result :: O
optim_result::O
"The final log likelihood or log joint, depending on whether `MAP` or `MLE` was run."
lp :: Float64
lp::Float64
"The evaluation function used to calculate the output."
f :: M
f::M
end
#############################
# Various StatsBase methods #
Expand All @@ -50,14 +51,23 @@ function Base.show(io::IO, m::ModeResult)
show(io, m.values.array)
end

function StatsBase.coeftable(m::ModeResult)
function StatsBase.coeftable(m::ModeResult; level::Real=0.95)
# Get columns for coeftable.
terms = StatsBase.coefnames(m)
estimates = m.values.array[:,1]
terms = string.(StatsBase.coefnames(m))
estimates = m.values.array[:, 1]
stderrors = StatsBase.stderror(m)
tstats = estimates ./ stderrors

StatsBase.CoefTable([estimates, stderrors, tstats], ["estimate", "stderror", "tstat"], terms)
zscore = estimates ./ stderrors
p = map(z -> StatsAPI.pvalue(Normal(), z; tail=:both), zscore)

# Confidence interval (CI)
q = quantile(Normal(), (1 + level) / 2)
ci_low = estimates .- q .* stderrors
ci_high = estimates .+ q .* stderrors

StatsBase.CoefTable(
[estimates, stderrors, zscore, p, ci_low, ci_high],
["Coef.", "Std. Error", "z", "Pr(>|z|)", "Lower 95%", "Upper 95%"],
terms)
end

function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff.hessian, kwargs...)
Expand Down Expand Up @@ -113,7 +123,7 @@ mle = optimize(model, MLE())
mle = optimize(model, MLE(), NelderMead())
```
"""
function Optim.optimize(model::Model, ::MLE, options::Optim.Options=Optim.Options(); kwargs...)
function Optim.optimize(model::Model, ::MLE, options::Optim.Options=Optim.Options(); kwargs...)
return _mle_optimize(model, options; kwargs...)
end
function Optim.optimize(model::Model, ::MLE, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...)
Expand All @@ -123,11 +133,11 @@ function Optim.optimize(model::Model, ::MLE, optimizer::Optim.AbstractOptimizer,
return _mle_optimize(model, optimizer, options; kwargs...)
end
function Optim.optimize(
model::Model,
::MLE,
init_vals::AbstractArray,
optimizer::Optim.AbstractOptimizer,
options::Optim.Options=Optim.Options();
model::Model,
::MLE,
init_vals::AbstractArray,
optimizer::Optim.AbstractOptimizer,
options::Optim.Options=Optim.Options();
kwargs...
)
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
Expand Down Expand Up @@ -159,7 +169,7 @@ map_est = optimize(model, MAP(), NelderMead())
```
"""

function Optim.optimize(model::Model, ::MAP, options::Optim.Options=Optim.Options(); kwargs...)
function Optim.optimize(model::Model, ::MAP, options::Optim.Options=Optim.Options(); kwargs...)
return _map_optimize(model, options; kwargs...)
end
function Optim.optimize(model::Model, ::MAP, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...)
Expand All @@ -169,11 +179,11 @@ function Optim.optimize(model::Model, ::MAP, optimizer::Optim.AbstractOptimizer,
return _map_optimize(model, optimizer, options; kwargs...)
end
function Optim.optimize(
model::Model,
::MAP,
init_vals::AbstractArray,
optimizer::Optim.AbstractOptimizer,
options::Optim.Options=Optim.Options();
model::Model,
::MAP,
init_vals::AbstractArray,
optimizer::Optim.AbstractOptimizer,
options::Optim.Options=Optim.Options();
kwargs...
)
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
Expand All @@ -190,43 +200,43 @@ end
Estimate a mode, i.e., compute a MLE or MAP estimate.
"""
function _optimize(
model::Model,
f::OptimLogDensity,
optimizer::Optim.AbstractOptimizer = Optim.LBFGS(),
args...;
model::Model,
f::OptimLogDensity,
optimizer::Optim.AbstractOptimizer=Optim.LBFGS(),
args...;
kwargs...
)
return _optimize(model, f, DynamicPPL.getparams(f), optimizer, args...; kwargs...)
end

function _optimize(
model::Model,
f::OptimLogDensity,
options::Optim.Options = Optim.Options(),
args...;
model::Model,
f::OptimLogDensity,
options::Optim.Options=Optim.Options(),
args...;
kwargs...
)
return _optimize(model, f, DynamicPPL.getparams(f), Optim.LBFGS(), args...; kwargs...)
end

function _optimize(
model::Model,
f::OptimLogDensity,
init_vals::AbstractArray = DynamicPPL.getparams(f),
options::Optim.Options = Optim.Options(),
args...;
model::Model,
f::OptimLogDensity,
init_vals::AbstractArray=DynamicPPL.getparams(f),
options::Optim.Options=Optim.Options(),
args...;
kwargs...
)
return _optimize(model, f, init_vals, Optim.LBFGS(), options, args...; kwargs...)
end

function _optimize(
model::Model,
f::OptimLogDensity,
init_vals::AbstractArray = DynamicPPL.getparams(f),
optimizer::Optim.AbstractOptimizer = Optim.LBFGS(),
options::Optim.Options = Optim.Options(),
args...;
model::Model,
f::OptimLogDensity,
init_vals::AbstractArray=DynamicPPL.getparams(f),
optimizer::Optim.AbstractOptimizer=Optim.LBFGS(),
options::Optim.Options=Optim.Options(),
args...;
kwargs...
)
# Convert the initial values, since it is assumed that users provide them
Expand All @@ -243,7 +253,7 @@ function _optimize(
@warn "Optimization did not converge! You may need to correct your model or adjust the Optim parameters."
end

# Get the VarInfo at the MLE/MAP point, and run the model to ensure
# Get the VarInfo at the MLE/MAP point, and run the model to ensure
# correct dimensionality.
@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
@set! f.varinfo = invlink!!(f.varinfo, model)
Expand Down

2 comments on commit 59e5cce

@yebai
Copy link
Member

@yebai yebai commented on 59e5cce Jul 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/87338

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.26.4 -m "<description of version>" 59e5cce0f00cfc27e2ddcb8e6dfc8387b502065d
git push origin v0.26.4

Please sign in to comment.