From f9dc7bec4b26d8538941d6d6547b892b767465ab Mon Sep 17 00:00:00 2001 From: Paulina Martin Date: Sun, 22 Aug 2021 14:32:14 -0500 Subject: [PATCH] Add draft code for Energy plots --- src/plot.jl | 139 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 136 insertions(+), 3 deletions(-) diff --git a/src/plot.jl b/src/plot.jl index 41edb8f2..be2d8f70 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -4,12 +4,15 @@ @shorthands pooleddensity @shorthands traceplot @shorthands corner +@userplot EnergyPlot +#@shorthands energyplot struct _TracePlot; c; val; end struct _MeanPlot; c; val; end struct _DensityPlot; c; val; end struct _HistogramPlot; c; val; end struct _AutocorPlot; lags; val; end +#struct _EnergyPlot; marginal_energy; energy_transition; p_type; n_chains; end # define alias functions for old syntax const translationdict = Dict( @@ -18,10 +21,10 @@ const translationdict = Dict( :density => _DensityPlot, :histogram => _HistogramPlot, :autocorplot => _AutocorPlot, - :pooleddensity => _DensityPlot + :pooleddensity => _DensityPlot, ) -const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :corner) +const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :corner, :energyplot) @recipe f(c::Chains, s::Symbol) = c, [s] @@ -30,7 +33,8 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor colordim = :chain, barbounds = (-Inf, Inf), maxlag = nothing, - append_chains = false + append_chains = false, + plot_type = :density ) st = get(plotattributes, :seriestype, :traceplot) c = append_chains || st == :pooleddensity ? pool_chain(chains) : chains @@ -64,6 +68,17 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor ac_mat = convert(Array, ac) val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :] _AutocorPlot(lags, val) + #elseif st == :energyplot + # p_type = plot_type + # energy_section = get(c, :hamiltonian_energy) + # #@show energy_section + # #@show params.hamiltonian_energy + # n_chains = (append_chains ? 1 : size(c, 3)) + # energy_data = (append_chains ? vec(energy_section.hamiltonian_energy.data) : energy_section.hamiltonian_energy.data) + # mean_energy = vec(mean(energy_data, dims = 1)) + # marginal_energy = [energy_data[:,i] .- mean_energy[i] for i in 1:n_chains] + # energy_transition = [energy_data[2:end,i] .- energy_data[1:end-1,i] for i in 1:n_chains] + # _EnergyPlot(marginal_energy, energy_transition, p_type, n_chains) elseif st ∈ supportedplots translationdict[st](c, val) else @@ -184,3 +199,121 @@ end ar = collect(Array(corner.c.value[:, corner.parameters,i]) for i in chains(corner.c)) RecipesBase.recipetype(:cornerplot, vcat(ar...)) end + +#function compute_energy( +# chains::Chains, +# combined = false, +# plot_type = :density +#) +# st = get(plotattributes, :seriestype, :traceplot) +# +# if st == :energyplot +# p_type = plot_type +# params = get(chains, :hamiltonian_energy) +# n_chains = (combined ? 1 : size(chains, 3)) +# energy_data = (combined ? vec(params.hamiltonian_energy.data) : params.hamiltonian_energy.data) +# mean_energy = vec(mean(energy_data, dims = 1)) +# marginal_energy = energy_data[:,i] .- mean_energy[i] +# energy_transition = energy_data[2:end,i] .- energy_data[1:end-1,i] +# _EnergyPlot(marginal_energy, energy_transition, p_type, n_chains) +# else +# +# end +#end + +#@recipe function f( +# chains::Chains; +# plot_type = :density, +# append_chains = false +#) +# +# st = get(plotattributes, :seriestype, :traceplot) +# if st == :energyplot +# p_type = plot_type +# energy_section = get(chains, :hamiltonian_energy) +# #@show energy_section +# #@show params.hamiltonian_energy +# n_chains = (append_chains ? 1 : size(chains, 3)) +# energy_data = (append_chains ? vec(energy_section.hamiltonian_energy.data) : energy_section.hamiltonian_energy.data) +# mean_energy = vec(mean(energy_data, dims = 1)) +# marginal_energy = [energy_data[:,i] .- mean_energy[i] for i in 1:n_chains] +# energy_transition = [energy_data[2:end,i] .- energy_data[1:end-1,i] for i in 1:n_chains] +# _EnergyPlot(marginal_energy, energy_transition, p_type, n_chains) +# elseif st ∈ supportedplots +# translationdict[st](c, val) +# end +#end + +function compute_energy( + chains::Chains, + combined = false, + plot_type = :density +) + p_type = plot_type + params = get(chains, :hamiltonian_energy) + isempty(params) && error("EnergyPlot receives a Chains object containing only the + :internals section. Please use Chains(chain, [:internals]) to create it") + n_chains = (combined ? 1 : size(chains, 3)) + energy_data = (combined ? vec(params.hamiltonian_energy.data) : params.hamiltonian_energy.data) + mean_energy = vec(mean(energy_data, dims = 1)) + marginal_energy = [energy_data[:,i] .- mean_energy[i] for i in 1:n_chains] + energy_transition = [energy_data[2:end,i] .- energy_data[1:end-1,i] for i in 1:n_chains] + return marginal_energy, energy_transition, p_type, n_chains + end + +@recipe function f( + p::EnergyPlot; + combined = false, + plot_type = :density + ) + + c = p.args[1] + #p_type = plot_type + #params = get(c, :hamiltonian_energy) + #isempty(params) && error("EnergyPlot receives a Chains object containing only the + # :internals section. Please use Chains(chain, [:internals]) to create it") + #n_chains = (combined ? 1 : size(c, 3)) + #energy_data = (combined ? vec(params.hamiltonian_energy.data) : params.hamiltonian_energy.data) + #mean_energy = vec(mean(energy_data, dims = 1)) + #marginal_energy = [energy_data[:,i] .- mean_energy[i] for i in 1:n_chains] + #energy_transition = [energy_data[2:end,i] .- energy_data[1:end-1,i] for i in 1:n_chains] + marginal_energy, energy_transition, p_type, n_chains = compute_energy(c, combined, plot_type) + k = 0 + for i in 1:n_chains + k += 1 + title --> "Chain $(MCMCChains.chains(c)[i])" + subplot := i + @series begin + seriestype := p_type + label --> "Marginal energy" + marginal_energy[i] + end + + @series begin + seriestype := p_type + label --> "Energy transition" + energy_transition[i] + end + end +end + +#@recipe function f(p::_EnergyPlot) +# +# k = 0 +# for i in 1:p.n_chains +# k = 1 +# @series begin +# subplot := i +# seriestype := p.p_type +# label --> "Marginal energy" +# p.marginal_energy[i] +# end +# +# @series begin +# subplot := i +# seriestype := p.p_type +# label --> "Energy transition" +# p.energy_transition[i] +# end +# end +#end \ No newline at end of file