Skip to content

Commit

Permalink
fix posterior mode and logging
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelJuillard committed Dec 27, 2023
1 parent 774698e commit ec0b859
Showing 1 changed file with 92 additions and 65 deletions.
157 changes: 92 additions & 65 deletions src/estimation/estimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ Base.@kwdef struct EstimationOptions
data::AxisArrayTable = AxisArrayTable(AxisArrayTables.AxisArray(Matrix{Float64}(undef, 0, 0), PeriodsSinceEpoch[], Symbol[]))
datafile::String = ""
diffuse_filter::Bool = false
display::Bool = false
display::Bool = true
fast_kalman_filter::Bool = true
first_obs::PeriodsSinceEpoch = Undated(typemin(Int))
last_obs::PeriodsSinceEpoch = Undated(typemin(Int))
mcmc_chains::Int = 1
mcmc_init_scale::Float64 = 0
mcmc_jscale::Float64 = 0
mcmc_jscale::Float64 = 0.2
mcmc_replic::Int = 0
mode_check::Bool = false
mode_compute::Bool = true
Expand Down Expand Up @@ -81,28 +81,35 @@ function estimation!(context, field::Dict{String, Any})
results = context.results.model_results[1]
lre_results = results.linearrationalexpectations
estimation_results = results.estimation
observations,obsvarnames = get_observations(context, options.datafile, options.data, options.first_obs, options.last_obs, options.nobs)
observations, obsvarnames = get_observations(context, options.datafile, options.data, options.first_obs, options.last_obs, options.nobs)
nobs = size(observations, 2)
estimated_parameters = context.work.estimated_parameters
initial_parameter_values = get_initial_value_or_mean()
set_estimated_parameters!(context, initial_parameter_values)
if options.plot_priors
plot_priors(context, estimated_parameters.name)
plot_priors()
end
if options.mode_compute
(res, mode, tstdh, mode_covariance) = posterior_mode!(context, initial_parameter_values, observations, obsvarnames, transformed_parameters = true)
@debug res
end
options.display && log_result(res)
estimation_result_table(
get_parameter_names(estimated_parameters),
mode,
tstdh,
"Posterior mode"
)
end

if !isempty(options.mode_file)
mode, mode_covariance = get_mode_file(options.mode_file)
end

if options.mcmc_replic > 0
chain = mh_estimation(context, observations, mode,
options.mcmc_jscale*mode_covariance,
mcmc_replic=options.mcmc_replic)
StatsPlots.plot(chain)
chain, back_transformed_chain = mh_estimation(context, observations, obsvarnames, mode,
covariance = options.mcmc_jscale*mode_covariance,
mcmc_replic = options.mcmc_replic)
output_MCMCChains(context, chain, options.display, options.display)
options.display && plot_prior_posterior(back_transformed_chain)
end

