diff --git a/Project.toml b/Project.toml index 848c78c3..88532b6c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.3.1" +version = "0.3.2" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index a07f8633..5328dd4e 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -203,6 +203,7 @@ end ess( samples::AbstractArray{<:Union{Missing,Real},3}; kind=:bulk, + relative::Bool=false, autocov_method=AutocovMethod(), split_chains::Int=2, maxlag::Int=250, @@ -215,6 +216,8 @@ Estimate the effective sample size (ESS) of the `samples` of shape Optionally, the `kind` of ESS estimate to be computed can be specified (see below). Some `kind`s accept additional `kwargs`. +If `relative` is `true`, the relative ESS is returned, i.e. `ess / (draws * chains)`. + $_DOC_SPLIT_CHAINS There must be at least 3 draws in each chain after splitting. `maxlag` indicates the maximum lag for which autocovariance is computed and must be greater @@ -447,6 +450,7 @@ end function _ess_rhat( ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real},3}; + relative::Bool=false, autocov_method::AbstractAutocovMethod=AutocovMethod(), split_chains::Int=2, maxlag::Int=250, @@ -485,8 +489,8 @@ function _ess_rhat( # define cache for the computation of the autocorrelation esscache = build_cache(autocov_method, samples, chain_var) - # set maximum ess for antithetic chains, see below - ess_max = ntotal * log10(oftype(one(T), ntotal)) + # set maximum relative ess for antithetic chains, see below + rel_ess_max = log10(oftype(one(T), ntotal)) # for each parameter for (i, chains_slice) in zip(eachindex(ess), eachslice(chains; dims=3)) @@ -565,8 +569,13 @@ function _ess_rhat( ρ_even = maxlag > 1 ? 1 - inv_var₊ * (W - mean_autocov(k, esscache)) : zero(ρ_even) τ = max(0, 2 * sum_pₜ + max(0, ρ_even) - 1) - # estimate the effective sample size - ess[i] = min(ntotal / τ, ess_max) + # estimate the relative effective sample size + ess[i] = min(inv(τ), rel_ess_max) + end + + if !relative + # absolute effective sample size + ess .*= ntotal end return (; ess, rhat) diff --git a/test/ess_rhat.jl b/test/ess_rhat.jl index dc251db5..ee8c7461 100644 --- a/test/ess_rhat.jl +++ b/test/ess_rhat.jl @@ -88,6 +88,21 @@ mymean(x) = mean(x) @test_throws ArgumentError ess(x2; kind=mymean) end + @testset "relative=true" begin + @testset for kind in (:rank, :bulk, :tail, :basic), + niter in (50, 100), + nchains in (2, 4) + + ss = niter * nchains + x = rand(niter, nchains, 2) + kind === :rank || @test ess(x; kind, relative=true) ≈ ess(x; kind) / ss + S, R = ess_rhat(x; kind) + S2, R2 = ess_rhat(x; kind, relative=true) + @test S2 ≈ S / ss + @test R2 == R + end + end + @testset "Union{Missing,Float64} eltype" begin @testset for kind in (:rank, :bulk, :tail, :basic) x = Array{Union{Missing,Float64}}(undef, 1000, 4, 3)