From 0bd9367a21887400674b623b512a91f588840ba3 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 10 Jan 2023 11:52:18 +0100 Subject: [PATCH 1/6] Use biased autocov for ESSMethod and FFTESSMethod --- src/ess.jl | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/ess.jl b/src/ess.jl index d385ad8d..ab3735ec 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -9,8 +9,6 @@ effective sample size of MCMC chains. It is is based on the discussion by Vehtari et al. and uses the biased estimator of the autocovariance, as discussed by Geyer. -In contrast to Geyer, the divisor `n - 1` is used in the estimation of -the autocovariance to obtain the unbiased estimator of the variance for lag 0. # References @@ -155,12 +153,9 @@ function mean_autocov(k::Int, cache::ESSCache) ) end - # normalize autocovariance estimators by `niter - 1` instead - # of `niter - k` to obtain - # - unbiased estimators of the variance for lag 0 - # - biased but more stable estimators for all other lags as discussed by - # Geyer (1992) - return s / (niter - 1) + # normalize autocovariance estimators by `niter` instead of `niter - k` to obtain biased + # but more stable estimators for all lags as discussed by Geyer (1992) + return s / niter end function mean_autocov(k::Int, cache::FFTESSCache) @@ -172,9 +167,10 @@ function mean_autocov(k::Int, cache::FFTESSCache) # we use biased but more stable estimators as discussed by Geyer (1992) samples_cache = cache.samples_cache chain_var = cache.chain_var + uncorrection_factor = (niter - 1) // niter # undo corrected=true for chain_var return Statistics.mean(1:nchains) do i @inbounds(real(samples_cache[k + 1, i]) / real(samples_cache[1, i])) * chain_var[i] - end + end * uncorrection_factor end function mean_autocov(k::Int, cache::BDAESSCache) From 7784237bb56b45671bff26b84a99dcb5eb168070 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 10 Jan 2023 11:52:36 +0100 Subject: [PATCH 2/6] Increment patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7d5cd9c9..f248508d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.2.1" +version = "0.2.2" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From e94bd9fd9bf4d06d121664f79b89843ecdaf61d9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 10 Jan 2023 12:10:19 +0100 Subject: [PATCH 3/6] Run formatter --- src/ess.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ess.jl b/src/ess.jl index ab3735ec..21ac72d9 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -167,10 +167,11 @@ function mean_autocov(k::Int, cache::FFTESSCache) # we use biased but more stable estimators as discussed by Geyer (1992) samples_cache = cache.samples_cache chain_var = cache.chain_var - uncorrection_factor = (niter - 1) // niter # undo corrected=true for chain_var - return Statistics.mean(1:nchains) do i + uncorrection_factor = (niter - 1)//niter # undo corrected=true for chain_var + result = Statistics.mean(1:nchains) do i @inbounds(real(samples_cache[k + 1, i]) / real(samples_cache[1, i])) * chain_var[i] - end * uncorrection_factor + end + return result * uncorrection_factor end function mean_autocov(k::Int, cache::BDAESSCache) From fda500d44610544e9a34beeeb9f32f6ea0d2ca6f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 10 Jan 2023 12:10:46 +0100 Subject: [PATCH 4/6] Add test for autocov methods being equivalent to StatsBase --- test/ess.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/ess.jl b/test/ess.jl index c58c00b7..60b138b9 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -1,3 +1,19 @@ +using MCMCDiagnosticTools +using StatsBase +using Test + +struct ExplicitESSMethod <: MCMCDiagnosticTools.AbstractESSMethod end +struct ExplicitESSCache{S} + samples::S +end +function MCMCDiagnosticTools.build_cache(::ExplicitESSMethod, samples::Matrix, var::Vector) + return ExplicitESSCache(samples) +end +MCMCDiagnosticTools.update!(::ExplicitESSCache) = nothing +function MCMCDiagnosticTools.mean_autocov(k::Int, cache::ExplicitESSCache) + return mean(autocov(cache.samples, k:k; demean=true)) +end + @testset "ess.jl" begin @testset "copy and split" begin # check a matrix with even number of rows @@ -87,4 +103,12 @@ @test all(ismissing, rhat_array) end end + + @testset "Autocov of ESSMethod and FFTESSMethod equivalent to StatsBase" begin + x = randn(1_000, 10, 40) + ess_exp = ess_rhat(x; method=ExplicitESSMethod())[1] + @testset "$method" for method in [FFTESSMethod(), ESSMethod()] + @test ess_rhat(x; method=method)[1] ≈ ess_exp + end + end end From aaae588fab84b66f2fd815680de7cc1dcbfde6b1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 10 Jan 2023 12:11:04 +0100 Subject: [PATCH 5/6] Add StatsBase as a test dependency --- test/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index 09fca5b2..a2bd2fbc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -18,5 +19,6 @@ MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" MLJLIBSVMInterface = "0.1, 0.2" MLJXGBoostInterface = "0.1, 0.2, 0.3" +StatsBase = "0.33" Tables = "1" julia = "1.3" From 06747e578d5542377d94c00afb962aa00b790b29 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 12 Jan 2023 22:28:45 +0100 Subject: [PATCH 6/6] Increment patch number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index af68e063..73087b7e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.2.2" +version = "0.2.3" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"