diff --git a/Project.toml b/Project.toml index 78d3ebe0..abac7047 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.2.4" +version = "0.2.5" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/ess.jl b/src/ess.jl index 506bd8e7..7cda9938 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -204,7 +204,9 @@ end ) Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shape -`(draws, chains, parameters)` with the `method` and a maximum lag of `maxlag`. +`(draws, chains, parameters)` with the `method`. + +`maxlag` indicates the maximum lag for which autocovariance is computed. By default, the computed ESS and ``\\widehat{R}`` values correspond to the estimator `mean`. Other estimators can be specified by passing a function `estimator` (see below). @@ -258,15 +260,18 @@ function ess_rhat( nchains = split_chains * size(chains, 2) ntotal = niter * nchains axes_out = (axes(chains, 3),) + T = promote_type(eltype(chains), typeof(zero(eltype(chains)) / 1)) - # do not compute estimates if there is only one sample or lag - maxlag = min(maxlag, niter - 1) - if !(maxlag > 0) + # discard the last pair of autocorrelations, which are poorly estimated and only matter + # when chains have mixed poorly anyways. + # leave the last even autocorrelation as a bias term that reduces variance for + # case of antithetical chains, see below + maxlag = min(maxlag, niter - 4) + if !(maxlag > 0) || T === Missing return similar(chains, Missing, axes_out), similar(chains, Missing, axes_out) end # define caches for mean and variance - T = promote_type(eltype(chains), typeof(zero(eltype(chains)) / 1)) chain_mean = Array{T}(undef, 1, nchains) chain_var = Array{T}(undef, nchains) samples = Array{T}(undef, niter, nchains) @@ -281,6 +286,9 @@ function ess_rhat( ess = similar(chains, T, axes_out) rhat = similar(chains, T, axes_out) + # set maximum ess for antithetic chains, see below + ess_max = ntotal * log10(oftype(one(T), ntotal)) + # for each parameter for (i, chains_slice) in zip(eachindex(ess), eachslice(chains; dims=3)) # check that no values are missing @@ -328,7 +336,7 @@ function ess_rhat( sum_pₜ = pₜ k = 2 - while k < maxlag + while k < (maxlag - 1) # compute subsequent autocorrelation of all chains # by combining estimates of each chain ρ_even = 1 - inv_var₊ * (W - mean_autocov(k, esscache)) @@ -347,10 +355,19 @@ function ess_rhat( # update indices k += 2 end + # for antithetic chains + # - reduce variance by averaging truncation to odd lag and truncation to next even lag + # - prevent negative ESS for short chains by ensuring τ is nonnegative + # See discussions in: + # - § 3.2 of Vehtari et al. https://arxiv.org/pdf/1903.08008v5.pdf + # - https://github.com/TuringLang/MCMCDiagnosticTools.jl/issues/40 + # - https://github.com/stan-dev/rstan/pull/618 + # - https://github.com/stan-dev/stan/pull/2774 + ρ_even = maxlag > 1 ? 1 - inv_var₊ * (W - mean_autocov(k, esscache)) : zero(ρ_even) + τ = max(0, 2 * sum_pₜ + max(0, ρ_even) - 1) # estimate the effective sample size - τ = 2 * sum_pₜ - 1 - ess[i] = ntotal / τ + ess[i] = min(ntotal / τ, ess_max) end return ess, rhat diff --git a/test/ess.jl b/test/ess.jl index 647fba86..3b875c53 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -160,19 +160,30 @@ end end @testset "ESS and R̂ (single sample)" begin # check that issue #137 is fixed - x = rand(1, 3, 5) + x = rand(4, 3, 5) for method in (ESSMethod(), FFTESSMethod(), BDAESSMethod()) # analyze array - ess_array, rhat_array = ess_rhat(x; method=method) + ess_array, rhat_array = ess_rhat(x; method=method, split_chains=1) @test length(ess_array) == size(x, 3) - @test all(ismissing, ess_array) # since min(maxlag, niter - 1) = 0 + @test all(ismissing, ess_array) # since min(maxlag, niter - 4) = 0 @test length(rhat_array) == size(x, 3) @test all(ismissing, rhat_array) end end + @testset "ESS and R̂ with Union{Missing,Float64} eltype" begin + x = Array{Union{Missing,Float64}}(undef, 1000, 4, 3) + x .= randn.() + x[1, 1, 1] = missing + S, R = ess_rhat(x) + @test ismissing(S[1]) + @test ismissing(R[1]) + @test !any(ismissing, S[2:3]) + @test !any(ismissing, R[2:3]) + end + @testset "Autocov of ESSMethod and FFTESSMethod equivalent to StatsBase" begin x = randn(1_000, 10, 40) ess_exp = ess_rhat(x; method=ExplicitESSMethod())[1] @@ -227,6 +238,21 @@ end end end + @testset "ESS thresholded for antithetic chains" begin + # for φ = -0.3 (slightly antithetic), ESS without thresholding for low ndraws is + # often >ndraws*log10(ndraws) + # for φ = -0.9 (highly antithetic), ESS without thresholding for low ndraws is + # usually negative + nchains = 4 + @testset for ndraws in (10, 100), φ in (-0.3, -0.9) + x = ar1(φ, sqrt(1 - φ^2), ndraws, nchains, 1000) + Smin, Smax = extrema(ess_rhat(mean, x)[1]) + ntotal = ndraws * nchains + @test Smax == ntotal * log10(ntotal) + @test Smin > 0 + end + end + @testset "ess_rhat_bulk(x)" begin xnorm = randn(1_000, 4, 10) @test ess_rhat_bulk(xnorm) == ess_rhat(mean, _rank_normalize(xnorm))