From 383f651b8fd85f6e3ea6088d0a79854462f10af6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 16 Apr 2023 20:50:41 +0200 Subject: [PATCH 1/4] Add relative keyword to ess/ess_rhat --- src/ess_rhat.jl | 9 +++++++++ test/ess_rhat.jl | 13 +++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index a07f8633..89edea06 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -203,6 +203,7 @@ end ess( samples::AbstractArray{<:Union{Missing,Real},3}; kind=:bulk, + relative::Bool=false, autocov_method=AutocovMethod(), split_chains::Int=2, maxlag::Int=250, @@ -215,6 +216,9 @@ Estimate the effective sample size (ESS) of the `samples` of shape Optionally, the `kind` of ESS estimate to be computed can be specified (see below). Some `kind`s accept additional `kwargs`. +If `relative` is `true`, the relative ESS is returned, i.e. the ESS divided by the sample +size. + $_DOC_SPLIT_CHAINS There must be at least 3 draws in each chain after splitting. `maxlag` indicates the maximum lag for which autocovariance is computed and must be greater @@ -447,6 +451,7 @@ end function _ess_rhat( ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real},3}; + relative::Bool=false, autocov_method::AbstractAutocovMethod=AutocovMethod(), split_chains::Int=2, maxlag::Int=250, @@ -569,6 +574,10 @@ function _ess_rhat( ess[i] = min(ntotal / τ, ess_max) end + if relative + ess ./= ntotal + end + return (; ess, rhat) end function _ess_rhat(::Val{:bulk}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) diff --git a/test/ess_rhat.jl b/test/ess_rhat.jl index dc251db5..4d165e15 100644 --- a/test/ess_rhat.jl +++ b/test/ess_rhat.jl @@ -88,6 +88,19 @@ mymean(x) = mean(x) @test_throws ArgumentError ess(x2; kind=mymean) end + @testset "relative=true" begin + @testset for kind in (:rank, :bulk, :tail, :basic), + niter in (50, 100), + nchains in (2, 4) + + ss = niter * nchains + x = rand(niter, nchains, 2) + kind === :rank || @test ess(x; kind, relative=true) == ess(x; kind) / ss + S, R = ess_rhat(x; kind) + @test ess_rhat(x; kind, relative=true) == (ess=S / ss, rhat=R) + end + end + @testset "Union{Missing,Float64} eltype" begin @testset for kind in (:rank, :bulk, :tail, :basic) x = Array{Union{Missing,Float64}}(undef, 1000, 4, 3) From fd3c4a0a572f8f382334f6e99127557644df7202 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 16 Apr 2023 20:50:56 +0200 Subject: [PATCH 2/4] Increment patch number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 848c78c3..88532b6c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.3.1" +version = "0.3.2" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From 536f52f5f2a0e38babbdb643c59e46e4fbf58ffe Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 16 Apr 2023 20:52:33 +0200 Subject: [PATCH 3/4] Update docstring --- src/ess_rhat.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 89edea06..da5bcc49 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -216,8 +216,7 @@ Estimate the effective sample size (ESS) of the `samples` of shape Optionally, the `kind` of ESS estimate to be computed can be specified (see below). Some `kind`s accept additional `kwargs`. -If `relative` is `true`, the relative ESS is returned, i.e. the ESS divided by the sample -size. +If `relative` is `true`, the relative ESS is returned, i.e. `ess / (draws * chains)`. $_DOC_SPLIT_CHAINS There must be at least 3 draws in each chain after splitting. From 8d46fec5121c82c83fd08637d6e5b404b9116303 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 16 Apr 2023 21:03:10 +0200 Subject: [PATCH 4/4] Rename variables in terms of relative ESS --- src/ess_rhat.jl | 13 +++++++------ test/ess_rhat.jl | 6 ++++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index da5bcc49..5328dd4e 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -489,8 +489,8 @@ function _ess_rhat( # define cache for the computation of the autocorrelation esscache = build_cache(autocov_method, samples, chain_var) - # set maximum ess for antithetic chains, see below - ess_max = ntotal * log10(oftype(one(T), ntotal)) + # set maximum relative ess for antithetic chains, see below + rel_ess_max = log10(oftype(one(T), ntotal)) # for each parameter for (i, chains_slice) in zip(eachindex(ess), eachslice(chains; dims=3)) @@ -569,12 +569,13 @@ function _ess_rhat( ρ_even = maxlag > 1 ? 1 - inv_var₊ * (W - mean_autocov(k, esscache)) : zero(ρ_even) τ = max(0, 2 * sum_pₜ + max(0, ρ_even) - 1) - # estimate the effective sample size - ess[i] = min(ntotal / τ, ess_max) + # estimate the relative effective sample size + ess[i] = min(inv(τ), rel_ess_max) end - if relative - ess ./= ntotal + if !relative + # absolute effective sample size + ess .*= ntotal end return (; ess, rhat) diff --git a/test/ess_rhat.jl b/test/ess_rhat.jl index 4d165e15..ee8c7461 100644 --- a/test/ess_rhat.jl +++ b/test/ess_rhat.jl @@ -95,9 +95,11 @@ mymean(x) = mean(x) ss = niter * nchains x = rand(niter, nchains, 2) - kind === :rank || @test ess(x; kind, relative=true) == ess(x; kind) / ss + kind === :rank || @test ess(x; kind, relative=true) ≈ ess(x; kind) / ss S, R = ess_rhat(x; kind) - @test ess_rhat(x; kind, relative=true) == (ess=S / ss, rhat=R) + S2, R2 = ess_rhat(x; kind, relative=true) + @test S2 ≈ S / ss + @test R2 == R end end