diff --git a/Project.toml b/Project.toml index af68e063..73087b7e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.2.2" +version = "0.2.3" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/ess.jl b/src/ess.jl index da6d9ea2..c31ece80 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -9,8 +9,6 @@ effective sample size of MCMC chains. It is is based on the discussion by [^VehtariGelman2021] and uses the biased estimator of the autocovariance, as discussed by [^Geyer1992]. -In contrast to Geyer, the divisor `n - 1` is used in the estimation of -the autocovariance to obtain the unbiased estimator of the variance for lag 0. [^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021). Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for @@ -157,12 +155,9 @@ function mean_autocov(k::Int, cache::ESSCache) ) end - # normalize autocovariance estimators by `niter - 1` instead - # of `niter - k` to obtain - # - unbiased estimators of the variance for lag 0 - # - biased but more stable estimators for all other lags as discussed by - # Geyer (1992) - return s / (niter - 1) + # normalize autocovariance estimators by `niter` instead of `niter - k` to obtain biased + # but more stable estimators for all lags as discussed by Geyer (1992) + return s / niter end function mean_autocov(k::Int, cache::FFTESSCache) @@ -174,9 +169,11 @@ function mean_autocov(k::Int, cache::FFTESSCache) # we use biased but more stable estimators as discussed by Geyer (1992) samples_cache = cache.samples_cache chain_var = cache.chain_var - return Statistics.mean(1:nchains) do i + uncorrection_factor = (niter - 1)//niter # undo corrected=true for chain_var + result = Statistics.mean(1:nchains) do i @inbounds(real(samples_cache[k + 1, i]) / real(samples_cache[1, i])) * chain_var[i] end + return result * uncorrection_factor end function mean_autocov(k::Int, cache::BDAESSCache) diff --git a/test/ess.jl b/test/ess.jl index ec1100f9..0e2752cf 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -9,6 +9,18 @@ using Statistics using StatsBase using Test +struct ExplicitESSMethod <: MCMCDiagnosticTools.AbstractESSMethod end +struct ExplicitESSCache{S} + samples::S +end +function MCMCDiagnosticTools.build_cache(::ExplicitESSMethod, samples::Matrix, var::Vector) + return ExplicitESSCache(samples) +end +MCMCDiagnosticTools.update!(::ExplicitESSCache) = nothing +function MCMCDiagnosticTools.mean_autocov(k::Int, cache::ExplicitESSCache) + return mean(autocov(cache.samples, k:k; demean=true)) +end + struct CauchyProblem end LogDensityProblems.logdensity(p::CauchyProblem, θ) = -sum(log1psq, θ) function LogDensityProblems.logdensity_and_gradient(p::CauchyProblem, θ) @@ -121,6 +133,14 @@ end end end + @testset "Autocov of ESSMethod and FFTESSMethod equivalent to StatsBase" begin + x = randn(1_000, 10, 40) + ess_exp = ess_rhat(x; method=ExplicitESSMethod())[1] + @testset "$method" for method in [FFTESSMethod(), ESSMethod()] + @test ess_rhat(x; method=method)[1] ≈ ess_exp + end + end + @testset "ESS and R̂ for chains with 2 epochs that have not mixed" begin # checks that splitting yields lower ESS estimates and higher Rhat estimates x = randn(1000, 4, 10) .+ repeat([0, 10]; inner=(500, 1, 1))