From 6088765e3ff7c9c47df54ff5f0aa8a58532406eb Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 23:46:40 +0100 Subject: [PATCH] Use keyword syntax for ess, ess_rhat, rhat, and mcse (#72) * Create rhat.jl * Add missing _expectand_proxy method * Remove duplicate rhat_tail * Update ESS methods * Define estimator keyword for mcse * Reduce exports * Fix bug * Run formatter * Shortcut rank-normalizing a `Missing` * Update ESS tests * Use new mcse syntax * Update documented methods * Update text * Fix test * Use new mcse signature * Update type kwarg to estimator * Update mcse docstring * Fix variable name * Make tail-ESS work for typeunion with Missing * Add rank for ESS * Rearrange tests * Test against ess * Add consistency test * Add inferrability tests for estimators * Add more error tests * Try to resolve type-instability * Add Compat for `@constprop` macro * Aggressively constprop * Test type-inferrability with Vals * Make work for typeunions with Missing on v1.6 * Avoid repeating split_chains docs * Rename ess.jl to ess_rhat.jl * Move rhat to same file as ess * Update costring * Run formatter * Remove rhat.jl * Better unify tests * Add missing type annotations * Apply suggestions from code review Co-authored-by: David Widmann * Avoid `@constprop`, rename `type` and `estimator` to `kind` * Cleanup * Rename `estimator` to `kind` for `mcse` * Update tests * Update kinds accepted for ESS * Remove unused utility * Constrain kind to Symbol when applicable * Remove stray tick * Apply suggestions from code review Co-authored-by: David Widmann * Apply suggestions from code review Co-authored-by: David Widmann * Add comment explaining kind checks * Change vectors to tuples --------- Co-authored-by: David Widmann --- docs/src/index.md | 9 +- src/MCMCDiagnosticTools.jl | 4 +- src/{ess.jl => ess_rhat.jl} | 347 ++++++++++++++++++++++++----------- src/gewekediag.jl | 4 +- src/heideldiag.jl | 4 +- src/mcse.jl | 41 +++-- src/utils.jl | 7 +- test/{ess.jl => ess_rhat.jl} | 229 +++++++++++++++-------- test/mcse.jl | 31 ++-- test/runtests.jl | 4 +- 10 files changed, 445 insertions(+), 235 deletions(-) rename src/{ess.jl => ess_rhat.jl} (59%) rename test/{ess.jl => ess_rhat.jl} (54%) diff --git a/docs/src/index.md b/docs/src/index.md index 54a29e74..5af1f44c 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -6,16 +6,13 @@ CurrentModule = MCMCDiagnosticTools ## Effective sample size and $\widehat{R}$ -The effective sample size (ESS) and $\widehat{R}$ can be estimated with [`ess_rhat`](@ref). - ```@docs +ess +rhat ess_rhat -ess_rhat_bulk -ess_tail -rhat_tail ``` -The following methods are supported: +The following `method`s are supported: ```@docs ESSMethod diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index e237abef..4710f483 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -16,7 +16,7 @@ using Statistics: Statistics export bfmi export discretediag -export ess_rhat, ess_rhat_bulk, ess_tail, rhat_tail, ESSMethod, FFTESSMethod, BDAESSMethod +export ess, ess_rhat, rhat, ESSMethod, FFTESSMethod, BDAESSMethod export gelmandiag, gelmandiag_multivariate export gewekediag export heideldiag @@ -27,7 +27,7 @@ export rstar include("utils.jl") include("bfmi.jl") include("discretediag.jl") -include("ess.jl") +include("ess_rhat.jl") include("gelmandiag.jl") include("gewekediag.jl") include("heideldiag.jl") diff --git a/src/ess.jl b/src/ess_rhat.jl similarity index 59% rename from src/ess.jl rename to src/ess_rhat.jl index 241fc673..3421e523 100644 --- a/src/ess.jl +++ b/src/ess_rhat.jl @@ -1,6 +1,11 @@ # methods abstract type AbstractESSMethod end +const _DOC_SPLIT_CHAINS = """`split_chains` indicates the number of chains each chain is split into. + When `split_chains > 1`, then the diagnostics check for within-chain convergence. When + `d = mod(draws, split_chains) > 0`, i.e. the chains cannot be evenly split, then 1 draw + is discarded after each of the first `d` splits within each chain.""" + """ ESSMethod <: AbstractESSMethod @@ -195,25 +200,22 @@ function mean_autocov(k::Int, cache::BDAESSCache) end """ - ess_rhat( - [estimator,] + ess( samples::AbstractArray{<:Union{Missing,Real},3}; + kind=:bulk, method=ESSMethod(), split_chains::Int=2, maxlag::Int=250, + kwargs... ) -Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shape +Estimate the effective sample size (ESS) of the `samples` of shape `(draws, chains, parameters)` with the `method`. -By default, the computed ESS and ``\\widehat{R}`` values correspond to the estimator `mean`. -Other estimators can be specified by passing a function `estimator` (see below). +Optionally, the `kind` of ESS estimate to be computed can be specified (see below). Some +`kind`s accept additional `kwargs`. -`split_chains` indicates the number of chains each chain is split into. -When `split_chains > 1`, then the diagnostics check for within-chain convergence. When -`d = mod(draws, split_chains) > 0`, i.e. the chains cannot be evenly split, then 1 draw -is discarded after each of the first `d` splits within each chain. There must be at least -3 draws in each chain after splitting. +$_DOC_SPLIT_CHAINS There must be at least 3 draws in each chain after splitting. `maxlag` indicates the maximum lag for which autocovariance is computed and must be greater than 0. @@ -222,11 +224,27 @@ For a given estimand, it is recommended that the ESS is at least `100 * chains` ``\\widehat{R} < 1.01``.[^VehtariGelman2021] See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref), -[`ess_rhat_bulk`](@ref), [`ess_tail`](@ref), [`rhat_tail`](@ref), [`mcse`](@ref) - -## Estimators - -The ESS and ``\\widehat{R}`` values can be computed for the following estimators: +[`rhat`](@ref), [`ess_rhat`](@ref), [`mcse`](@ref) + +## Kinds of ESS estimates + +If `kind` isa a `Symbol`, it may take one of the following values: +- `:bulk`: basic ESS computed on rank-normalized draws. This kind diagnoses poor convergence + in the bulk of the distribution due to trends or different locations of the chains. +- `:tail`: minimum of the quantile-ESS for the symmetric quantiles where + `tail_prob=0.1` is the probability in the tails. This kind diagnoses poor convergence in + the tails of the distribution. If this kind is chosen, `kwargs` may contain a + `tail_prob` keyword. +- `:basic`: basic ESS, equivalent to specifying `kind=Statistics.mean`. + +!!! note + While Bulk-ESS is conceptually related to basic ESS, it is well-defined even if the + chains do not have finite variance.[^VehtariGelman2021] For each parameter, + rank-normalization proceeds by first ranking the inputs using "tied ranking" and then + transforming the ranks to normal quantiles so that the result is standard normally + distributed. This transform is monotonic. + +Otherwise, `kind` specifies one of the following estimators, whose ESS is to be estimated: - `Statistics.mean` - `Statistics.median` - `Statistics.std` @@ -239,19 +257,191 @@ The ESS and ``\\widehat{R}`` values can be computed for the following estimators doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) """ -function ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - return ess_rhat(Statistics.mean, samples; kwargs...) +function ess(samples::AbstractArray{<:Union{Missing,Real},3}; kind=:bulk, kwargs...) + # if we just call _ess(Val(kind), ...) Julia cannot infer the return type with default + # const-propagation. We keep this type-inferrable by manually dispatching to the cases. + if kind === :bulk + return _ess(Val(:bulk), samples; kwargs...) + elseif kind === :tail + return _ess(Val(:tail), samples; kwargs...) + elseif kind === :basic + return _ess(Val(:basic), samples; kwargs...) + elseif kind isa Symbol + throw(ArgumentError("the `kind` `$kind` is not supported by `ess`")) + else + return _ess(kind, samples; kwargs...) + end end -function ess_rhat(f, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - x = _expectand_proxy(f, samples) +function _ess(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + x = _expectand_proxy(estimator, samples) if x === nothing - throw(ArgumentError("the estimator $f is not yet supported by `ess_rhat`")) + throw(ArgumentError("the estimator $estimator is not yet supported by `ess`")) end - values = ess_rhat(Statistics.mean, x; kwargs...) - return values + return _ess(Val(:basic), x; kwargs...) +end +function _ess(kind::Val, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + return first(_ess_rhat(kind, samples; kwargs...)) +end +function _ess( + ::Val{:tail}, + x::AbstractArray{<:Union{Missing,Real},3}; + tail_prob::Real=1//10, + kwargs..., +) + # workaround for https://github.com/JuliaStats/Statistics.jl/issues/136 + T = Base.promote_eltype(x, tail_prob) + pl = convert(T, tail_prob / 2) + pu = convert(T, 1 - tail_prob / 2) + S_lower = _ess(Base.Fix2(Statistics.quantile, pl), x; kwargs...) + S_upper = _ess(Base.Fix2(Statistics.quantile, pu), x; kwargs...) + return map(min, S_lower, S_upper) end + +""" + rhat(samples::AbstractArray{Union{Real,Missing},3}; kind::Symbol=:rank, split_chains=2) + +Compute the ``\\widehat{R}`` diagnostics for each parameter in `samples` of shape +`(chains, draws, parameters)`.[^VehtariGelman2021] + +`kind` indicates the kind of ``\\widehat{R}`` to compute (see below). + +$_DOC_SPLIT_CHAINS + +See also [`ess`](@ref), [`ess_rhat`](@ref), [`rstar`](@ref) + +## Kinds of ``\\widehat{R}`` + +The following `kind`s are supported: +- `:rank`: maximum of ``\\widehat{R}`` with `kind=:bulk` and `kind=:tail`. +- `:bulk`: basic ``\\widehat{R}`` computed on rank-normalized draws. This kind diagnoses + poor convergence in the bulk of the distribution due to trends or different locations of + the chains. +- `:tail`: ``\\widehat{R}`` computed on draws folded around the median and then + rank-normalized. This kind diagnoses poor convergence in the tails of the distribution + due to different scales of the chains. +- `:basic`: Classic ``\\widehat{R}``. + +[^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021). + Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for + assessing convergence of MCMC. Bayesian Analysis. + doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) + arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) +""" +function rhat( + samples::AbstractArray{<:Union{Missing,Real},3}; kind::Symbol=:rank, kwargs... +) + # if we just call _rhat(Val(kind), ...) Julia cannot infer the return type with default + # const-propagation. We keep this type-inferrable by manually dispatching to the cases. + if kind === :rank + return _rhat(Val(:rank), samples; kwargs...) + elseif kind === :bulk + return _rhat(Val(:bulk), samples; kwargs...) + elseif kind === :tail + return _rhat(Val(:tail), samples; kwargs...) + elseif kind === :basic + return _rhat(Val(:basic), samples; kwargs...) + else + return throw(ArgumentError("the `kind` `$kind` is not supported by `rhat`")) + end +end +function _rhat( + ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2 +) + # compute size of matrices (each chain may be split!) + niter = size(chains, 1) ÷ split_chains + nchains = split_chains * size(chains, 2) + axes_out = (axes(chains, 3),) + T = promote_type(eltype(chains), typeof(zero(eltype(chains)) / 1)) + + # define output arrays + rhat = similar(chains, T, axes_out) + + T === Missing && return rhat + + # define caches for mean and variance + chain_mean = Array{T}(undef, 1, nchains) + chain_var = Array{T}(undef, nchains) + samples = Array{T}(undef, niter, nchains) + + # compute correction factor + correctionfactor = (niter - 1)//niter + + # for each parameter + for (i, chains_slice) in zip(eachindex(rhat), eachslice(chains; dims=3)) + # check that no values are missing + if any(x -> x === missing, chains_slice) + rhat[i] = missing + continue + end + + # split chains + copyto_split!(samples, chains_slice) + + # calculate mean of chains + Statistics.mean!(chain_mean, samples) + + # calculate within-chain variance + @inbounds for j in 1:nchains + chain_var[j] = Statistics.var( + view(samples, :, j); mean=chain_mean[j], corrected=true + ) + end + W = Statistics.mean(chain_var) + + # compute variance estimator var₊, which accounts for between-chain variance as well + # avoid NaN when nchains=1 and set the variance estimator var₊ to the the within-chain variance in that case + var₊ = correctionfactor * W + Statistics.var(chain_mean; corrected=(nchains > 1)) + + # estimate rhat + rhat[i] = sqrt(var₊ / W) + end + + return rhat +end +function _rhat(::Val{:bulk}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + return _rhat(Val(:basic), _rank_normalize(x); kwargs...) +end +function _rhat(::Val{:tail}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + return _rhat(Val(:bulk), _fold_around_median(x); kwargs...) +end +function _rhat(::Val{:rank}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + Rbulk = _rhat(Val(:bulk), x; kwargs...) + Rtail = _rhat(Val(:tail), x; kwargs...) + return map(max, Rtail, Rbulk) +end + +""" + ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kind::Symbol=:rank, kwargs...) + +Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shape +`(draws, chains, parameters)` with the `method`. + +When both ESS and ``\\widehat{R}`` are needed, this method is often more efficient than +calling `ess` and `rhat` separately. + +See [`rhat`](@ref) for a description of supported `kind`s and [`ess`](@ref) for a +description of `kwargs`. +""" function ess_rhat( - ::typeof(Statistics.mean), + samples::AbstractArray{<:Union{Missing,Real},3}; kind::Symbol=:rank, kwargs... +) + # if we just call _ess_rhat(Val(kind), ...) Julia cannot infer the return type with + # default const-propagation. We keep this type-inferrable by manually dispatching to the + # cases. + if kind === :rank + return _ess_rhat(Val(:rank), samples; kwargs...) + elseif kind === :bulk + return _ess_rhat(Val(:bulk), samples; kwargs...) + elseif kind === :tail + return _ess_rhat(Val(:tail), samples; kwargs...) + elseif kind === :basic + return _ess_rhat(Val(:basic), samples; kwargs...) + else + return throw(ArgumentError("the `kind` `$kind` is not supported by `ess_rhat`")) + end +end +function _ess_rhat( + ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real},3}; method::AbstractESSMethod=ESSMethod(), split_chains::Int=2, @@ -377,98 +567,32 @@ function ess_rhat( return ess, rhat end - -""" - ess_rhat_bulk(samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - -Estimate the bulk-effective sample size and bulk-``\\widehat{R}`` values for the `samples` of -shape `(draws, chains, parameters)`. - -For a description of `kwargs`, see [`ess_rhat`](@ref). - -The bulk-ESS and bulk-``\\widehat{R}`` are variants of ESS and ``\\widehat{R}`` that -diagnose poor convergence in the bulk of the distribution due to trends or different -locations of the chains. While it is conceptually related to [`ess_rhat`](@ref) for -`Statistics.mean`, it is well-defined even if the chains do not have finite variance.[^VehtariGelman2021] - -Bulk-ESS and bulk-``\\widehat{R}`` are computed by rank-normalizing the samples and then -computing `ess_rhat`. For each parameter, rank-normalization proceeds by first ranking the -inputs using "tied ranking" and then transforming the ranks to normal quantiles so that the -result is standard normally distributed. The transform is monotonic. - -See also: [`ess_tail`](@ref), [`rhat_tail`](@ref) - -[^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021). - Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for - assessing convergence of MCMC. Bayesian Analysis. - doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) - arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) -""" -function ess_rhat_bulk(x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - return ess_rhat(Statistics.mean, _rank_normalize(x); kwargs...) +function _ess_rhat(::Val{:bulk}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + return _ess_rhat(Val(:basic), _rank_normalize(x); kwargs...) end - -""" - ess_tail(samples::AbstractArray{<:Union{Missing,Real},3}; tail_prob=1//10, kwargs...) - -Estimate the tail-effective sample size and for the `samples` of shape -`(draws, chains, parameters)`. - -For a description of `kwargs`, see [`ess_rhat`](@ref). - -The tail-ESS diagnoses poor convergence in the tails of the distribution. Specifically, it -is the minimum of the ESS of the estimate of the symmetric quantiles where `tail_prob` is -the probability in the tails. For example, with the default `tail_prob=1//10`, the tail-ESS -is the minimum of the ESS of the 0.5 and 0.95 sample quantiles.[^VehtariGelman2021] - -See also: [`ess_rhat_bulk`](@ref), [`rhat_tail`](@ref) - -[^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021). - Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for - assessing convergence of MCMC. Bayesian Analysis. - doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) - arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) -""" -function ess_tail( - x::AbstractArray{<:Union{Missing,Real},3}; tail_prob::Real=1//10, kwargs... +function _ess_rhat( + kind::Val{:tail}, + x::AbstractArray{<:Union{Missing,Real},3}; + split_chains::Int=2, + kwargs..., ) - # workaround for https://github.com/JuliaStats/Statistics.jl/issues/136 - T = Base.promote_eltype(x, tail_prob) - return min.( - first(ess_rhat(Base.Fix2(Statistics.quantile, T(tail_prob / 2)), x; kwargs...)), - first(ess_rhat(Base.Fix2(Statistics.quantile, T(1 - tail_prob / 2)), x; kwargs...)), - ) + S = _ess(kind, x; split_chains=split_chains, kwargs...) + R = _rhat(kind, x; split_chains=split_chains) + return S, R +end +function _ess_rhat( + ::Val{:rank}, x::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2, kwargs... +) + Sbulk, Rbulk = _ess_rhat(Val(:bulk), x; split_chains=split_chains, kwargs...) + Rtail = _rhat(Val(:tail), x; split_chains=split_chains) + Rrank = map(max, Rtail, Rbulk) + return Sbulk, Rrank end - -""" - rhat_tail(samples::AbstractArray{Union{Real,Missing},3}; kwargs...) - -Estimate the tail-``\\widehat{R}`` diagnostic for the `samples` of shape -`(draws, chains, parameters)`. - -For a description of `kwargs`, see [`ess_rhat`](@ref). - -The tail-``\\widehat{R}`` diagnostic is a variant of ``\\widehat{R}`` that diagnoses poor -convergence in the tails of the distribution. In particular, it can detect chains that have -similar locations but different scales.[^VehtariGelman2021] - -For each parameter matrix of draws `x` with size `(draws, chains)`, it is calculated by -computing bulk-``\\widehat{R}`` on the absolute deviation of the draws from the median: -`abs.(x .- median(x))`. - -See also: [`ess_tail`](@ref), [`ess_rhat_bulk`](@ref) - -[^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021). - Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for - assessing convergence of MCMC. Bayesian Analysis. - doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) - arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) -""" -rhat_tail(x; kwargs...) = last(ess_rhat_bulk(_fold_around_median(x); kwargs...)) # Compute an expectand `z` such that ``\\textrm{mean-ESS}(z) ≈ \\textrm{f-ESS}(x)``. # If no proxy expectand for `f` is known, `nothing` is returned. _expectand_proxy(f, x) = nothing +_expectand_proxy(::typeof(Statistics.mean), x) = x function _expectand_proxy(::typeof(Statistics.median), x) y = similar(x) # avoid using the `dims` keyword for median because it @@ -490,7 +614,12 @@ function _expectand_proxy(f::Base.Fix2{typeof(Statistics.quantile),<:Real}, x) y = similar(x) # currently quantile does not support a dims keyword argument for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3)) - yi .= xi .≤ f(vec(xi)) + if any(ismissing, xi) + # quantile function raises an error if there are missing values + fill!(yi, missing) + else + yi .= xi .≤ f(vec(xi)) + end end return y end diff --git a/src/gewekediag.jl b/src/gewekediag.jl index 38bc788b..e6b86576 100644 --- a/src/gewekediag.jl +++ b/src/gewekediag.jl @@ -25,8 +25,8 @@ function gewekediag(x::AbstractVector{<:Real}; first::Real=0.1, last::Real=0.5, x1 = x[1:round(Int, first * n)] x2 = x[round(Int, n - last * n + 1):n] s = hypot( - Base.first(mcse(Statistics.mean, reshape(x1, :, 1, 1); split_chains=1, kwargs...)), - Base.first(mcse(Statistics.mean, reshape(x2, :, 1, 1); split_chains=1, kwargs...)), + Base.first(mcse(reshape(x1, :, 1, 1); split_chains=1, kwargs...)), + Base.first(mcse(reshape(x2, :, 1, 1); split_chains=1, kwargs...)), ) z = (Statistics.mean(x1) - Statistics.mean(x2)) / s p = SpecialFunctions.erfc(abs(z) / sqrt2) diff --git a/src/heideldiag.jl b/src/heideldiag.jl index 26ec5296..791e8bd4 100644 --- a/src/heideldiag.jl +++ b/src/heideldiag.jl @@ -20,7 +20,7 @@ function heideldiag( delta = trunc(Int, 0.10 * n) y = x[trunc(Int, n / 2):end] T = typeof(zero(eltype(x)) / 1) - s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...)) + s = first(mcse(reshape(y, :, 1, 1); split_chains=1, kwargs...)) S0 = length(y) * s^2 i, pvalue, converged, ybar = 1, one(T), false, T(NaN) while i < n / 2 @@ -37,7 +37,7 @@ function heideldiag( end i += delta end - s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...)) + s = first(mcse(reshape(y, :, 1, 1); split_chains=1, kwargs...)) halfwidth = sqrt2 * SpecialFunctions.erfcinv(T(alpha)) * s passed = halfwidth / abs(ybar) <= eps return ( diff --git a/src/mcse.jl b/src/mcse.jl index 779e4e14..3e13f3a9 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -2,20 +2,21 @@ const normcdf1 = 0.8413447460685429 # StatsFuns.normcdf(1) const normcdfn1 = 0.15865525393145705 # StatsFuns.normcdf(-1) """ - mcse(estimator, samples::AbstractArray{<:Union{Missing,Real}}; kwargs...) + mcse(samples::AbstractArray{<:Union{Missing,Real}}; kind=Statistics.mean, kwargs...) -Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` of -shape `(draws, chains, parameters)`. +Estimate the Monte Carlo standard errors (MCSE) of the estimator `kind` applied to `samples` +of shape `(draws, chains, parameters)`. -See also: [`ess_rhat`](@ref) +See also: [`ess`](@ref) -## Estimators +## Kinds of MCSE estimates -`estimator` must accept a vector of the same `eltype` as `samples` and return a real estimate. +The estimator whose MCSE should be estimated is specified with `kind`. `kind` must accept a +vector of the same `eltype` as `samples` and return a real estimate. -For the following estimators, the effective sample size [`ess_rhat`](@ref) and an estimate +For the following estimators, the effective sample size [`ess`](@ref) and an estimate of the asymptotic variance are used to compute the MCSE, and `kwargs` are forwarded to -`ess_rhat`: +`ess`: - `Statistics.mean` - `Statistics.median` - `Statistics.std` @@ -26,7 +27,7 @@ is used as a fallback, and the only accepted `kwargs` are `batch_size`, which in size of the overlapping batches used to estimate the MCSE, defaulting to `floor(Int, sqrt(draws * chains))`. Note that SBM tends to underestimate the MCSE, especially for highly autocorrelated chains. One should verify that autocorrelation is low -by checking the bulk- and tail-[`ess_rhat`](@ref) values. +by checking the bulk- and tail-ESS values. [^FlegalJones2011]: Flegal JM, Jones GL. (2011) Implementing MCMC: estimating with confidence. Handbook of Markov Chain Monte Carlo. pp. 175-97. @@ -36,18 +37,22 @@ by checking the bulk- and tail-[`ess_rhat`](@ref) values. doi: [10.1007/978-3-642-27440-4_18](https://doi.org/10.1007/978-3-642-27440-4_18) """ -mcse(f, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) = _mcse_sbm(f, x; kwargs...) -function mcse( +function mcse(x::AbstractArray{<:Union{Missing,Real},3}; kind=Statistics.mean, kwargs...) + return _mcse(kind, x; kwargs...) +end + +_mcse(f, x; kwargs...) = _mcse_sbm(f, x; kwargs...) +function _mcse( ::typeof(Statistics.mean), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) - S = first(ess_rhat(Statistics.mean, samples; kwargs...)) + S = _ess(Statistics.mean, samples; kwargs...) return dropdims(Statistics.std(samples; dims=(1, 2)); dims=(1, 2)) ./ sqrt.(S) end -function mcse( +function _mcse( ::typeof(Statistics.std), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) x = (samples .- Statistics.mean(samples; dims=(1, 2))) .^ 2 # expectand proxy - S = first(ess_rhat(Statistics.mean, x; kwargs...)) + S = _ess(Statistics.mean, x; kwargs...) # asymptotic variance of sample variance estimate is Var[var] = E[μ₄] - E[var]², # where μ₄ is the 4th central moment # by the delta method, Var[std] = Var[var] / 4E[var] = (E[μ₄]/E[var] - E[var])/4, @@ -56,13 +61,13 @@ function mcse( mean_moment4 = dropdims(Statistics.mean(abs2, x; dims=(1, 2)); dims=(1, 2)) return @. sqrt((mean_moment4 / mean_var - mean_var) / S) / 2 end -function mcse( +function _mcse( f::Base.Fix2{typeof(Statistics.quantile),<:Real}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs..., ) p = f.x - S = first(ess_rhat(f, samples; kwargs...)) + S = _ess(f, samples; kwargs...) T = eltype(S) R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T)))) values = similar(S, R) @@ -71,10 +76,10 @@ function mcse( end return values end -function mcse( +function _mcse( ::typeof(Statistics.median), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) - S = first(ess_rhat(Statistics.median, samples; kwargs...)) + S = _ess(Statistics.median, samples; kwargs...) T = eltype(S) R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T)))) values = similar(S, R) diff --git a/src/utils.jl b/src/utils.jl index 6ebfcb64..fbb6e37c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -166,11 +166,16 @@ and then transforming the ranks to normal quantiles so that the result is standa normally distributed. """ function _rank_normalize(x::AbstractArray{<:Any,3}) - y = similar(x, float(eltype(x))) + T = promote_type(eltype(x), typeof(zero(eltype(x)) / 1)) + y = similar(x, T) map(_rank_normalize!, eachslice(y; dims=3), eachslice(x; dims=3)) return y end function _rank_normalize!(values, x) + if any(ismissing, x) + fill!(values, missing) + return values + end rank = StatsBase.tiedrank(x) _normal_quantiles_from_ranks!(values, rank) map!(StatsFuns.norminvcdf, values, values) diff --git a/test/ess.jl b/test/ess_rhat.jl similarity index 54% rename from test/ess.jl rename to test/ess_rhat.jl index 30a96e9a..a639b480 100644 --- a/test/ess.jl +++ b/test/ess_rhat.jl @@ -32,7 +32,138 @@ function LogDensityProblems.capabilities(p::CauchyProblem) return LogDensityProblems.LogDensityOrder{1}() end -@testset "ess.jl" begin +mymean(x) = mean(x) + +@testset "ess_rhat.jl" begin + @testset "ess/ess_rhat/rhat basics" begin + @testset "only promote eltype when necessary" begin + @testset for kind in (:rank, :bulk, :tail, :basic) + @testset for T in (Float32, Float64) + x = rand(T, 100, 4, 2) + TV = Vector{T} + kind === :rank || @test @inferred(ess(x; kind=kind)) isa TV + @test @inferred(rhat(x; kind=kind)) isa TV + @test @inferred(ess_rhat(x; kind=kind)) isa Tuple{TV,TV} + end + @testset "Int" begin + x = rand(1:10, 100, 4, 2) + TV = Vector{Float64} + kind === :rank || @test @inferred(ess(x; kind=kind)) isa TV + @test @inferred(rhat(x; kind=kind)) isa TV + @test @inferred(ess_rhat(x; kind=kind)) isa Tuple{TV,TV} + end + end + @testset for kind in (mean, median, mad, std, Base.Fix2(quantile, 0.25)) + @testset for T in (Float32, Float64) + x = rand(T, 100, 4, 2) + @test @inferred(ess(x; kind=kind)) isa Vector{T} + end + @testset "Int" begin + x = rand(1:10, 100, 4, 2) + @test @inferred(ess(x; kind=kind)) isa Vector{Float64} + end + end + end + + @testset "errors" begin # check that issue #137 is fixed + x = rand(4, 3, 5) + x2 = rand(5, 3, 5) + x3 = rand(100, 3, 5) + @testset for f in (ess, ess_rhat) + @testset for kind in (:rank, :bulk, :tail, :basic) + f === ess && kind === :rank && continue + @test_throws ArgumentError f(x; split_chains=1, kind=kind) + f(x2; split_chains=1, kind=kind) + @test_throws ArgumentError f(x2; split_chains=2, kind=kind) + f(x3; maxlag=1, kind=kind) + @test_throws DomainError f(x3; maxlag=0, kind=kind) + end + @test_throws ArgumentError f(x2; kind=:foo) + end + @test_throws ArgumentError rhat(x2; kind=:foo) + @test_throws ArgumentError ess(x2; kind=mymean) + end + + @testset "Union{Missing,Float64} eltype" begin + @testset for kind in (:rank, :bulk, :tail, :basic) + x = Array{Union{Missing,Float64}}(undef, 1000, 4, 3) + x .= randn.() + x[1, 1, 1] = missing + S1 = ess(x; kind=kind === :rank ? :bulk : kind) + R1 = rhat(x; kind=kind) + S2, R2 = ess_rhat(x; kind=kind) + @test ismissing(S1[1]) + @test ismissing(R1[1]) + @test ismissing(S2[1]) + @test ismissing(R2[1]) + @test !any(ismissing, S1[2:3]) + @test !any(ismissing, R1[2:3]) + @test !any(ismissing, S2[2:3]) + @test !any(ismissing, R2[2:3]) + end + end + + @testset "produces similar vectors to inputs" begin + @testset for kind in (:rank, :bulk, :tail, :basic) + # simultaneously checks that we index correctly and that output types are correct + x = randn(100, 4, 5) + y = OffsetArray(x, -5:94, 2:5, 11:15) + S11 = ess(y; kind=kind === :rank ? :bulk : kind) + R11 = rhat(y; kind=kind) + S12, R12 = ess_rhat(y; kind=kind) + @test S11 isa OffsetVector{Float64} + @test S12 isa OffsetVector{Float64} + @test axes(S11, 1) == axes(S12, 1) == axes(y, 3) + @test R11 isa OffsetVector{Float64} + @test R12 isa OffsetVector{Float64} + @test axes(R11, 1) == axes(R12, 1) == axes(y, 3) + S21 = ess(x; kind=kind === :rank ? :bulk : kind) + R21 = rhat(x; kind=kind) + S22, R22 = ess_rhat(x; kind=kind) + @test S22 == S21 == collect(S21) + @test R21 == R22 == collect(R11) + y = OffsetArray(similar(x, Missing), -5:94, 2:5, 11:15) + S31 = ess(y; kind=kind === :rank ? :bulk : kind) + R31 = rhat(y; kind=kind) + S32, R32 = ess_rhat(y; kind=kind) + @test S31 isa OffsetVector{Missing} + @test S32 isa OffsetVector{Missing} + @test axes(S31, 1) == axes(S32, 1) == axes(y, 3) + @test R31 isa OffsetVector{Missing} + @test R32 isa OffsetVector{Missing} + @test axes(R31, 1) == axes(R32, 1) == axes(y, 3) + end + end + + @testset "ess, ess_rhat, and rhat consistency" begin + x = randn(1000, 4, 10) + @testset for kind in (:rank, :bulk, :tail, :basic), split_chains in (1, 2) + R1 = rhat(x; kind=kind, split_chains=split_chains) + @testset for method in (ESSMethod(), BDAESSMethod()), maxlag in (100, 10) + S1 = ess( + x; + kind=kind === :rank ? :bulk : kind, + split_chains=split_chains, + method=method, + maxlag=maxlag, + ) + S2, R2 = ess_rhat( + x; + kind=kind, + split_chains=split_chains, + method=method, + maxlag=maxlag, + ) + @test S1 == S2 + @test R1 == R2 + end + end + end + end + + # now that we have checked mutual consistency of each method, we perform all following + # checks for whichever method is most convenient + @testset "ESS and R̂ (IID samples)" begin # Repeat tests with different scales @testset for scale in (1, 50, 100), nchains in (1, 10), split_chains in (1, 2) @@ -65,39 +196,6 @@ end end end - @testset "ESS and R̂ only promote eltype when necessary" begin - @testset for T in (Float32, Float64) - x = rand(T, 100, 4, 2) - TV = Vector{T} - @inferred Tuple{TV,TV} ess_rhat(x) - end - @testset "Int" begin - x = rand(1:10, 100, 4, 2) - TV = Vector{Float64} - @inferred Tuple{TV,TV} ess_rhat(x) - end - end - - @testset "ESS and R̂ are similar vectors to inputs" begin - # simultaneously checks that we index correctly and that output types are correct - x = randn(100, 4, 5) - y = OffsetArray(x, -5:94, 2:5, 11:15) - S, R = ess_rhat(y) - @test S isa OffsetVector{Float64} - @test axes(S, 1) == axes(y, 3) - @test R isa OffsetVector{Float64} - @test axes(R, 1) == axes(y, 3) - S2, R2 = ess_rhat(x) - @test S2 == collect(S) - @test R2 == collect(R) - y = OffsetArray(similar(x, Missing), -5:94, 2:5, 11:15) - S3, R3 = ess_rhat(y) - @test S3 isa OffsetVector{Missing} - @test axes(S3, 1) == axes(y, 3) - @test R3 isa OffsetVector{Missing} - @test axes(R3, 1) == axes(y, 3) - end - @testset "ESS and R̂ (identical samples)" begin x = ones(10_000, 10, 40) @@ -115,47 +213,25 @@ end end end - @testset "ESS and R̂ errors" begin # check that issue #137 is fixed - x = rand(4, 3, 5) - x2 = rand(5, 3, 5) - @test_throws ArgumentError ess_rhat(x; split_chains=1) - ess_rhat(x2; split_chains=1) - @test_throws ArgumentError ess_rhat(x2; split_chains=2) - x3 = rand(100, 3, 5) - ess_rhat(x3; maxlag=1) - @test_throws DomainError ess_rhat(x3; maxlag=0) - end - - @testset "ESS and R̂ with Union{Missing,Float64} eltype" begin - x = Array{Union{Missing,Float64}}(undef, 1000, 4, 3) - x .= randn.() - x[1, 1, 1] = missing - S, R = ess_rhat(x) - @test ismissing(S[1]) - @test ismissing(R[1]) - @test !any(ismissing, S[2:3]) - @test !any(ismissing, R[2:3]) - end - @testset "Autocov of ESSMethod and FFTESSMethod equivalent to StatsBase" begin x = randn(1_000, 10, 40) - ess_exp = ess_rhat(x; method=ExplicitESSMethod())[1] + ess_exp = ess(x; method=ExplicitESSMethod()) @testset "$method" for method in [FFTESSMethod(), ESSMethod()] - @test ess_rhat(x; method=method)[1] ≈ ess_exp + @test ess(x; method=method) ≈ ess_exp end end @testset "ESS and R̂ for chains with 2 epochs that have not mixed" begin # checks that splitting yields lower ESS estimates and higher Rhat estimates x = randn(1000, 4, 10) .+ repeat([0, 10]; inner=(500, 1, 1)) - ess_array, rhat_array = ess_rhat(x; split_chains=1) + ess_array, rhat_array = ess_rhat(x; kind=:basic, split_chains=1) @test all(x -> isapprox(x, 1; rtol=0.1), rhat_array) - ess_array2, rhat_array2 = ess_rhat(x; split_chains=2) + ess_array2, rhat_array2 = ess_rhat(x; kind=:basic, split_chains=2) @test all(ess_array2 .< ess_array) @test all(>(2), rhat_array2) end - @testset "ess_rhat(f, x)[1]" begin + @testset "ess(x; kind=f)" begin # we check the ESS estimates by simulating uncorrelated, correlated, and # anticorrelated chains, mapping the draws to a target distribution, computing the # estimand, and estimating the ESS for the chosen estimator, computing the @@ -166,7 +242,7 @@ end nparams = 100 x = randn(ndraws, nchains, nparams) mymean(x; kwargs...) = mean(x; kwargs...) - @test_throws ArgumentError ess_rhat(mymean, x) + @test_throws ArgumentError ess(x; kind=mymean) estimators = [mean, median, std, mad, Base.Fix2(quantile, 0.25)] dists = [Normal(10, 100), Exponential(10), TDist(7) * 10 - 20] # AR(1) coefficients. 0 is IID, -0.3 is slightly anticorrelated, 0.9 is highly autocorrelated @@ -181,7 +257,7 @@ end x .= quantile.(dist, cdf.(Normal(), x)) # stationary distribution is dist μ_mean = dropdims(mapslices(f ∘ vec, x; dims=(1, 2)); dims=(1, 2)) dist = asymptotic_dist(f, dist) - n = @inferred(ess_rhat(f, x))[1] + n = @inferred(ess(x; kind=f)) μ = mean(dist) mcse = sqrt.(var(dist) ./ n) for i in eachindex(μ_mean, mcse) @@ -199,19 +275,19 @@ end nchains = 4 @testset for ndraws in (10, 100), φ in (-0.3, -0.9) x = ar1(φ, sqrt(1 - φ^2), ndraws, nchains, 1000) - Smin, Smax = extrema(ess_rhat(mean, x)[1]) + Smin, Smax = extrema(ess(x; kind=mean)) ntotal = ndraws * nchains @test Smax == ntotal * log10(ntotal) @test Smin > 0 end end - @testset "ess_rhat_bulk(x)" begin + @testset "ess(x; kind=:bulk)" begin xnorm = randn(1_000, 4, 10) - @test ess_rhat_bulk(xnorm) == ess_rhat(mean, _rank_normalize(xnorm)) + @test ess(xnorm; kind=:bulk) == ess(_rank_normalize(xnorm); kind=:basic) xcauchy = quantile.(Cauchy(), cdf.(Normal(), xnorm)) # transformation by any monotonic function should not change the bulk ESS/R-hat - @test ess_rhat_bulk(xnorm) == ess_rhat_bulk(xcauchy) + @test ess(xnorm; kind=:bulk) == ess(xcauchy; kind=:bulk) end @testset "tail- ESS and R-hat detect mismatched scales" begin @@ -230,19 +306,17 @@ end # sanity check that standard and bulk ESS and R-hat both fail to detect # mismatched scales - S, R = ess_rhat(x) + S, R = ess_rhat(x; kind=:basic) @test all(≥(ess_cutoff), S) @test all(≤(rhat_cutoff), R) - Sbulk, Rbulk = ess_rhat_bulk(x) + Sbulk, Rbulk = ess_rhat(x; kind=:bulk) @test all(≥(ess_cutoff), Sbulk) @test all(≤(rhat_cutoff), Rbulk) - # check that tail-Rhat and tail-ESS detect mismatched scales and signal - # poor convergence - S = ess_tail(x) - @test all(<(ess_cutoff), S) - R = rhat_tail(x) - @test all(>(rhat_cutoff), R) + # check that tail- ESS detects mismatched scales and signal poor convergence + S_tail, R_tail = ess_rhat(x; kind=:tail) + @test all(<(ess_cutoff), S_tail) + @test all(>(rhat_cutoff), R_tail) end @testset "bulk and tail ESS and R-hat for heavy tailed" begin @@ -260,9 +334,8 @@ end end x = permutedims(cat(posterior_matrices...; dims=3), (2, 3, 1)) - Sbulk, Rbulk = ess_rhat_bulk(x) - Stail = ess_tail(x) - Rtail = rhat_tail(x) + Sbulk, Rbulk = ess_rhat(x; kind=:bulk) + Stail, Rtail = ess_rhat(x; kind=:tail) ess_cutoff = 100 * size(x, 2) # recommended cutoff is 100 * nchains @test mean(≥(ess_cutoff), Sbulk) > 0.9 @test mean(≥(ess_cutoff), Stail) < mean(≥(ess_cutoff), Sbulk) diff --git a/test/mcse.jl b/test/mcse.jl index 781143f5..40e888ab 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -6,42 +6,43 @@ using Statistics using StatsBase @testset "mcse.jl" begin - @testset "estimator must be provided" begin + @testset "estimator defaults to mean" begin x = randn(100, 4, 10) - @test_throws MethodError mcse(x) + @test mcse(x) == mcse(x; kind=mean) end - @testset "ESS-based methods forward kwargs to ess_rhat" begin + @testset "ESS-based methods forward kwargs to ess" begin x = randn(100, 4, 10) @testset for f in [mean, median, std, Base.Fix2(quantile, 0.1)] - @test @inferred(mcse(f, x; split_chains=1)) ≠ mcse(f, x) + @test @inferred(mcse(x; kind=f, split_chains=1)) ≠ mcse(x; kind=f) end end @testset "mcse falls back to _mcse_sbm" begin x = randn(100, 4, 10) - @test @inferred(mcse(mad, x)) == - MCMCDiagnosticTools._mcse_sbm(mad, x) ≠ - MCMCDiagnosticTools._mcse_sbm(mad, x; batch_size=16) == - mcse(mad, x; batch_size=16) + estimator = mad + @test @inferred(mcse(x; kind=estimator)) == + MCMCDiagnosticTools._mcse_sbm(estimator, x) ≠ + MCMCDiagnosticTools._mcse_sbm(estimator, x; batch_size=16) == + mcse(x; kind=estimator, batch_size=16) end @testset "mcse produces similar vectors to inputs" begin # simultaneously checks that we index correctly and that output types are correct @testset for T in (Float32, Float64), - f in [mean, median, std, Base.Fix2(quantile, T(0.1)), mad] + estimator in [mean, median, std, Base.Fix2(quantile, T(0.1)), mad] x = randn(T, 100, 4, 5) y = OffsetArray(x, -5:94, 2:5, 11:15) - se = mcse(f, y) + se = mcse(y; kind=estimator) @test se isa OffsetVector{T} @test axes(se, 1) == axes(y, 3) - se2 = mcse(f, x) + se2 = mcse(x; kind=estimator) @test se2 ≈ collect(se) # quantile errors if data contains missings - f isa Base.Fix2{typeof(quantile)} && continue + estimator isa Base.Fix2{typeof(quantile)} && continue y = OffsetArray(similar(x, Missing), -5:94, 2:5, 11:15) - @test mcse(f, y) isa OffsetVector{Missing} + @test mcse(y; kind=estimator) isa OffsetVector{Missing} end end @@ -50,7 +51,7 @@ using StatsBase x .= randn.() x[1, 1, 1] = missing @testset for f in [mean, median, std, mad] - se = mcse(f, x) + se = mcse(x; kind=f) @test ismissing(se[1]) @test !any(ismissing, se[2:end]) end @@ -81,7 +82,7 @@ using StatsBase x .= quantile.(dist, cdf.(Normal(), x)) # stationary distribution is dist μ_mean = dropdims(mapslices(f ∘ vec, x; dims=(1, 2)); dims=(1, 2)) μ = mean(asymptotic_dist(f, dist)) - se = mcse(f, x) + se = mcse === MCMCDiagnosticTools._mcse_sbm ? mcse(f, x) : mcse(x; kind=f) for i in eachindex(μ_mean, se) atol = quantile(Normal(0, se[i]), 1 - α) @test μ_mean[i] ≈ μ atol = atol diff --git a/test/runtests.jl b/test/runtests.jl index 17507a39..b756fbe2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,8 +21,8 @@ Random.seed!(1) @testset "discrete diagnostic" begin include("discretediag.jl") end - @testset "ESS" begin - include("ess.jl") + @testset "ESS and R̂" begin + include("ess_rhat.jl") end @testset "Monte Carlo standard error" begin include("mcse.jl")