Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use biased autocov for all lags #61

Merged
merged 7 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.2.2"
version = "0.2.3"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
15 changes: 6 additions & 9 deletions src/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ effective sample size of MCMC chains.

It is is based on the discussion by [^VehtariGelman2021] and uses the
biased estimator of the autocovariance, as discussed by [^Geyer1992].
In contrast to Geyer, the divisor `n - 1` is used in the estimation of
the autocovariance to obtain the unbiased estimator of the variance for lag 0.

[^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021).
Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for
Expand Down Expand Up @@ -157,12 +155,9 @@ function mean_autocov(k::Int, cache::ESSCache)
)
end

# normalize autocovariance estimators by `niter - 1` instead
# of `niter - k` to obtain
# - unbiased estimators of the variance for lag 0
# - biased but more stable estimators for all other lags as discussed by
# Geyer (1992)
return s / (niter - 1)
# normalize autocovariance estimators by `niter` instead of `niter - k` to obtain biased
# but more stable estimators for all lags as discussed by Geyer (1992)
return s / niter
end

function mean_autocov(k::Int, cache::FFTESSCache)
Expand All @@ -174,9 +169,11 @@ function mean_autocov(k::Int, cache::FFTESSCache)
# we use biased but more stable estimators as discussed by Geyer (1992)
samples_cache = cache.samples_cache
chain_var = cache.chain_var
return Statistics.mean(1:nchains) do i
uncorrection_factor = (niter - 1)//niter # undo corrected=true for chain_var
result = Statistics.mean(1:nchains) do i
@inbounds(real(samples_cache[k + 1, i]) / real(samples_cache[1, i])) * chain_var[i]
end
return result * uncorrection_factor
end

function mean_autocov(k::Int, cache::BDAESSCache)
Expand Down
20 changes: 20 additions & 0 deletions test/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ using Statistics
using StatsBase
using Test

struct ExplicitESSMethod <: MCMCDiagnosticTools.AbstractESSMethod end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

struct ExplicitESSCache{S}
samples::S
end
function MCMCDiagnosticTools.build_cache(::ExplicitESSMethod, samples::Matrix, var::Vector)
return ExplicitESSCache(samples)
end
MCMCDiagnosticTools.update!(::ExplicitESSCache) = nothing
function MCMCDiagnosticTools.mean_autocov(k::Int, cache::ExplicitESSCache)
return mean(autocov(cache.samples, k:k; demean=true))
end

struct CauchyProblem end
LogDensityProblems.logdensity(p::CauchyProblem, θ) = -sum(log1psq, θ)
function LogDensityProblems.logdensity_and_gradient(p::CauchyProblem, θ)
Expand Down Expand Up @@ -121,6 +133,14 @@ end
end
end

@testset "Autocov of ESSMethod and FFTESSMethod equivalent to StatsBase" begin
x = randn(1_000, 10, 40)
ess_exp = ess_rhat(x; method=ExplicitESSMethod())[1]
@testset "$method" for method in [FFTESSMethod(), ESSMethod()]
@test ess_rhat(x; method=method)[1] ≈ ess_exp
end
end

@testset "ESS and R̂ for chains with 2 epochs that have not mixed" begin
# checks that splitting yields lower ESS estimates and higher Rhat estimates
x = randn(1000, 4, 10) .+ repeat([0, 10]; inner=(500, 1, 1))
Expand Down