From 2937084fafe43be94aa0cea1b502087226b19f3b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 23:57:15 +0100 Subject: [PATCH 1/5] Rename method to autocov_method --- src/ess_rhat.jl | 68 ++++++++++++++++++++++++------------------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 3421e523..5fcb73cd 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -1,5 +1,5 @@ # methods -abstract type AbstractESSMethod end +abstract type AbstractAutoCovMethod end const _DOC_SPLIT_CHAINS = """`split_chains` indicates the number of chains each chain is split into. When `split_chains > 1`, then the diagnostics check for within-chain convergence. When @@ -7,10 +7,10 @@ const _DOC_SPLIT_CHAINS = """`split_chains` indicates the number of chains each is discarded after each of the first `d` splits within each chain.""" """ - ESSMethod <: AbstractESSMethod + AutoCovMethod <: AbstractAutoCovMethod -The `ESSMethod` uses a standard algorithm for estimating the -effective sample size of MCMC chains. +The `AutoCovMethod` uses a standard algorithm for estimating the mean autocovariance of MCMC +chains. It is is based on the discussion by [^VehtariGelman2021] and uses the biased estimator of the autocovariance, as discussed by [^Geyer1992]. @@ -22,15 +22,15 @@ biased estimator of the autocovariance, as discussed by [^Geyer1992]. arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) [^Geyer1992]: Geyer, C. J. (1992). Practical Markov Chain Monte Carlo. Statistical Science, 473-483. """ -struct ESSMethod <: AbstractESSMethod end +struct AutoCovMethod <: AbstractAutoCovMethod end """ - FFTESSMethod <: AbstractESSMethod + FFTAutoCovMethod <: AbstractAutoCovMethod -The `FFTESSMethod` uses a standard algorithm for estimating -the effective sample size of MCMC chains. +The `FFTAutoCovMethod` uses a standard algorithm for estimating the mean autocovariance of +MCMC chains. -The algorithm is the same as the one of [`ESSMethod`](@ref) but this method uses fast +The algorithm is the same as the one of [`AutoCovMethod`](@ref) but this method uses fast Fourier transforms (FFTs) for estimating the autocorrelation. !!! info @@ -39,12 +39,12 @@ Fourier transforms (FFTs) for estimating the autocorrelation. as [FFTW.jl](https://github.com/JuliaMath/FFTW.jl) or [FastTransforms.jl](https://github.com/JuliaApproximation/FastTransforms.jl). """ -struct FFTESSMethod <: AbstractESSMethod end +struct FFTAutoCovMethod <: AbstractAutoCovMethod end """ - BDAESSMethod <: AbstractESSMethod + BDAAutoCovMethod <: AbstractAutoCovMethod -The `BDAESSMethod` uses a standard algorithm for estimating the effective sample size of +The `BDAAutoCovMethod` uses a standard algorithm for estimating the mean autocovariance of MCMC chains. It is is based on the discussion by [^VehtariGelman2021]. and uses the @@ -57,15 +57,15 @@ variogram estimator of the autocorrelation function discussed by [^BDA3]. arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) [^BDA3]: Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A., & Rubin, D. B. (2013). Bayesian data analysis. CRC press. """ -struct BDAESSMethod <: AbstractESSMethod end +struct BDAAutoCovMethod <: AbstractAutoCovMethod end # caches -struct ESSCache{T,S} +struct AutoCovCache{T,S} samples::Matrix{T} chain_var::Vector{S} end -struct FFTESSCache{T,S,C,P,I} +struct FFTAutoCovCache{T,S,C,P,I} samples::Matrix{T} chain_var::Vector{S} samples_cache::C @@ -73,21 +73,21 @@ struct FFTESSCache{T,S,C,P,I} invplan::I end -mutable struct BDAESSCache{T,S,M} +mutable struct BDAAutoCovCache{T,S,M} samples::Matrix{T} chain_var::Vector{S} mean_chain_var::M end -function build_cache(::ESSMethod, samples::Matrix, var::Vector) +function build_cache(::AutoCovMethod, samples::Matrix, var::Vector) # check arguments niter, nchains = size(samples) length(var) == nchains || throw(DimensionMismatch()) - return ESSCache(samples, var) + return AutoCovCache(samples, var) end -function build_cache(::FFTESSMethod, samples::Matrix, var::Vector) +function build_cache(::FFTAutoCovMethod, samples::Matrix, var::Vector) # check arguments niter, nchains = size(samples) length(var) == nchains || throw(DimensionMismatch()) @@ -101,20 +101,20 @@ function build_cache(::FFTESSMethod, samples::Matrix, var::Vector) fft_plan = AbstractFFTs.plan_fft!(samples_cache, 1) ifft_plan = AbstractFFTs.plan_ifft!(samples_cache, 1) - return FFTESSCache(samples, var, samples_cache, fft_plan, ifft_plan) + return FFTAutoCovCache(samples, var, samples_cache, fft_plan, ifft_plan) end -function build_cache(::BDAESSMethod, samples::Matrix, var::Vector) +function build_cache(::BDAAutoCovMethod, samples::Matrix, var::Vector) # check arguments nchains = size(samples, 2) length(var) == nchains || throw(DimensionMismatch()) - return BDAESSCache(samples, var, Statistics.mean(var)) + return BDAAutoCovCache(samples, var, Statistics.mean(var)) end -update!(cache::ESSCache) = nothing +update!(cache::AutoCovCache) = nothing -function update!(cache::FFTESSCache) +function update!(cache::FFTAutoCovCache) # copy samples and add zero padding samples = cache.samples samples_cache = cache.samples_cache @@ -138,14 +138,14 @@ function update!(cache::FFTESSCache) return nothing end -function update!(cache::BDAESSCache) +function update!(cache::BDAAutoCovCache) # recompute mean of within-chain variances cache.mean_chain_var = Statistics.mean(cache.chain_var) return nothing end -function mean_autocov(k::Int, cache::ESSCache) +function mean_autocov(k::Int, cache::AutoCovCache) # check arguments samples = cache.samples niter, nchains = size(samples) @@ -165,7 +165,7 @@ function mean_autocov(k::Int, cache::ESSCache) return s / niter end -function mean_autocov(k::Int, cache::FFTESSCache) +function mean_autocov(k::Int, cache::FFTAutoCovCache) # check arguments niter, nchains = size(cache.samples) 0 ≤ k < niter || throw(ArgumentError("only lags ≥ 0 and < $niter are supported")) @@ -181,7 +181,7 @@ function mean_autocov(k::Int, cache::FFTESSCache) return result * uncorrection_factor end -function mean_autocov(k::Int, cache::BDAESSCache) +function mean_autocov(k::Int, cache::BDAAutoCovCache) # check arguments samples = cache.samples niter, nchains = size(samples) @@ -203,14 +203,14 @@ end ess( samples::AbstractArray{<:Union{Missing,Real},3}; kind=:bulk, - method=ESSMethod(), + autocov_method=AutoCovMethod(), split_chains::Int=2, maxlag::Int=250, kwargs... ) Estimate the effective sample size (ESS) of the `samples` of shape -`(draws, chains, parameters)` with the `method`. +`(draws, chains, parameters)` with the `autocov_method`. Optionally, the `kind` of ESS estimate to be computed can be specified (see below). Some `kind`s accept additional `kwargs`. @@ -223,7 +223,7 @@ than 0. For a given estimand, it is recommended that the ESS is at least `100 * chains` and that ``\\widehat{R} < 1.01``.[^VehtariGelman2021] -See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref), +See also: [`AutoCovMethod`](@ref), [`FFTAutoCovMethod`](@ref), [`BDAAutoCovMethod`](@ref), [`rhat`](@ref), [`ess_rhat`](@ref), [`mcse`](@ref) ## Kinds of ESS estimates @@ -414,7 +414,7 @@ end ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kind::Symbol=:rank, kwargs...) Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shape -`(draws, chains, parameters)` with the `method`. +`(draws, chains, parameters)`. When both ESS and ``\\widehat{R}`` are needed, this method is often more efficient than calling `ess` and `rhat` separately. @@ -443,7 +443,7 @@ end function _ess_rhat( ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real},3}; - method::AbstractESSMethod=ESSMethod(), + autocov_method::AbstractAutoCovMethod=AutoCovMethod(), split_chains::Int=2, maxlag::Int=250, ) @@ -479,7 +479,7 @@ function _ess_rhat( correctionfactor = (niter - 1)//niter # define cache for the computation of the autocorrelation - esscache = build_cache(method, samples, chain_var) + esscache = build_cache(autocov_method, samples, chain_var) # set maximum ess for antithetic chains, see below ess_max = ntotal * log10(oftype(one(T), ntotal)) From d3c56fe517deb7add11ab842a49b01ab2e7b0d3d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 23:57:27 +0100 Subject: [PATCH 2/5] Update exports --- src/MCMCDiagnosticTools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index 4710f483..842aa824 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -16,7 +16,7 @@ using Statistics: Statistics export bfmi export discretediag -export ess, ess_rhat, rhat, ESSMethod, FFTESSMethod, BDAESSMethod +export ess, ess_rhat, rhat, AutoCovMethod, FFTAutoCovMethod, BDAAutoCovMethod export gelmandiag, gelmandiag_multivariate export gewekediag export heideldiag From c42e25bebd357e21564f13545ec2db5d712d055b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 23:57:33 +0100 Subject: [PATCH 3/5] Update tests --- test/ess_rhat.jl | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/test/ess_rhat.jl b/test/ess_rhat.jl index a639b480..2102ed8f 100644 --- a/test/ess_rhat.jl +++ b/test/ess_rhat.jl @@ -10,11 +10,13 @@ using Statistics using StatsBase using Test -struct ExplicitESSMethod <: MCMCDiagnosticTools.AbstractESSMethod end +struct ExplicitAutoCovMethod <: MCMCDiagnosticTools.AbstractAutoCovMethod end struct ExplicitESSCache{S} samples::S end -function MCMCDiagnosticTools.build_cache(::ExplicitESSMethod, samples::Matrix, var::Vector) +function MCMCDiagnosticTools.build_cache( + ::ExplicitAutoCovMethod, samples::Matrix, var::Vector +) return ExplicitESSCache(samples) end MCMCDiagnosticTools.update!(::ExplicitESSCache) = nothing @@ -139,19 +141,21 @@ mymean(x) = mean(x) x = randn(1000, 4, 10) @testset for kind in (:rank, :bulk, :tail, :basic), split_chains in (1, 2) R1 = rhat(x; kind=kind, split_chains=split_chains) - @testset for method in (ESSMethod(), BDAESSMethod()), maxlag in (100, 10) + @testset for autocov_method in (AutoCovMethod(), BDAAutoCovMethod()), + maxlag in (100, 10) + S1 = ess( x; kind=kind === :rank ? :bulk : kind, split_chains=split_chains, - method=method, + autocov_method=autocov_method, maxlag=maxlag, ) S2, R2 = ess_rhat( x; kind=kind, split_chains=split_chains, - method=method, + autocov_method=autocov_method, maxlag=maxlag, ) @test S1 == S2 @@ -172,13 +176,13 @@ mymean(x) = mean(x) ess_standard, rhat_standard = ess_rhat(x; split_chains=split_chains) ess_standard2, rhat_standard2 = ess_rhat( - x; split_chains=split_chains, method=ESSMethod() + x; split_chains=split_chains, autocov_method=AutoCovMethod() ) ess_fft, rhat_fft = ess_rhat( - x; split_chains=split_chains, method=FFTESSMethod() + x; split_chains=split_chains, autocov_method=FFTAutoCovMethod() ) ess_bda, rhat_bda = ess_rhat( - x; split_chains=split_chains, method=BDAESSMethod() + x; split_chains=split_chains, autocov_method=BDAAutoCovMethod() ) # check that we get (roughly) the same results @@ -200,9 +204,9 @@ mymean(x) = mean(x) x = ones(10_000, 10, 40) ess_standard, rhat_standard = ess_rhat(x) - ess_standard2, rhat_standard2 = ess_rhat(x; method=ESSMethod()) - ess_fft, rhat_fft = ess_rhat(x; method=FFTESSMethod()) - ess_bda, rhat_bda = ess_rhat(x; method=BDAESSMethod()) + ess_standard2, rhat_standard2 = ess_rhat(x; autocov_method=AutoCovMethod()) + ess_fft, rhat_fft = ess_rhat(x; autocov_method=FFTAutoCovMethod()) + ess_bda, rhat_bda = ess_rhat(x; autocov_method=BDAAutoCovMethod()) # check that the estimates are all NaN for ess in (ess_standard, ess_standard2, ess_fft, ess_bda) @@ -213,11 +217,12 @@ mymean(x) = mean(x) end end - @testset "Autocov of ESSMethod and FFTESSMethod equivalent to StatsBase" begin + @testset "Autocov of AutoCovMethod and FFTAutoCovMethod equivalent to StatsBase" begin x = randn(1_000, 10, 40) - ess_exp = ess(x; method=ExplicitESSMethod()) - @testset "$method" for method in [FFTESSMethod(), ESSMethod()] - @test ess(x; method=method) ≈ ess_exp + ess_exp = ess(x; autocov_method=ExplicitAutoCovMethod()) + @testset "$autocov_method" for autocov_method in + [FFTAutoCovMethod(), AutoCovMethod()] + @test ess(x; autocov_method=autocov_method) ≈ ess_exp end end From f4851b515a862f01c08b1bc7d68c2c59e5c00fa4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 23:57:38 +0100 Subject: [PATCH 4/5] Update docs --- docs/src/index.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 5af1f44c..b64b73da 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -12,12 +12,12 @@ rhat ess_rhat ``` -The following `method`s are supported: +The following `autocov_method`s are supported: ```@docs -ESSMethod -FFTESSMethod -BDAESSMethod +AutoCovMethod +FFTAutoCovMethod +BDAAutoCovMethod ``` ## Monte Carlo standard error From 212c39b409104d02ae9ad0a141064bc3b4b6aca4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 23:59:11 +0100 Subject: [PATCH 5/5] Update capitalization --- docs/src/index.md | 6 ++-- src/MCMCDiagnosticTools.jl | 2 +- src/ess_rhat.jl | 58 +++++++++++++++++++------------------- test/ess_rhat.jl | 24 ++++++++-------- 4 files changed, 45 insertions(+), 45 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index b64b73da..7f280f17 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -15,9 +15,9 @@ ess_rhat The following `autocov_method`s are supported: ```@docs -AutoCovMethod -FFTAutoCovMethod -BDAAutoCovMethod +AutocovMethod +FFTAutocovMethod +BDAAutocovMethod ``` ## Monte Carlo standard error diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index 842aa824..1ae21a77 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -16,7 +16,7 @@ using Statistics: Statistics export bfmi export discretediag -export ess, ess_rhat, rhat, AutoCovMethod, FFTAutoCovMethod, BDAAutoCovMethod +export ess, ess_rhat, rhat, AutocovMethod, FFTAutocovMethod, BDAAutocovMethod export gelmandiag, gelmandiag_multivariate export gewekediag export heideldiag diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 5fcb73cd..039ae3ed 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -1,5 +1,5 @@ # methods -abstract type AbstractAutoCovMethod end +abstract type AbstractAutocovMethod end const _DOC_SPLIT_CHAINS = """`split_chains` indicates the number of chains each chain is split into. When `split_chains > 1`, then the diagnostics check for within-chain convergence. When @@ -7,9 +7,9 @@ const _DOC_SPLIT_CHAINS = """`split_chains` indicates the number of chains each is discarded after each of the first `d` splits within each chain.""" """ - AutoCovMethod <: AbstractAutoCovMethod + AutocovMethod <: AbstractAutocovMethod -The `AutoCovMethod` uses a standard algorithm for estimating the mean autocovariance of MCMC +The `AutocovMethod` uses a standard algorithm for estimating the mean autocovariance of MCMC chains. It is is based on the discussion by [^VehtariGelman2021] and uses the @@ -22,15 +22,15 @@ biased estimator of the autocovariance, as discussed by [^Geyer1992]. arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) [^Geyer1992]: Geyer, C. J. (1992). Practical Markov Chain Monte Carlo. Statistical Science, 473-483. """ -struct AutoCovMethod <: AbstractAutoCovMethod end +struct AutocovMethod <: AbstractAutocovMethod end """ - FFTAutoCovMethod <: AbstractAutoCovMethod + FFTAutocovMethod <: AbstractAutocovMethod -The `FFTAutoCovMethod` uses a standard algorithm for estimating the mean autocovariance of +The `FFTAutocovMethod` uses a standard algorithm for estimating the mean autocovariance of MCMC chains. -The algorithm is the same as the one of [`AutoCovMethod`](@ref) but this method uses fast +The algorithm is the same as the one of [`AutocovMethod`](@ref) but this method uses fast Fourier transforms (FFTs) for estimating the autocorrelation. !!! info @@ -39,12 +39,12 @@ Fourier transforms (FFTs) for estimating the autocorrelation. as [FFTW.jl](https://github.com/JuliaMath/FFTW.jl) or [FastTransforms.jl](https://github.com/JuliaApproximation/FastTransforms.jl). """ -struct FFTAutoCovMethod <: AbstractAutoCovMethod end +struct FFTAutocovMethod <: AbstractAutocovMethod end """ - BDAAutoCovMethod <: AbstractAutoCovMethod + BDAAutocovMethod <: AbstractAutocovMethod -The `BDAAutoCovMethod` uses a standard algorithm for estimating the mean autocovariance of +The `BDAAutocovMethod` uses a standard algorithm for estimating the mean autocovariance of MCMC chains. It is is based on the discussion by [^VehtariGelman2021]. and uses the @@ -57,15 +57,15 @@ variogram estimator of the autocorrelation function discussed by [^BDA3]. arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) [^BDA3]: Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A., & Rubin, D. B. (2013). Bayesian data analysis. CRC press. """ -struct BDAAutoCovMethod <: AbstractAutoCovMethod end +struct BDAAutocovMethod <: AbstractAutocovMethod end # caches -struct AutoCovCache{T,S} +struct AutocovCache{T,S} samples::Matrix{T} chain_var::Vector{S} end -struct FFTAutoCovCache{T,S,C,P,I} +struct FFTAutocovCache{T,S,C,P,I} samples::Matrix{T} chain_var::Vector{S} samples_cache::C @@ -73,21 +73,21 @@ struct FFTAutoCovCache{T,S,C,P,I} invplan::I end -mutable struct BDAAutoCovCache{T,S,M} +mutable struct BDAAutocovCache{T,S,M} samples::Matrix{T} chain_var::Vector{S} mean_chain_var::M end -function build_cache(::AutoCovMethod, samples::Matrix, var::Vector) +function build_cache(::AutocovMethod, samples::Matrix, var::Vector) # check arguments niter, nchains = size(samples) length(var) == nchains || throw(DimensionMismatch()) - return AutoCovCache(samples, var) + return AutocovCache(samples, var) end -function build_cache(::FFTAutoCovMethod, samples::Matrix, var::Vector) +function build_cache(::FFTAutocovMethod, samples::Matrix, var::Vector) # check arguments niter, nchains = size(samples) length(var) == nchains || throw(DimensionMismatch()) @@ -101,20 +101,20 @@ function build_cache(::FFTAutoCovMethod, samples::Matrix, var::Vector) fft_plan = AbstractFFTs.plan_fft!(samples_cache, 1) ifft_plan = AbstractFFTs.plan_ifft!(samples_cache, 1) - return FFTAutoCovCache(samples, var, samples_cache, fft_plan, ifft_plan) + return FFTAutocovCache(samples, var, samples_cache, fft_plan, ifft_plan) end -function build_cache(::BDAAutoCovMethod, samples::Matrix, var::Vector) +function build_cache(::BDAAutocovMethod, samples::Matrix, var::Vector) # check arguments nchains = size(samples, 2) length(var) == nchains || throw(DimensionMismatch()) - return BDAAutoCovCache(samples, var, Statistics.mean(var)) + return BDAAutocovCache(samples, var, Statistics.mean(var)) end -update!(cache::AutoCovCache) = nothing +update!(cache::AutocovCache) = nothing -function update!(cache::FFTAutoCovCache) +function update!(cache::FFTAutocovCache) # copy samples and add zero padding samples = cache.samples samples_cache = cache.samples_cache @@ -138,14 +138,14 @@ function update!(cache::FFTAutoCovCache) return nothing end -function update!(cache::BDAAutoCovCache) +function update!(cache::BDAAutocovCache) # recompute mean of within-chain variances cache.mean_chain_var = Statistics.mean(cache.chain_var) return nothing end -function mean_autocov(k::Int, cache::AutoCovCache) +function mean_autocov(k::Int, cache::AutocovCache) # check arguments samples = cache.samples niter, nchains = size(samples) @@ -165,7 +165,7 @@ function mean_autocov(k::Int, cache::AutoCovCache) return s / niter end -function mean_autocov(k::Int, cache::FFTAutoCovCache) +function mean_autocov(k::Int, cache::FFTAutocovCache) # check arguments niter, nchains = size(cache.samples) 0 ≤ k < niter || throw(ArgumentError("only lags ≥ 0 and < $niter are supported")) @@ -181,7 +181,7 @@ function mean_autocov(k::Int, cache::FFTAutoCovCache) return result * uncorrection_factor end -function mean_autocov(k::Int, cache::BDAAutoCovCache) +function mean_autocov(k::Int, cache::BDAAutocovCache) # check arguments samples = cache.samples niter, nchains = size(samples) @@ -203,7 +203,7 @@ end ess( samples::AbstractArray{<:Union{Missing,Real},3}; kind=:bulk, - autocov_method=AutoCovMethod(), + autocov_method=AutocovMethod(), split_chains::Int=2, maxlag::Int=250, kwargs... @@ -223,7 +223,7 @@ than 0. For a given estimand, it is recommended that the ESS is at least `100 * chains` and that ``\\widehat{R} < 1.01``.[^VehtariGelman2021] -See also: [`AutoCovMethod`](@ref), [`FFTAutoCovMethod`](@ref), [`BDAAutoCovMethod`](@ref), +See also: [`AutocovMethod`](@ref), [`FFTAutocovMethod`](@ref), [`BDAAutocovMethod`](@ref), [`rhat`](@ref), [`ess_rhat`](@ref), [`mcse`](@ref) ## Kinds of ESS estimates @@ -443,7 +443,7 @@ end function _ess_rhat( ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real},3}; - autocov_method::AbstractAutoCovMethod=AutoCovMethod(), + autocov_method::AbstractAutocovMethod=AutocovMethod(), split_chains::Int=2, maxlag::Int=250, ) diff --git a/test/ess_rhat.jl b/test/ess_rhat.jl index 2102ed8f..14c8fcfc 100644 --- a/test/ess_rhat.jl +++ b/test/ess_rhat.jl @@ -10,12 +10,12 @@ using Statistics using StatsBase using Test -struct ExplicitAutoCovMethod <: MCMCDiagnosticTools.AbstractAutoCovMethod end +struct ExplicitAutocovMethod <: MCMCDiagnosticTools.AbstractAutocovMethod end struct ExplicitESSCache{S} samples::S end function MCMCDiagnosticTools.build_cache( - ::ExplicitAutoCovMethod, samples::Matrix, var::Vector + ::ExplicitAutocovMethod, samples::Matrix, var::Vector ) return ExplicitESSCache(samples) end @@ -141,7 +141,7 @@ mymean(x) = mean(x) x = randn(1000, 4, 10) @testset for kind in (:rank, :bulk, :tail, :basic), split_chains in (1, 2) R1 = rhat(x; kind=kind, split_chains=split_chains) - @testset for autocov_method in (AutoCovMethod(), BDAAutoCovMethod()), + @testset for autocov_method in (AutocovMethod(), BDAAutocovMethod()), maxlag in (100, 10) S1 = ess( @@ -176,13 +176,13 @@ mymean(x) = mean(x) ess_standard, rhat_standard = ess_rhat(x; split_chains=split_chains) ess_standard2, rhat_standard2 = ess_rhat( - x; split_chains=split_chains, autocov_method=AutoCovMethod() + x; split_chains=split_chains, autocov_method=AutocovMethod() ) ess_fft, rhat_fft = ess_rhat( - x; split_chains=split_chains, autocov_method=FFTAutoCovMethod() + x; split_chains=split_chains, autocov_method=FFTAutocovMethod() ) ess_bda, rhat_bda = ess_rhat( - x; split_chains=split_chains, autocov_method=BDAAutoCovMethod() + x; split_chains=split_chains, autocov_method=BDAAutocovMethod() ) # check that we get (roughly) the same results @@ -204,9 +204,9 @@ mymean(x) = mean(x) x = ones(10_000, 10, 40) ess_standard, rhat_standard = ess_rhat(x) - ess_standard2, rhat_standard2 = ess_rhat(x; autocov_method=AutoCovMethod()) - ess_fft, rhat_fft = ess_rhat(x; autocov_method=FFTAutoCovMethod()) - ess_bda, rhat_bda = ess_rhat(x; autocov_method=BDAAutoCovMethod()) + ess_standard2, rhat_standard2 = ess_rhat(x; autocov_method=AutocovMethod()) + ess_fft, rhat_fft = ess_rhat(x; autocov_method=FFTAutocovMethod()) + ess_bda, rhat_bda = ess_rhat(x; autocov_method=BDAAutocovMethod()) # check that the estimates are all NaN for ess in (ess_standard, ess_standard2, ess_fft, ess_bda) @@ -217,11 +217,11 @@ mymean(x) = mean(x) end end - @testset "Autocov of AutoCovMethod and FFTAutoCovMethod equivalent to StatsBase" begin + @testset "Autocov of AutocovMethod and FFTAutocovMethod equivalent to StatsBase" begin x = randn(1_000, 10, 40) - ess_exp = ess(x; autocov_method=ExplicitAutoCovMethod()) + ess_exp = ess(x; autocov_method=ExplicitAutocovMethod()) @testset "$autocov_method" for autocov_method in - [FFTAutoCovMethod(), AutoCovMethod()] + [FFTAutocovMethod(), AutocovMethod()] @test ess(x; autocov_method=autocov_method) ≈ ess_exp end end