return nothing
Expand All @@ -127,7 +134,7 @@ end
computes the posterior mode.
# Keywork arguments
# Keyword arguments
- `context::Context=context`: context of the computation
- `data::AxisArrayTable`: AxisArrayTable containing observed variables
- `datafile::String: data filename
Expand Down Expand Up @@ -162,6 +169,7 @@ function mode_compute!(; algorithm = BFGS,

observations, obsvarnames = get_observations(context, datafile, data, first_obs, last_obs, nobs)
(res, mode, tstdh, mode_covariance) = posterior_mode!(context, initial_values, observations, obsvarnames, algorithm = algorithm, transformed_parameters = transformed_parameters)
display && log_result(res)
end

"""
Expand Down Expand Up @@ -192,7 +200,7 @@ end
runs random walk Monte Carlo simulations of the posterior
# Keywork arguments
# Keyword arguments
- `context::Context=context`: context of the computation
- `covariance::AbstractMatrix{Float64}`:
- `data::AxisArrayTable`: AxisArrayTable containing observed variables
Expand Down Expand Up @@ -255,16 +263,15 @@ function rwmh_compute!(;context::Context=context,
)
output_MCMCChains(context, chain, display, plot_chain)
plot_posterior_density && plot_prior_posterior(context, back_transformed_chain)
plot_chain && StatsPlots.plot(chain)
return chain
end

function output_MCMCChains(context, chain, display, plot_chain)
estimation_results = context.results.model_results[1].estimation
n = estimation_results.posterior_mcmc_chains_nbr += 1
path = mkpath(joinpath(context.modfileinfo.modfilepath, "output"))
serialize("$path/mcmc_chain_$n.jls", chain)
display && Base.display(chain)
serialize("$path/mcmc_chain_$n.jls", chain)
display && log_result(chain)
plot_chain && plot_MCMCChains(chain, n, "$path/graphs", display)
end

Expand Down Expand Up @@ -513,7 +520,6 @@ function make_logposteriordensity(context, observations, ssws)
try
lpd += loglikelihood(x, context, observations, ssws)
catch e
error(e)
@debug e
lpd = -Inf
end
Expand Down Expand Up @@ -940,12 +946,6 @@ function posterior_mode!(
results.posterior_mode_covariance = copy(invhess)
end

estimation_result_table(
get_parameter_names(ep),
results.posterior_mode,
results.posterior_mode_std,
"Posterior mode"
)
return(res,
results.posterior_mode,
results.posterior_mode_std,
Expand All @@ -964,7 +964,7 @@ function mh_estimation(
last_obs = 0,
mcmc_chains = 1,
mcmc_replic = 100000,
transformed_covariance = Matrix(under, 0, 0),
transformed_covariance = Matrix(undef, 0, 0),
transformed_parameters = true,
kwargs...,
)
Expand Down Expand Up @@ -1294,31 +1294,6 @@ function mcmc_diagnostics(chains, context, names)
f
end

"""
plot_priors(context, names, n_points = 100)
plots prior density
"""
function plot_priors(context, names, n_points = 100)
indices = [find(context.work.estimated_parameters.name, e) for e in names]
prior_pdfs = []
n_plots = length(indices)
prior_x_axes = []
for i in indices
prior = context.work.estimated_parameters.prior[i]
m, v = mean(prior), var(prior)
prior_x_axis = LinRange(m-15*v, m+15*v, n_points)
prior_pdf = [pdf(prior, e) for e in prior_x_axis]
push!(prior_pdfs, prior_pdf/sum(prior_pdf))
push!(prior_x_axes, prior_x_axis)
end
prior_pdfs = hcat(prior_pdfs...)
prior_x_axes = hcat(prior_x_axes...)
names = hcat(names...)
f = Plots.plot(prior_x_axes, prior_pdfs, layout=n_plots, title=names, linecolor=:darkgrey, labels=false, linewidth=2)
display(f)
end

function transform_chains(chains, t, posterior_density)
y = chains.value.data
nparams = size(y, 2) - 1
Expand Down Expand Up @@ -1500,7 +1475,60 @@ function get_index_name(s::Symbol, symboltable::SymbolTable)
return (index, name)
end

function plot_prior_posterior(context::Context, chains)
"""
plot_priors(; context::Context = context, n_points::Int = 100)
plots prior density
# Keyword arguments
- `context::Context = context`: context in which to take the date to be ploted
- `n_points::Int = 100`: number of points used for a curve
"""
function plot_priors(; context::Context = context, n_points = 100)
ep = context.work.estimated_parameters
@assert length(ep.prior) > 0 "There is no defined prior"

path = "$(context.modfileinfo.modfilepath)/graphs/"
mkpath(path)
filename = "$(path)/Priors"
nprior = length(ep.prior)
(nbplt, nr, nc, lr, lc, nstar) = pltorg(nprior)
ivars = collect(1:nr*nc)
for p = 1:nbplt
filename1 = "$(filename)_$(p).png"
p == nbplt && (ivars = ivars[1:nstar])
plot_panel_priors(
ep.prior,
ep.name,
ivars,
nr,
nc,
filename1
)
ivars .+= nr * nc
end
end

function plot_panel_priors(
prior,
ylabels,
ivars,
nr,
nc,
filename;
kwargs...
)
sp = [Plots.plot(showaxis = false, ticks = false, grid = false) for i = 1:nr*nc]
for (i, j) in enumerate(ivars)
sp[i] = Plots.plot(prior[j], title = ylabels[j], labels = "Prior", kwargs...)
end

pl = Plots.plot(sp..., layout = (nr, nc), size = (900, 900), plot_title = "Prior distributions")
graph_display(pl)
savefig(filename)
end

function plot_prior_posterior(chains; context::Context=context)
ep = context.work.estimated_parameters
mode = context.results.model_results[1].estimation.posterior_mode
@assert length(ep.prior) > 0 "There is no defined prior"
Expand All @@ -1518,12 +1546,10 @@ function plot_prior_posterior(context::Context, chains)
ep.prior,
chains,
mode,
"Priors",
ep.name,
ivars,
nr,
nc,
nr * nc,
filename1
)
ivars .+= nr * nc
Expand All @@ -1534,12 +1560,10 @@ function plot_prior_posterior(context::Context, chains)
ep.prior,
chains,
mode,
"Priors",
ep.name,
ivars,
lr,
lc,
nstar,
filename
)
end
Expand All @@ -1548,25 +1572,28 @@ function plot_panel_prior_posterior(
prior,
chains,
mode,
title,
ylabels,
ivars,
nr,
nc,
nstar,
filename;
kwargs...
)
sp = [Plots.plot(showaxis = false, ticks = false, grid = false) for i = 1:nr*nc]
for (i, j) in enumerate(ivars)
posterior_density = kde(vec(get(chains, Symbol(ylabels[j]))[1].data))
sp[i] = Plots.plot(prior[j], title = ylabels[j], labels = "Prior", kwargs...)
plot!(posterior_density, labels = "Posterior")
end
sp = [Plots.plot(showaxis = false, ticks = false, grid = false) for i = 1:nr*nc]
for (i, j) in enumerate(ivars)
posterior_density = kde(vec(get(chains, Symbol(ylabels[j]))[1].data))

sp[i] = Plots.plot(prior[j], title = ylabels[j], labels = "Prior", kwargs...)
plot!(posterior_density, labels = "Posterior")
end

pl = Plots.plot(sp..., layout = (nr, nc), size = (900, 900), plot_title = "Prior and posterior distributions")
graph_display(pl)
savefig(filename)
pl = Plots.plot(sp..., layout = (nr, nc), size = (900, 900), plot_title = "Prior and posterior distributions")
graph_display(pl)
savefig(filename)
end

function log_result(result)
io = IOBuffer()
show(io, "text/plain", result)
@info String(take!(io))
end

0 comments on commit ec0b859

Please sign in to comment.