Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename method in ess to autocov_method #73

Merged
merged 5 commits into from
Feb 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ rhat
ess_rhat
```

The following `method`s are supported:
The following `autocov_method`s are supported:

```@docs
ESSMethod
FFTESSMethod
BDAESSMethod
AutocovMethod
FFTAutocovMethod
BDAAutocovMethod
```

## Monte Carlo standard error
Expand Down
2 changes: 1 addition & 1 deletion src/MCMCDiagnosticTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using Statistics: Statistics

export bfmi
export discretediag
export ess, ess_rhat, rhat, ESSMethod, FFTESSMethod, BDAESSMethod
export ess, ess_rhat, rhat, AutocovMethod, FFTAutocovMethod, BDAAutocovMethod
export gelmandiag, gelmandiag_multivariate
export gewekediag
export heideldiag
Expand Down
68 changes: 34 additions & 34 deletions src/ess_rhat.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# methods
abstract type AbstractESSMethod end
abstract type AbstractAutocovMethod end

const _DOC_SPLIT_CHAINS = """`split_chains` indicates the number of chains each chain is split into.
When `split_chains > 1`, then the diagnostics check for within-chain convergence. When
`d = mod(draws, split_chains) > 0`, i.e. the chains cannot be evenly split, then 1 draw
is discarded after each of the first `d` splits within each chain."""

"""
ESSMethod <: AbstractESSMethod
AutocovMethod <: AbstractAutocovMethod

The `ESSMethod` uses a standard algorithm for estimating the
effective sample size of MCMC chains.
The `AutocovMethod` uses a standard algorithm for estimating the mean autocovariance of MCMC
chains.

It is is based on the discussion by [^VehtariGelman2021] and uses the
biased estimator of the autocovariance, as discussed by [^Geyer1992].
Expand All @@ -22,15 +22,15 @@ biased estimator of the autocovariance, as discussed by [^Geyer1992].
arXiv: [1903.08008](https://arxiv.org/abs/1903.08008)
[^Geyer1992]: Geyer, C. J. (1992). Practical Markov Chain Monte Carlo. Statistical Science, 473-483.
"""
struct ESSMethod <: AbstractESSMethod end
struct AutocovMethod <: AbstractAutocovMethod end

"""
FFTESSMethod <: AbstractESSMethod
FFTAutocovMethod <: AbstractAutocovMethod

The `FFTESSMethod` uses a standard algorithm for estimating
the effective sample size of MCMC chains.
The `FFTAutocovMethod` uses a standard algorithm for estimating the mean autocovariance of
MCMC chains.

The algorithm is the same as the one of [`ESSMethod`](@ref) but this method uses fast
The algorithm is the same as the one of [`AutocovMethod`](@ref) but this method uses fast
Fourier transforms (FFTs) for estimating the autocorrelation.

!!! info
Expand All @@ -39,12 +39,12 @@ Fourier transforms (FFTs) for estimating the autocorrelation.
as [FFTW.jl](https://github.com/JuliaMath/FFTW.jl) or
[FastTransforms.jl](https://github.com/JuliaApproximation/FastTransforms.jl).
"""
struct FFTESSMethod <: AbstractESSMethod end
struct FFTAutocovMethod <: AbstractAutocovMethod end

"""
BDAESSMethod <: AbstractESSMethod
BDAAutocovMethod <: AbstractAutocovMethod

The `BDAESSMethod` uses a standard algorithm for estimating the effective sample size of
The `BDAAutocovMethod` uses a standard algorithm for estimating the mean autocovariance of
MCMC chains.

It is is based on the discussion by [^VehtariGelman2021]. and uses the
Expand All @@ -57,37 +57,37 @@ variogram estimator of the autocorrelation function discussed by [^BDA3].
arXiv: [1903.08008](https://arxiv.org/abs/1903.08008)
[^BDA3]: Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A., & Rubin, D. B. (2013). Bayesian data analysis. CRC press.
"""
struct BDAESSMethod <: AbstractESSMethod end
struct BDAAutocovMethod <: AbstractAutocovMethod end

# caches
struct ESSCache{T,S}
struct AutocovCache{T,S}
samples::Matrix{T}
chain_var::Vector{S}
end

struct FFTESSCache{T,S,C,P,I}
struct FFTAutocovCache{T,S,C,P,I}
samples::Matrix{T}
chain_var::Vector{S}
samples_cache::C
plan::P
invplan::I
end

mutable struct BDAESSCache{T,S,M}
mutable struct BDAAutocovCache{T,S,M}
samples::Matrix{T}
chain_var::Vector{S}
mean_chain_var::M
end

function build_cache(::ESSMethod, samples::Matrix, var::Vector)
function build_cache(::AutocovMethod, samples::Matrix, var::Vector)
# check arguments
niter, nchains = size(samples)
length(var) == nchains || throw(DimensionMismatch())

return ESSCache(samples, var)
return AutocovCache(samples, var)
end

function build_cache(::FFTESSMethod, samples::Matrix, var::Vector)
function build_cache(::FFTAutocovMethod, samples::Matrix, var::Vector)
# check arguments
niter, nchains = size(samples)
length(var) == nchains || throw(DimensionMismatch())
Expand All @@ -101,20 +101,20 @@ function build_cache(::FFTESSMethod, samples::Matrix, var::Vector)
fft_plan = AbstractFFTs.plan_fft!(samples_cache, 1)
ifft_plan = AbstractFFTs.plan_ifft!(samples_cache, 1)

return FFTESSCache(samples, var, samples_cache, fft_plan, ifft_plan)
return FFTAutocovCache(samples, var, samples_cache, fft_plan, ifft_plan)
end

function build_cache(::BDAESSMethod, samples::Matrix, var::Vector)
function build_cache(::BDAAutocovMethod, samples::Matrix, var::Vector)
# check arguments
nchains = size(samples, 2)
length(var) == nchains || throw(DimensionMismatch())

return BDAESSCache(samples, var, Statistics.mean(var))
return BDAAutocovCache(samples, var, Statistics.mean(var))
end

update!(cache::ESSCache) = nothing
update!(cache::AutocovCache) = nothing

function update!(cache::FFTESSCache)
function update!(cache::FFTAutocovCache)
# copy samples and add zero padding
samples = cache.samples
samples_cache = cache.samples_cache
Expand All @@ -138,14 +138,14 @@ function update!(cache::FFTESSCache)
return nothing
end

function update!(cache::BDAESSCache)
function update!(cache::BDAAutocovCache)
# recompute mean of within-chain variances
cache.mean_chain_var = Statistics.mean(cache.chain_var)

return nothing
end

function mean_autocov(k::Int, cache::ESSCache)
function mean_autocov(k::Int, cache::AutocovCache)
# check arguments
samples = cache.samples
niter, nchains = size(samples)
Expand All @@ -165,7 +165,7 @@ function mean_autocov(k::Int, cache::ESSCache)
return s / niter
end

function mean_autocov(k::Int, cache::FFTESSCache)
function mean_autocov(k::Int, cache::FFTAutocovCache)
# check arguments
niter, nchains = size(cache.samples)
0 ≤ k < niter || throw(ArgumentError("only lags ≥ 0 and < $niter are supported"))
Expand All @@ -181,7 +181,7 @@ function mean_autocov(k::Int, cache::FFTESSCache)
return result * uncorrection_factor
end

function mean_autocov(k::Int, cache::BDAESSCache)
function mean_autocov(k::Int, cache::BDAAutocovCache)
# check arguments
samples = cache.samples
niter, nchains = size(samples)
Expand All @@ -203,14 +203,14 @@ end
ess(
samples::AbstractArray{<:Union{Missing,Real},3};
kind=:bulk,
method=ESSMethod(),
autocov_method=AutocovMethod(),
split_chains::Int=2,
maxlag::Int=250,
kwargs...
)

Estimate the effective sample size (ESS) of the `samples` of shape
`(draws, chains, parameters)` with the `method`.
`(draws, chains, parameters)` with the `autocov_method`.

Optionally, the `kind` of ESS estimate to be computed can be specified (see below). Some
`kind`s accept additional `kwargs`.
Expand All @@ -223,7 +223,7 @@ than 0.
For a given estimand, it is recommended that the ESS is at least `100 * chains` and that
``\\widehat{R} < 1.01``.[^VehtariGelman2021]

See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref),
See also: [`AutocovMethod`](@ref), [`FFTAutocovMethod`](@ref), [`BDAAutocovMethod`](@ref),
[`rhat`](@ref), [`ess_rhat`](@ref), [`mcse`](@ref)

## Kinds of ESS estimates
Expand Down Expand Up @@ -414,7 +414,7 @@ end
ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kind::Symbol=:rank, kwargs...)

Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shape
`(draws, chains, parameters)` with the `method`.
`(draws, chains, parameters)`.

When both ESS and ``\\widehat{R}`` are needed, this method is often more efficient than
calling `ess` and `rhat` separately.
Expand Down Expand Up @@ -443,7 +443,7 @@ end
function _ess_rhat(
::Val{:basic},
chains::AbstractArray{<:Union{Missing,Real},3};
method::AbstractESSMethod=ESSMethod(),
autocov_method::AbstractAutocovMethod=AutocovMethod(),
split_chains::Int=2,
maxlag::Int=250,
)
Expand Down Expand Up @@ -479,7 +479,7 @@ function _ess_rhat(
correctionfactor = (niter - 1)//niter

# define cache for the computation of the autocorrelation
esscache = build_cache(method, samples, chain_var)
esscache = build_cache(autocov_method, samples, chain_var)

# set maximum ess for antithetic chains, see below
ess_max = ntotal * log10(oftype(one(T), ntotal))
Expand Down
35 changes: 20 additions & 15 deletions test/ess_rhat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ using Statistics
using StatsBase
using Test

struct ExplicitESSMethod <: MCMCDiagnosticTools.AbstractESSMethod end
struct ExplicitAutocovMethod <: MCMCDiagnosticTools.AbstractAutocovMethod end
struct ExplicitESSCache{S}
samples::S
end
function MCMCDiagnosticTools.build_cache(::ExplicitESSMethod, samples::Matrix, var::Vector)
function MCMCDiagnosticTools.build_cache(
::ExplicitAutocovMethod, samples::Matrix, var::Vector
)
return ExplicitESSCache(samples)
end
MCMCDiagnosticTools.update!(::ExplicitESSCache) = nothing
Expand Down Expand Up @@ -139,19 +141,21 @@ mymean(x) = mean(x)
x = randn(1000, 4, 10)
@testset for kind in (:rank, :bulk, :tail, :basic), split_chains in (1, 2)
R1 = rhat(x; kind=kind, split_chains=split_chains)
@testset for method in (ESSMethod(), BDAESSMethod()), maxlag in (100, 10)
@testset for autocov_method in (AutocovMethod(), BDAAutocovMethod()),
maxlag in (100, 10)

S1 = ess(
x;
kind=kind === :rank ? :bulk : kind,
split_chains=split_chains,
method=method,
autocov_method=autocov_method,
maxlag=maxlag,
)
S2, R2 = ess_rhat(
x;
kind=kind,
split_chains=split_chains,
method=method,
autocov_method=autocov_method,
maxlag=maxlag,
)
@test S1 == S2
Expand All @@ -172,13 +176,13 @@ mymean(x) = mean(x)

ess_standard, rhat_standard = ess_rhat(x; split_chains=split_chains)
ess_standard2, rhat_standard2 = ess_rhat(
x; split_chains=split_chains, method=ESSMethod()
x; split_chains=split_chains, autocov_method=AutocovMethod()
)
ess_fft, rhat_fft = ess_rhat(
x; split_chains=split_chains, method=FFTESSMethod()
x; split_chains=split_chains, autocov_method=FFTAutocovMethod()
)
ess_bda, rhat_bda = ess_rhat(
x; split_chains=split_chains, method=BDAESSMethod()
x; split_chains=split_chains, autocov_method=BDAAutocovMethod()
)

# check that we get (roughly) the same results
Expand All @@ -200,9 +204,9 @@ mymean(x) = mean(x)
x = ones(10_000, 10, 40)

ess_standard, rhat_standard = ess_rhat(x)
ess_standard2, rhat_standard2 = ess_rhat(x; method=ESSMethod())
ess_fft, rhat_fft = ess_rhat(x; method=FFTESSMethod())
ess_bda, rhat_bda = ess_rhat(x; method=BDAESSMethod())
ess_standard2, rhat_standard2 = ess_rhat(x; autocov_method=AutocovMethod())
ess_fft, rhat_fft = ess_rhat(x; autocov_method=FFTAutocovMethod())
ess_bda, rhat_bda = ess_rhat(x; autocov_method=BDAAutocovMethod())

# check that the estimates are all NaN
for ess in (ess_standard, ess_standard2, ess_fft, ess_bda)
Expand All @@ -213,11 +217,12 @@ mymean(x) = mean(x)
end
end

@testset "Autocov of ESSMethod and FFTESSMethod equivalent to StatsBase" begin
@testset "Autocov of AutocovMethod and FFTAutocovMethod equivalent to StatsBase" begin
x = randn(1_000, 10, 40)
ess_exp = ess(x; method=ExplicitESSMethod())
@testset "$method" for method in [FFTESSMethod(), ESSMethod()]
@test ess(x; method=method) ≈ ess_exp
ess_exp = ess(x; autocov_method=ExplicitAutocovMethod())
@testset "$autocov_method" for autocov_method in
[FFTAutocovMethod(), AutocovMethod()]
@test ess(x; autocov_method=autocov_method) ≈ ess_exp
end
end

Expand Down