Skip to content

Commit

Permalink
Add option to compute relative ESS (#81)
Browse files Browse the repository at this point in the history
* Add relative keyword to ess/ess_rhat

* Increment patch number

* Update docstring

* Rename variables in terms of relative ESS
  • Loading branch information
sethaxen authored May 4, 2023
1 parent f2aca2d commit 5b3e960
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.3.1"
version = "0.3.2"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
17 changes: 13 additions & 4 deletions src/ess_rhat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ end
ess(
samples::AbstractArray{<:Union{Missing,Real},3};
kind=:bulk,
relative::Bool=false,
autocov_method=AutocovMethod(),
split_chains::Int=2,
maxlag::Int=250,
Expand All @@ -215,6 +216,8 @@ Estimate the effective sample size (ESS) of the `samples` of shape
Optionally, the `kind` of ESS estimate to be computed can be specified (see below). Some
`kind`s accept additional `kwargs`.
If `relative` is `true`, the relative ESS is returned, i.e. `ess / (draws * chains)`.
$_DOC_SPLIT_CHAINS There must be at least 3 draws in each chain after splitting.
`maxlag` indicates the maximum lag for which autocovariance is computed and must be greater
Expand Down Expand Up @@ -447,6 +450,7 @@ end
function _ess_rhat(
::Val{:basic},
chains::AbstractArray{<:Union{Missing,Real},3};
relative::Bool=false,
autocov_method::AbstractAutocovMethod=AutocovMethod(),
split_chains::Int=2,
maxlag::Int=250,
Expand Down Expand Up @@ -485,8 +489,8 @@ function _ess_rhat(
# define cache for the computation of the autocorrelation
esscache = build_cache(autocov_method, samples, chain_var)

# set maximum ess for antithetic chains, see below
ess_max = ntotal * log10(oftype(one(T), ntotal))
# set maximum relative ess for antithetic chains, see below
rel_ess_max = log10(oftype(one(T), ntotal))

# for each parameter
for (i, chains_slice) in zip(eachindex(ess), eachslice(chains; dims=3))
Expand Down Expand Up @@ -565,8 +569,13 @@ function _ess_rhat(
ρ_even = maxlag > 1 ? 1 - inv_var₊ * (W - mean_autocov(k, esscache)) : zero(ρ_even)
τ = max(0, 2 * sum_pₜ + max(0, ρ_even) - 1)

# estimate the effective sample size
ess[i] = min(ntotal / τ, ess_max)
# estimate the relative effective sample size
ess[i] = min(inv(τ), rel_ess_max)
end

if !relative
# absolute effective sample size
ess .*= ntotal
end

return (; ess, rhat)
Expand Down
15 changes: 15 additions & 0 deletions test/ess_rhat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,21 @@ mymean(x) = mean(x)
@test_throws ArgumentError ess(x2; kind=mymean)
end

@testset "relative=true" begin
@testset for kind in (:rank, :bulk, :tail, :basic),
niter in (50, 100),
nchains in (2, 4)

ss = niter * nchains
x = rand(niter, nchains, 2)
kind === :rank || @test ess(x; kind, relative=true) ess(x; kind) / ss
S, R = ess_rhat(x; kind)
S2, R2 = ess_rhat(x; kind, relative=true)
@test S2 S / ss
@test R2 == R
end
end

@testset "Union{Missing,Float64} eltype" begin
@testset for kind in (:rank, :bulk, :tail, :basic)
x = Array{Union{Missing,Float64}}(undef, 1000, 4, 3)
Expand Down

2 comments on commit 5b3e960

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/82924

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.2 -m "<description of version>" 5b3e9604876284620a3e4d5178c4e961557e9c51
git push origin v0.3.2

Please sign in to comment.