diff --git a/docs/src/index.md b/docs/src/index.md index 5af1f44c..7f280f17 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -12,12 +12,12 @@ rhat ess_rhat ``` -The following `method`s are supported: +The following `autocov_method`s are supported: ```@docs -ESSMethod -FFTESSMethod -BDAESSMethod +AutocovMethod +FFTAutocovMethod +BDAAutocovMethod ``` ## Monte Carlo standard error diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index 4710f483..1ae21a77 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -16,7 +16,7 @@ using Statistics: Statistics export bfmi export discretediag -export ess, ess_rhat, rhat, ESSMethod, FFTESSMethod, BDAESSMethod +export ess, ess_rhat, rhat, AutocovMethod, FFTAutocovMethod, BDAAutocovMethod export gelmandiag, gelmandiag_multivariate export gewekediag export heideldiag diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 3421e523..039ae3ed 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -1,5 +1,5 @@ # methods -abstract type AbstractESSMethod end +abstract type AbstractAutocovMethod end const _DOC_SPLIT_CHAINS = """`split_chains` indicates the number of chains each chain is split into. When `split_chains > 1`, then the diagnostics check for within-chain convergence. When @@ -7,10 +7,10 @@ const _DOC_SPLIT_CHAINS = """`split_chains` indicates the number of chains each is discarded after each of the first `d` splits within each chain.""" """ - ESSMethod <: AbstractESSMethod + AutocovMethod <: AbstractAutocovMethod -The `ESSMethod` uses a standard algorithm for estimating the -effective sample size of MCMC chains. +The `AutocovMethod` uses a standard algorithm for estimating the mean autocovariance of MCMC +chains. It is is based on the discussion by [^VehtariGelman2021] and uses the biased estimator of the autocovariance, as discussed by [^Geyer1992]. @@ -22,15 +22,15 @@ biased estimator of the autocovariance, as discussed by [^Geyer1992]. arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) [^Geyer1992]: Geyer, C. J. (1992). Practical Markov Chain Monte Carlo. Statistical Science, 473-483. """ -struct ESSMethod <: AbstractESSMethod end +struct AutocovMethod <: AbstractAutocovMethod end """ - FFTESSMethod <: AbstractESSMethod + FFTAutocovMethod <: AbstractAutocovMethod -The `FFTESSMethod` uses a standard algorithm for estimating -the effective sample size of MCMC chains. +The `FFTAutocovMethod` uses a standard algorithm for estimating the mean autocovariance of +MCMC chains. -The algorithm is the same as the one of [`ESSMethod`](@ref) but this method uses fast +The algorithm is the same as the one of [`AutocovMethod`](@ref) but this method uses fast Fourier transforms (FFTs) for estimating the autocorrelation. !!! info @@ -39,12 +39,12 @@ Fourier transforms (FFTs) for estimating the autocorrelation. as [FFTW.jl](https://github.com/JuliaMath/FFTW.jl) or [FastTransforms.jl](https://github.com/JuliaApproximation/FastTransforms.jl). """ -struct FFTESSMethod <: AbstractESSMethod end +struct FFTAutocovMethod <: AbstractAutocovMethod end """ - BDAESSMethod <: AbstractESSMethod + BDAAutocovMethod <: AbstractAutocovMethod -The `BDAESSMethod` uses a standard algorithm for estimating the effective sample size of +The `BDAAutocovMethod` uses a standard algorithm for estimating the mean autocovariance of MCMC chains. It is is based on the discussion by [^VehtariGelman2021]. and uses the @@ -57,15 +57,15 @@ variogram estimator of the autocorrelation function discussed by [^BDA3]. arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) [^BDA3]: Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A., & Rubin, D. B. (2013). Bayesian data analysis. CRC press. """ -struct BDAESSMethod <: AbstractESSMethod end +struct BDAAutocovMethod <: AbstractAutocovMethod end # caches -struct ESSCache{T,S} +struct AutocovCache{T,S} samples::Matrix{T} chain_var::Vector{S} end -struct FFTESSCache{T,S,C,P,I} +struct FFTAutocovCache{T,S,C,P,I} samples::Matrix{T} chain_var::Vector{S} samples_cache::C @@ -73,21 +73,21 @@ struct FFTESSCache{T,S,C,P,I} invplan::I end -mutable struct BDAESSCache{T,S,M} +mutable struct BDAAutocovCache{T,S,M} samples::Matrix{T} chain_var::Vector{S} mean_chain_var::M end -function build_cache(::ESSMethod, samples::Matrix, var::Vector) +function build_cache(::AutocovMethod, samples::Matrix, var::Vector) # check arguments niter, nchains = size(samples) length(var) == nchains || throw(DimensionMismatch()) - return ESSCache(samples, var) + return AutocovCache(samples, var) end -function build_cache(::FFTESSMethod, samples::Matrix, var::Vector) +function build_cache(::FFTAutocovMethod, samples::Matrix, var::Vector) # check arguments niter, nchains = size(samples) length(var) == nchains || throw(DimensionMismatch()) @@ -101,20 +101,20 @@ function build_cache(::FFTESSMethod, samples::Matrix, var::Vector) fft_plan = AbstractFFTs.plan_fft!(samples_cache, 1) ifft_plan = AbstractFFTs.plan_ifft!(samples_cache, 1) - return FFTESSCache(samples, var, samples_cache, fft_plan, ifft_plan) + return FFTAutocovCache(samples, var, samples_cache, fft_plan, ifft_plan) end -function build_cache(::BDAESSMethod, samples::Matrix, var::Vector) +function build_cache(::BDAAutocovMethod, samples::Matrix, var::Vector) # check arguments nchains = size(samples, 2) length(var) == nchains || throw(DimensionMismatch()) - return BDAESSCache(samples, var, Statistics.mean(var)) + return BDAAutocovCache(samples, var, Statistics.mean(var)) end -update!(cache::ESSCache) = nothing +update!(cache::AutocovCache) = nothing -function update!(cache::FFTESSCache) +function update!(cache::FFTAutocovCache) # copy samples and add zero padding samples = cache.samples samples_cache = cache.samples_cache @@ -138,14 +138,14 @@ function update!(cache::FFTESSCache) return nothing end -function update!(cache::BDAESSCache) +function update!(cache::BDAAutocovCache) # recompute mean of within-chain variances cache.mean_chain_var = Statistics.mean(cache.chain_var) return nothing end -function mean_autocov(k::Int, cache::ESSCache) +function mean_autocov(k::Int, cache::AutocovCache) # check arguments samples = cache.samples niter, nchains = size(samples) @@ -165,7 +165,7 @@ function mean_autocov(k::Int, cache::ESSCache) return s / niter end -function mean_autocov(k::Int, cache::FFTESSCache) +function mean_autocov(k::Int, cache::FFTAutocovCache) # check arguments niter, nchains = size(cache.samples) 0 ≤ k < niter || throw(ArgumentError("only lags ≥ 0 and < $niter are supported")) @@ -181,7 +181,7 @@ function mean_autocov(k::Int, cache::FFTESSCache) return result * uncorrection_factor end -function mean_autocov(k::Int, cache::BDAESSCache) +function mean_autocov(k::Int, cache::BDAAutocovCache) # check arguments samples = cache.samples niter, nchains = size(samples) @@ -203,14 +203,14 @@ end ess( samples::AbstractArray{<:Union{Missing,Real},3}; kind=:bulk, - method=ESSMethod(), + autocov_method=AutocovMethod(), split_chains::Int=2, maxlag::Int=250, kwargs... ) Estimate the effective sample size (ESS) of the `samples` of shape -`(draws, chains, parameters)` with the `method`. +`(draws, chains, parameters)` with the `autocov_method`. Optionally, the `kind` of ESS estimate to be computed can be specified (see below). Some `kind`s accept additional `kwargs`. @@ -223,7 +223,7 @@ than 0. For a given estimand, it is recommended that the ESS is at least `100 * chains` and that ``\\widehat{R} < 1.01``.[^VehtariGelman2021] -See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref), +See also: [`AutocovMethod`](@ref), [`FFTAutocovMethod`](@ref), [`BDAAutocovMethod`](@ref), [`rhat`](@ref), [`ess_rhat`](@ref), [`mcse`](@ref) ## Kinds of ESS estimates @@ -414,7 +414,7 @@ end ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kind::Symbol=:rank, kwargs...) Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shape -`(draws, chains, parameters)` with the `method`. +`(draws, chains, parameters)`. When both ESS and ``\\widehat{R}`` are needed, this method is often more efficient than calling `ess` and `rhat` separately. @@ -443,7 +443,7 @@ end function _ess_rhat( ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real},3}; - method::AbstractESSMethod=ESSMethod(), + autocov_method::AbstractAutocovMethod=AutocovMethod(), split_chains::Int=2, maxlag::Int=250, ) @@ -479,7 +479,7 @@ function _ess_rhat( correctionfactor = (niter - 1)//niter # define cache for the computation of the autocorrelation - esscache = build_cache(method, samples, chain_var) + esscache = build_cache(autocov_method, samples, chain_var) # set maximum ess for antithetic chains, see below ess_max = ntotal * log10(oftype(one(T), ntotal)) diff --git a/test/ess_rhat.jl b/test/ess_rhat.jl index a639b480..14c8fcfc 100644 --- a/test/ess_rhat.jl +++ b/test/ess_rhat.jl @@ -10,11 +10,13 @@ using Statistics using StatsBase using Test -struct ExplicitESSMethod <: MCMCDiagnosticTools.AbstractESSMethod end +struct ExplicitAutocovMethod <: MCMCDiagnosticTools.AbstractAutocovMethod end struct ExplicitESSCache{S} samples::S end -function MCMCDiagnosticTools.build_cache(::ExplicitESSMethod, samples::Matrix, var::Vector) +function MCMCDiagnosticTools.build_cache( + ::ExplicitAutocovMethod, samples::Matrix, var::Vector +) return ExplicitESSCache(samples) end MCMCDiagnosticTools.update!(::ExplicitESSCache) = nothing @@ -139,19 +141,21 @@ mymean(x) = mean(x) x = randn(1000, 4, 10) @testset for kind in (:rank, :bulk, :tail, :basic), split_chains in (1, 2) R1 = rhat(x; kind=kind, split_chains=split_chains) - @testset for method in (ESSMethod(), BDAESSMethod()), maxlag in (100, 10) + @testset for autocov_method in (AutocovMethod(), BDAAutocovMethod()), + maxlag in (100, 10) + S1 = ess( x; kind=kind === :rank ? :bulk : kind, split_chains=split_chains, - method=method, + autocov_method=autocov_method, maxlag=maxlag, ) S2, R2 = ess_rhat( x; kind=kind, split_chains=split_chains, - method=method, + autocov_method=autocov_method, maxlag=maxlag, ) @test S1 == S2 @@ -172,13 +176,13 @@ mymean(x) = mean(x) ess_standard, rhat_standard = ess_rhat(x; split_chains=split_chains) ess_standard2, rhat_standard2 = ess_rhat( - x; split_chains=split_chains, method=ESSMethod() + x; split_chains=split_chains, autocov_method=AutocovMethod() ) ess_fft, rhat_fft = ess_rhat( - x; split_chains=split_chains, method=FFTESSMethod() + x; split_chains=split_chains, autocov_method=FFTAutocovMethod() ) ess_bda, rhat_bda = ess_rhat( - x; split_chains=split_chains, method=BDAESSMethod() + x; split_chains=split_chains, autocov_method=BDAAutocovMethod() ) # check that we get (roughly) the same results @@ -200,9 +204,9 @@ mymean(x) = mean(x) x = ones(10_000, 10, 40) ess_standard, rhat_standard = ess_rhat(x) - ess_standard2, rhat_standard2 = ess_rhat(x; method=ESSMethod()) - ess_fft, rhat_fft = ess_rhat(x; method=FFTESSMethod()) - ess_bda, rhat_bda = ess_rhat(x; method=BDAESSMethod()) + ess_standard2, rhat_standard2 = ess_rhat(x; autocov_method=AutocovMethod()) + ess_fft, rhat_fft = ess_rhat(x; autocov_method=FFTAutocovMethod()) + ess_bda, rhat_bda = ess_rhat(x; autocov_method=BDAAutocovMethod()) # check that the estimates are all NaN for ess in (ess_standard, ess_standard2, ess_fft, ess_bda) @@ -213,11 +217,12 @@ mymean(x) = mean(x) end end - @testset "Autocov of ESSMethod and FFTESSMethod equivalent to StatsBase" begin + @testset "Autocov of AutocovMethod and FFTAutocovMethod equivalent to StatsBase" begin x = randn(1_000, 10, 40) - ess_exp = ess(x; method=ExplicitESSMethod()) - @testset "$method" for method in [FFTESSMethod(), ESSMethod()] - @test ess(x; method=method) ≈ ess_exp + ess_exp = ess(x; autocov_method=ExplicitAutocovMethod()) + @testset "$autocov_method" for autocov_method in + [FFTAutocovMethod(), AutocovMethod()] + @test ess(x; autocov_method=autocov_method) ≈ ess_exp end end