From 878cbe9d0c262a56c2f628ec198042e98ac4656f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Feb 2023 21:30:53 +0100 Subject: [PATCH 01/51] Create rhat.jl --- src/MCMCDiagnosticTools.jl | 2 + src/rhat.jl | 98 ++++++++++++++++++++++++++++++++++++++ src/utils.jl | 3 ++ 3 files changed, 103 insertions(+) create mode 100644 src/rhat.jl diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index e237abef..3d1eef33 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -22,6 +22,7 @@ export gewekediag export heideldiag export mcse export rafterydiag +export rhat export rstar include("utils.jl") @@ -33,5 +34,6 @@ include("gewekediag.jl") include("heideldiag.jl") include("mcse.jl") include("rafterydiag.jl") +include("rhat.jl") include("rstar.jl") end diff --git a/src/rhat.jl b/src/rhat.jl new file mode 100644 index 00000000..c9093425 --- /dev/null +++ b/src/rhat.jl @@ -0,0 +1,98 @@ +""" + rhat(samples::AbstractArray{Union{Real,Missing},3}; type=:rank, split_chains=2) + +Compute the ``\\widehat{R}`` diagnostics for each parameter in `samples` of shape +`(chains, draws, parameters)`. [^VehtariGelman2021] + +`type` indicates the type of ``\\widehat{R}`` to compute (see below). + +`split_chains` indicates the number of chains each chain is split into. +When `split_chains > 1`, then the diagnostics check for within-chain convergence. When +`d = mod(draws, split_chains) > 0`, i.e. the chains cannot be evenly split, then 1 draw +is discarded after each of the first `d` splits within each chain. + +## Types + +The following types are supported: +- `:rank`: maximum of ``\\widehat{R}`` with `type=:bulk` and `type=:tail`. +- `:bulk`: basic ``\\widehat{R}``` computed on rank-normalized draws. This type diagnoses + poor convergence in the bulk of the distribution due to trends or different locations of + the chains. +- `:tail`: ``\\widehat{R}`` computed on draws folded around the median and then + rank-normalized. This type diagnoses poor convergence in the tails of the distribution + due to different scales of the chains. +- `:basic`: Classic ``\\widehat{R}``. + +[^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021). + Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for + assessing convergence of MCMC. Bayesian Analysis. + doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) + arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) +""" +function rhat(samples::AbstractArray{<:Union{Missing,Real},3}; type=Val(:rank), kwargs...) + return _rhat(_val(type), samples; kwargs...) +end + +function _rhat( + ::Val{:basic}, + chains::AbstractArray{<:Union{Missing,Real},3}; + split_chains::Int=2, +) + # compute size of matrices (each chain may be split!) + niter = size(chains, 1) ÷ split_chains + nchains = split_chains * size(chains, 2) + axes_out = (axes(chains, 3),) + T = promote_type(eltype(chains), typeof(zero(eltype(chains)) / 1)) + + # define output arrays + rhat = similar(chains, T, axes_out) + + T === Missing && return rhat + + # define caches for mean and variance + chain_mean = Array{T}(undef, 1, nchains) + chain_var = Array{T}(undef, nchains) + samples = Array{T}(undef, niter, nchains) + + # compute correction factor + correctionfactor = (niter - 1)//niter + + # for each parameter + for (i, chains_slice) in zip(eachindex(rhat), eachslice(chains; dims=3)) + # check that no values are missing + if any(x -> x === missing, chains_slice) + rhat[i] = missing + continue + end + + # split chains + copyto_split!(samples, chains_slice) + + # calculate mean of chains + Statistics.mean!(chain_mean, samples) + + # calculate within-chain variance + @inbounds for j in 1:nchains + chain_var[j] = Statistics.var( + view(samples, :, j); mean=chain_mean[j], corrected=true + ) + end + W = Statistics.mean(chain_var) + + # compute variance estimator var₊, which accounts for between-chain variance as well + # avoid NaN when nchains=1 and set the variance estimator var₊ to the the within-chain variance in that case + var₊ = correctionfactor * W + Statistics.var(chain_mean; corrected=(nchains > 1)) + + # estimate rhat + rhat[i] = sqrt(var₊ / W) + end + + return rhat +end +_rhat(::Val{:bulk}, x; kwargs...) = _rhat(Val(:basic), _rank_normalize(x); kwargs...) +_rhat(::Val{:tail}, x; kwargs...) = _rhat(Val(:bulk), _fold_around_median(x); kwargs...) +function _rhat(::Val{:rank}, x; kwargs...) + Rbulk = _rhat(Val(:bulk), x; kwargs...) + Rtail = _rhat(Val(:tail), x; kwargs...) + return map(max, Rtail, Rbulk) +end diff --git a/src/utils.jl b/src/utils.jl index 6ebfcb64..fe899340 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -185,3 +185,6 @@ function _normal_quantiles_from_ranks!(q, r; α=3//8) q .= (r .- α) ./ (n - 2α + 1) return q end + +_val(k) = Val(k) +_val(k::Val) = k From 67ba11c7b73a745fa6caede9f08e65905dac8004 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Feb 2023 21:31:08 +0100 Subject: [PATCH 02/51] Add missing _expectand_proxy method --- src/ess.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/ess.jl b/src/ess.jl index 241fc673..ba358a74 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -469,6 +469,7 @@ rhat_tail(x; kwargs...) = last(ess_rhat_bulk(_fold_around_median(x); kwargs...)) # Compute an expectand `z` such that ``\\textrm{mean-ESS}(z) ≈ \\textrm{f-ESS}(x)``. # If no proxy expectand for `f` is known, `nothing` is returned. _expectand_proxy(f, x) = nothing +_expectand_proxy(::typeof(Statistics.mean), x) = x function _expectand_proxy(::typeof(Statistics.median), x) y = similar(x) # avoid using the `dims` keyword for median because it From 067ce264a61efd1fc751658ae31ac6e86342c414 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Feb 2023 21:38:19 +0100 Subject: [PATCH 03/51] Remove duplicate rhat_tail --- src/ess.jl | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/src/ess.jl b/src/ess.jl index ba358a74..9a645465 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -440,32 +440,6 @@ function ess_tail( ) end -""" - rhat_tail(samples::AbstractArray{Union{Real,Missing},3}; kwargs...) - -Estimate the tail-``\\widehat{R}`` diagnostic for the `samples` of shape -`(draws, chains, parameters)`. - -For a description of `kwargs`, see [`ess_rhat`](@ref). - -The tail-``\\widehat{R}`` diagnostic is a variant of ``\\widehat{R}`` that diagnoses poor -convergence in the tails of the distribution. In particular, it can detect chains that have -similar locations but different scales.[^VehtariGelman2021] - -For each parameter matrix of draws `x` with size `(draws, chains)`, it is calculated by -computing bulk-``\\widehat{R}`` on the absolute deviation of the draws from the median: -`abs.(x .- median(x))`. - -See also: [`ess_tail`](@ref), [`ess_rhat_bulk`](@ref) - -[^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021). - Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for - assessing convergence of MCMC. Bayesian Analysis. - doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) - arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) -""" -rhat_tail(x; kwargs...) = last(ess_rhat_bulk(_fold_around_median(x); kwargs...)) - # Compute an expectand `z` such that ``\\textrm{mean-ESS}(z) ≈ \\textrm{f-ESS}(x)``. # If no proxy expectand for `f` is known, `nothing` is returned. _expectand_proxy(f, x) = nothing From c4d50253ff440f6eaeb1e607477c3cd227d78989 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Feb 2023 21:38:48 +0100 Subject: [PATCH 04/51] Update ESS methods --- src/ess.jl | 173 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 97 insertions(+), 76 deletions(-) diff --git a/src/ess.jl b/src/ess.jl index 9a645465..c38c36c2 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -195,19 +195,21 @@ function mean_autocov(k::Int, cache::BDAESSCache) end """ - ess_rhat( - [estimator,] + ess( samples::AbstractArray{<:Union{Missing,Real},3}; + type=:bulk, + [estimator,] method=ESSMethod(), split_chains::Int=2, maxlag::Int=250, + kwargs... ) -Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shape +Estimate the effective sample size (ESS) of the `samples` of shape `(draws, chains, parameters)` with the `method`. -By default, the computed ESS and ``\\widehat{R}`` values correspond to the estimator `mean`. -Other estimators can be specified by passing a function `estimator` (see below). +Optionally, only one of the `type` of ESS estimate to return or the `estimator` for which +ESS is computed can be specified (see below). Some `type`s accept additional `kwargs`. `split_chains` indicates the number of chains each chain is split into. When `split_chains > 1`, then the diagnostics check for within-chain convergence. When @@ -233,25 +235,92 @@ The ESS and ``\\widehat{R}`` values can be computed for the following estimators - `StatsBase.mad` - `Base.Fix2(Statistics.quantile, p::Real)` +## Types + +If no `estimator` is provided, the following types of ESS estimates may be computed: +- `:bulk`: mean-ESS computed on rank-normalized draws. This type diagnoses poor + convergence in the bulk of the distribution due to trends or different locations of the + chains. +- `:tail`: minimum of the quantile-ESS for the symmetric quantiles where + `tail_prob=0.1` is the probability in the tails. This type diagnoses poor convergence in + the tails of the distribution. If this type is chosen, `kwargs` may contain a + `tail_prob` keyword. +- `:basic`: basic ESS, equivalent to specifying `estimator=Statistics.mean`. + +While Bulk-ESS is conceptually related to basic ESS, it is well-defined even if the chains +do not have finite variance.[^VehtariGelman2021]. For each parameter, rank-normalization +proceeds by first ranking the inputs using "tied ranking" and then transforming the ranks to +normal quantiles so that the result is standard normally distributed. This transform is +monotonic. + [^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021). Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for assessing convergence of MCMC. Bayesian Analysis. doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) """ -function ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - return ess_rhat(Statistics.mean, samples; kwargs...) -end -function ess_rhat(f, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - x = _expectand_proxy(f, samples) - if x === nothing - throw(ArgumentError("the estimator $f is not yet supported by `ess_rhat`")) +ess(samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) = _ess(samples; kwargs...) + +function _ess( + samples::AbstractArray{<:Union{Missing,Real},3}; + estimator=nothing, + type=estimator === nothing ? Val(:bulk) : nothing, + kwargs..., +) + if estimator !== nothing && type !== nothing + throw(ArgumentError("only one of `estimator` and `type` can be specified")) + elseif estimator !== nothing + x = _expectand_proxy(estimator, samples) + if x === nothing + throw(ArgumentError("the estimator $estimator is not yet supported by `ess`")) + end + return _ess(Val(:basic), x; kwargs...) + else + return _ess(_val(type), samples; kwargs...) end - values = ess_rhat(Statistics.mean, x; kwargs...) - return values end -function ess_rhat( - ::typeof(Statistics.mean), +function _ess(::Val{T}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) where {T} + return throw(ArgumentError("the `type` `$T` is not supported by `ess`")) +end +function _ess(type::Val{:basic}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + return first(ess_rhat(samples; type=type, kwargs...)) +end +function _ess(type::Val{:bulk}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + return first(ess_rhat(samples; type=type, kwargs...)) +end +function _ess( + ::Val{:tail}, + x::AbstractArray{<:Union{Missing,Real},3}; tail_prob::Real=1//10, kwargs... +) + # workaround for https://github.com/JuliaStats/Statistics.jl/issues/136 + T = Base.promote_eltype(x, tail_prob) + S_lower = ess(x; estimator=Base.Fix2(Statistics.quantile, T(tail_prob / 2)), kwargs...) + S_upper = ess( + x; estimator=Base.Fix2(Statistics.quantile, T(1 - tail_prob / 2)), kwargs..., + ) + return map(min, S_lower, S_upper) +end + +""" + ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; type=:rank, kwargs...) + +Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shape +`(draws, chains, parameters)` with the `method`. + +When both ESS and ``\\widehat{R}`` are needed, this method is often more efficient than +calling `ess` and `rhat` separately. + +See [`rhat`](@ref) for a description of supported `type`s and [`ess`](@ref) for a +description of `kwargs`. +""" +function ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; type=Val(:rank), kwargs...) + return _ess_rhat(_val(type), samples; kwargs...) +end +function _ess_rhat(::Val{T}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) where {T} + return throw(ArgumentError("the `type` `$T` is not supported by `ess_rhat`")) +end +function _ess_rhat( + ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real},3}; method::AbstractESSMethod=ESSMethod(), split_chains::Int=2, @@ -377,67 +446,19 @@ function ess_rhat( return ess, rhat end - -""" - ess_rhat_bulk(samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - -Estimate the bulk-effective sample size and bulk-``\\widehat{R}`` values for the `samples` of -shape `(draws, chains, parameters)`. - -For a description of `kwargs`, see [`ess_rhat`](@ref). - -The bulk-ESS and bulk-``\\widehat{R}`` are variants of ESS and ``\\widehat{R}`` that -diagnose poor convergence in the bulk of the distribution due to trends or different -locations of the chains. While it is conceptually related to [`ess_rhat`](@ref) for -`Statistics.mean`, it is well-defined even if the chains do not have finite variance.[^VehtariGelman2021] - -Bulk-ESS and bulk-``\\widehat{R}`` are computed by rank-normalizing the samples and then -computing `ess_rhat`. For each parameter, rank-normalization proceeds by first ranking the -inputs using "tied ranking" and then transforming the ranks to normal quantiles so that the -result is standard normally distributed. The transform is monotonic. - -See also: [`ess_tail`](@ref), [`rhat_tail`](@ref) - -[^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021). - Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for - assessing convergence of MCMC. Bayesian Analysis. - doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) - arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) -""" -function ess_rhat_bulk(x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - return ess_rhat(Statistics.mean, _rank_normalize(x); kwargs...) +function _ess_rhat(::Val{:bulk}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + return _ess_rhat(Val(:basic), _rank_normalize(x); kwargs...) end - -""" - ess_tail(samples::AbstractArray{<:Union{Missing,Real},3}; tail_prob=1//10, kwargs...) - -Estimate the tail-effective sample size and for the `samples` of shape -`(draws, chains, parameters)`. - -For a description of `kwargs`, see [`ess_rhat`](@ref). - -The tail-ESS diagnoses poor convergence in the tails of the distribution. Specifically, it -is the minimum of the ESS of the estimate of the symmetric quantiles where `tail_prob` is -the probability in the tails. For example, with the default `tail_prob=1//10`, the tail-ESS -is the minimum of the ESS of the 0.5 and 0.95 sample quantiles.[^VehtariGelman2021] - -See also: [`ess_rhat_bulk`](@ref), [`rhat_tail`](@ref) - -[^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021). - Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for - assessing convergence of MCMC. Bayesian Analysis. - doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) - arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) -""" -function ess_tail( - x::AbstractArray{<:Union{Missing,Real},3}; tail_prob::Real=1//10, kwargs... -) - # workaround for https://github.com/JuliaStats/Statistics.jl/issues/136 - T = Base.promote_eltype(x, tail_prob) - return min.( - first(ess_rhat(Base.Fix2(Statistics.quantile, T(tail_prob / 2)), x; kwargs...)), - first(ess_rhat(Base.Fix2(Statistics.quantile, T(1 - tail_prob / 2)), x; kwargs...)), - ) +function _ess_rhat(::Val{:tail}, x::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2, kwargs...) + S = ess(x; type=Val(:tail), split_chains=split_chains, kwargs...) + R = rhat(x; type=Val(:tail), split_chains=split_chains) + return S, R +end +function _ess_rhat(::Val{:rank}, x::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2, kwargs...) + Sbulk, Rbulk = _ess_rhat(Val(:bulk), x; split_chains=split_chains, kwargs...) + Rtail = _rhat(x; type=Val(:tail), split_chains=split_chains, kwargs...) + Rrank = map(max, Rtail, Rbulk) + return Sbulk, Rrank end # Compute an expectand `z` such that ``\\textrm{mean-ESS}(z) ≈ \\textrm{f-ESS}(x)``. From 4029eb8e68739a4876592834a8d4dd4a1b3815c4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Feb 2023 21:39:23 +0100 Subject: [PATCH 05/51] Define estimator keyword for mcse --- src/mcse.jl | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/mcse.jl b/src/mcse.jl index 779e4e14..4f8f9df8 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -2,7 +2,7 @@ const normcdf1 = 0.8413447460685429 # StatsFuns.normcdf(1) const normcdfn1 = 0.15865525393145705 # StatsFuns.normcdf(-1) """ - mcse(estimator, samples::AbstractArray{<:Union{Missing,Real}}; kwargs...) + mcse(samples::AbstractArray{<:Union{Missing,Real}}; estimator=Statistics.mean, kwargs...) Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` of shape `(draws, chains, parameters)`. @@ -36,18 +36,22 @@ by checking the bulk- and tail-[`ess_rhat`](@ref) values. doi: [10.1007/978-3-642-27440-4_18](https://doi.org/10.1007/978-3-642-27440-4_18) """ -mcse(f, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) = _mcse_sbm(f, x; kwargs...) -function mcse( +function mcse(x::AbstractArray{<:Union{Missing,Real},3}; estimator=Statistics.mean, kwargs...) + return _mcse(estimator, x; kwargs...) +end + +_mcse(f, x; kwargs...) = _mcse_sbm(f, x; kwargs...) +function _mcse( ::typeof(Statistics.mean), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) - S = first(ess_rhat(Statistics.mean, samples; kwargs...)) + S = ess(samples; estimator=Statistics.mean, kwargs...) return dropdims(Statistics.std(samples; dims=(1, 2)); dims=(1, 2)) ./ sqrt.(S) end -function mcse( +function _mcse( ::typeof(Statistics.std), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) x = (samples .- Statistics.mean(samples; dims=(1, 2))) .^ 2 # expectand proxy - S = first(ess_rhat(Statistics.mean, x; kwargs...)) + S = ess(x; type=Statistics.mean, kwargs...) # asymptotic variance of sample variance estimate is Var[var] = E[μ₄] - E[var]², # where μ₄ is the 4th central moment # by the delta method, Var[std] = Var[var] / 4E[var] = (E[μ₄]/E[var] - E[var])/4, @@ -56,13 +60,13 @@ function mcse( mean_moment4 = dropdims(Statistics.mean(abs2, x; dims=(1, 2)); dims=(1, 2)) return @. sqrt((mean_moment4 / mean_var - mean_var) / S) / 2 end -function mcse( +function _mcse( f::Base.Fix2{typeof(Statistics.quantile),<:Real}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs..., ) p = f.x - S = first(ess_rhat(f, samples; kwargs...)) + S = ess(samples; estimator=f, kwargs...) T = eltype(S) R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T)))) values = similar(S, R) @@ -71,10 +75,10 @@ function mcse( end return values end -function mcse( +function _mcse( ::typeof(Statistics.median), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) - S = first(ess_rhat(Statistics.median, samples; kwargs...)) + S = ess(samples; estimator=Statistics.median, kwargs...) T = eltype(S) R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T)))) values = similar(S, R) From 4a9267b9964cb37fb6ba058b76deaf7ea538bbd8 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Feb 2023 21:39:32 +0100 Subject: [PATCH 06/51] Reduce exports --- src/MCMCDiagnosticTools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index 3d1eef33..b197c7fd 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -16,7 +16,7 @@ using Statistics: Statistics export bfmi export discretediag -export ess_rhat, ess_rhat_bulk, ess_tail, rhat_tail, ESSMethod, FFTESSMethod, BDAESSMethod +export ess, ess_rhat, ESSMethod, FFTESSMethod, BDAESSMethod export gelmandiag, gelmandiag_multivariate export gewekediag export heideldiag From ca230658a7bf32a2b6525911b53d838fa999e31a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 00:13:59 +0100 Subject: [PATCH 07/51] Fix bug --- src/ess.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ess.jl b/src/ess.jl index c38c36c2..35a70412 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -456,7 +456,7 @@ function _ess_rhat(::Val{:tail}, x::AbstractArray{<:Union{Missing,Real},3}; spli end function _ess_rhat(::Val{:rank}, x::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2, kwargs...) Sbulk, Rbulk = _ess_rhat(Val(:bulk), x; split_chains=split_chains, kwargs...) - Rtail = _rhat(x; type=Val(:tail), split_chains=split_chains, kwargs...) + Rtail = rhat(x; type=Val(:tail), split_chains=split_chains) Rrank = map(max, Rtail, Rbulk) return Sbulk, Rrank end From bae3acb02746c961952f0305e5c65984dd964dfb Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 00:14:17 +0100 Subject: [PATCH 08/51] Run formatter --- src/ess.jl | 28 ++++++++++++++++++++-------- src/mcse.jl | 4 +++- src/rhat.jl | 6 +++--- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/ess.jl b/src/ess.jl index 35a70412..d6ec8e56 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -224,7 +224,7 @@ For a given estimand, it is recommended that the ESS is at least `100 * chains` ``\\widehat{R} < 1.01``.[^VehtariGelman2021] See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref), -[`ess_rhat_bulk`](@ref), [`ess_tail`](@ref), [`rhat_tail`](@ref), [`mcse`](@ref) +[`rhat`](@ref), [`ess_rhat`](@ref), [`mcse`](@ref) ## Estimators @@ -279,7 +279,9 @@ function _ess( return _ess(_val(type), samples; kwargs...) end end -function _ess(::Val{T}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) where {T} +function _ess( + ::Val{T}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... +) where {T} return throw(ArgumentError("the `type` `$T` is not supported by `ess`")) end function _ess(type::Val{:basic}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) @@ -290,13 +292,15 @@ function _ess(type::Val{:bulk}, samples::AbstractArray{<:Union{Missing,Real},3}; end function _ess( ::Val{:tail}, - x::AbstractArray{<:Union{Missing,Real},3}; tail_prob::Real=1//10, kwargs... + x::AbstractArray{<:Union{Missing,Real},3}; + tail_prob::Real=1//10, + kwargs..., ) # workaround for https://github.com/JuliaStats/Statistics.jl/issues/136 T = Base.promote_eltype(x, tail_prob) S_lower = ess(x; estimator=Base.Fix2(Statistics.quantile, T(tail_prob / 2)), kwargs...) S_upper = ess( - x; estimator=Base.Fix2(Statistics.quantile, T(1 - tail_prob / 2)), kwargs..., + x; estimator=Base.Fix2(Statistics.quantile, T(1 - tail_prob / 2)), kwargs... ) return map(min, S_lower, S_upper) end @@ -313,10 +317,14 @@ calling `ess` and `rhat` separately. See [`rhat`](@ref) for a description of supported `type`s and [`ess`](@ref) for a description of `kwargs`. """ -function ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; type=Val(:rank), kwargs...) +function ess_rhat( + samples::AbstractArray{<:Union{Missing,Real},3}; type=Val(:rank), kwargs... +) return _ess_rhat(_val(type), samples; kwargs...) end -function _ess_rhat(::Val{T}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) where {T} +function _ess_rhat( + ::Val{T}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... +) where {T} return throw(ArgumentError("the `type` `$T` is not supported by `ess_rhat`")) end function _ess_rhat( @@ -449,12 +457,16 @@ end function _ess_rhat(::Val{:bulk}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) return _ess_rhat(Val(:basic), _rank_normalize(x); kwargs...) end -function _ess_rhat(::Val{:tail}, x::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2, kwargs...) +function _ess_rhat( + ::Val{:tail}, x::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2, kwargs... +) S = ess(x; type=Val(:tail), split_chains=split_chains, kwargs...) R = rhat(x; type=Val(:tail), split_chains=split_chains) return S, R end -function _ess_rhat(::Val{:rank}, x::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2, kwargs...) +function _ess_rhat( + ::Val{:rank}, x::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2, kwargs... +) Sbulk, Rbulk = _ess_rhat(Val(:bulk), x; split_chains=split_chains, kwargs...) Rtail = rhat(x; type=Val(:tail), split_chains=split_chains) Rrank = map(max, Rtail, Rbulk) diff --git a/src/mcse.jl b/src/mcse.jl index 4f8f9df8..3bef83bf 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -36,7 +36,9 @@ by checking the bulk- and tail-[`ess_rhat`](@ref) values. doi: [10.1007/978-3-642-27440-4_18](https://doi.org/10.1007/978-3-642-27440-4_18) """ -function mcse(x::AbstractArray{<:Union{Missing,Real},3}; estimator=Statistics.mean, kwargs...) +function mcse( + x::AbstractArray{<:Union{Missing,Real},3}; estimator=Statistics.mean, kwargs... +) return _mcse(estimator, x; kwargs...) end diff --git a/src/rhat.jl b/src/rhat.jl index c9093425..d6a92f22 100644 --- a/src/rhat.jl +++ b/src/rhat.jl @@ -11,6 +11,8 @@ When `split_chains > 1`, then the diagnostics check for within-chain convergence `d = mod(draws, split_chains) > 0`, i.e. the chains cannot be evenly split, then 1 draw is discarded after each of the first `d` splits within each chain. +See also [`ess`](@ref), [`ess_rhat`](@ref), [`rstar`](@ref) + ## Types The following types are supported: @@ -34,9 +36,7 @@ function rhat(samples::AbstractArray{<:Union{Missing,Real},3}; type=Val(:rank), end function _rhat( - ::Val{:basic}, - chains::AbstractArray{<:Union{Missing,Real},3}; - split_chains::Int=2, + ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2 ) # compute size of matrices (each chain may be split!) niter = size(chains, 1) ÷ split_chains From d7fc6643b491f66b4bb654e98d667587cd084cb1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 00:15:02 +0100 Subject: [PATCH 09/51] Shortcut rank-normalizing a `Missing` --- src/utils.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index fe899340..25d9bf54 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -171,6 +171,10 @@ function _rank_normalize(x::AbstractArray{<:Any,3}) return y end function _rank_normalize!(values, x) + if any(ismissing, x) + fill!(values, missing) + return values + end rank = StatsBase.tiedrank(x) _normal_quantiles_from_ranks!(values, rank) map!(StatsFuns.norminvcdf, values, values) From 1d8c21af8caef4b138fd2757baf1dfdbd76c321f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 00:15:22 +0100 Subject: [PATCH 10/51] Update ESS tests --- test/ess.jl | 41 +++++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/test/ess.jl b/test/ess.jl index 30a96e9a..eb710f7a 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -139,23 +139,23 @@ end @testset "Autocov of ESSMethod and FFTESSMethod equivalent to StatsBase" begin x = randn(1_000, 10, 40) - ess_exp = ess_rhat(x; method=ExplicitESSMethod())[1] + ess_exp = ess(x; method=ExplicitESSMethod()) @testset "$method" for method in [FFTESSMethod(), ESSMethod()] - @test ess_rhat(x; method=method)[1] ≈ ess_exp + @test ess(x; method=method) ≈ ess_exp end end @testset "ESS and R̂ for chains with 2 epochs that have not mixed" begin # checks that splitting yields lower ESS estimates and higher Rhat estimates x = randn(1000, 4, 10) .+ repeat([0, 10]; inner=(500, 1, 1)) - ess_array, rhat_array = ess_rhat(x; split_chains=1) + ess_array, rhat_array = ess_rhat(x; type=:basic, split_chains=1) @test all(x -> isapprox(x, 1; rtol=0.1), rhat_array) - ess_array2, rhat_array2 = ess_rhat(x; split_chains=2) + ess_array2, rhat_array2 = ess_rhat(x; type=:basic, split_chains=2) @test all(ess_array2 .< ess_array) @test all(>(2), rhat_array2) end - @testset "ess_rhat(f, x)[1]" begin + @testset "ess(x; estimator=f)[1]" begin # we check the ESS estimates by simulating uncorrelated, correlated, and # anticorrelated chains, mapping the draws to a target distribution, computing the # estimand, and estimating the ESS for the chosen estimator, computing the @@ -166,7 +166,7 @@ end nparams = 100 x = randn(ndraws, nchains, nparams) mymean(x; kwargs...) = mean(x; kwargs...) - @test_throws ArgumentError ess_rhat(mymean, x) + @test_throws ArgumentError ess(x; estimator=mymean) estimators = [mean, median, std, mad, Base.Fix2(quantile, 0.25)] dists = [Normal(10, 100), Exponential(10), TDist(7) * 10 - 20] # AR(1) coefficients. 0 is IID, -0.3 is slightly anticorrelated, 0.9 is highly autocorrelated @@ -181,7 +181,7 @@ end x .= quantile.(dist, cdf.(Normal(), x)) # stationary distribution is dist μ_mean = dropdims(mapslices(f ∘ vec, x; dims=(1, 2)); dims=(1, 2)) dist = asymptotic_dist(f, dist) - n = @inferred(ess_rhat(f, x))[1] + n = @inferred(ess(x; estimator=f)) μ = mean(dist) mcse = sqrt.(var(dist) ./ n) for i in eachindex(μ_mean, mcse) @@ -199,19 +199,19 @@ end nchains = 4 @testset for ndraws in (10, 100), φ in (-0.3, -0.9) x = ar1(φ, sqrt(1 - φ^2), ndraws, nchains, 1000) - Smin, Smax = extrema(ess_rhat(mean, x)[1]) + Smin, Smax = extrema(ess(x; estimator=mean)) ntotal = ndraws * nchains @test Smax == ntotal * log10(ntotal) @test Smin > 0 end end - @testset "ess_rhat_bulk(x)" begin + @testset "ess_rhat(x; type=:bulk)" begin xnorm = randn(1_000, 4, 10) - @test ess_rhat_bulk(xnorm) == ess_rhat(mean, _rank_normalize(xnorm)) + @test ess_rhat(xnorm; type=:bulk) == ess_rhat(_rank_normalize(xnorm); type=:basic) xcauchy = quantile.(Cauchy(), cdf.(Normal(), xnorm)) # transformation by any monotonic function should not change the bulk ESS/R-hat - @test ess_rhat_bulk(xnorm) == ess_rhat_bulk(xcauchy) + @test ess_rhat(xnorm; type=:bulk) == ess_rhat(xcauchy; type=:bulk) end @testset "tail- ESS and R-hat detect mismatched scales" begin @@ -230,19 +230,17 @@ end # sanity check that standard and bulk ESS and R-hat both fail to detect # mismatched scales - S, R = ess_rhat(x) + S, R = ess_rhat(x; type=:basic) @test all(≥(ess_cutoff), S) @test all(≤(rhat_cutoff), R) - Sbulk, Rbulk = ess_rhat_bulk(x) + Sbulk, Rbulk = ess_rhat(x; type=:bulk) @test all(≥(ess_cutoff), Sbulk) @test all(≤(rhat_cutoff), Rbulk) - # check that tail-Rhat and tail-ESS detect mismatched scales and signal - # poor convergence - S = ess_tail(x) - @test all(<(ess_cutoff), S) - R = rhat_tail(x) - @test all(>(rhat_cutoff), R) + # check that tail- ESS detects mismatched scales and signal poor convergence + S_tail, R_tail = ess_rhat(x; type=:tail) + @test all(<(ess_cutoff), S_tail) + @test all(>(rhat_cutoff), R_tail) end @testset "bulk and tail ESS and R-hat for heavy tailed" begin @@ -260,9 +258,8 @@ end end x = permutedims(cat(posterior_matrices...; dims=3), (2, 3, 1)) - Sbulk, Rbulk = ess_rhat_bulk(x) - Stail = ess_tail(x) - Rtail = rhat_tail(x) + Sbulk, Rbulk = ess_rhat(x; type=:bulk) + Stail, Rtail = ess_rhat(x; type=:tail) ess_cutoff = 100 * size(x, 2) # recommended cutoff is 100 * nchains @test mean(≥(ess_cutoff), Sbulk) > 0.9 @test mean(≥(ess_cutoff), Stail) < mean(≥(ess_cutoff), Sbulk) From a1508140d562ab9f5565ee3dbf52965f19061572 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 00:15:49 +0100 Subject: [PATCH 11/51] Use new mcse syntax --- test/mcse.jl | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/test/mcse.jl b/test/mcse.jl index 781143f5..bb4a0ff8 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -6,42 +6,43 @@ using Statistics using StatsBase @testset "mcse.jl" begin - @testset "estimator must be provided" begin + @testset "estimator defaults to mean" begin x = randn(100, 4, 10) - @test_throws MethodError mcse(x) + @test_throws mcse(x) == mcse(x; estimator=mean) end @testset "ESS-based methods forward kwargs to ess_rhat" begin x = randn(100, 4, 10) @testset for f in [mean, median, std, Base.Fix2(quantile, 0.1)] - @test @inferred(mcse(f, x; split_chains=1)) ≠ mcse(f, x) + @test @inferred(mcse(x; estimator=f, split_chains=1)) ≠ mcse(x; estimator=f) end end @testset "mcse falls back to _mcse_sbm" begin x = randn(100, 4, 10) - @test @inferred(mcse(mad, x)) == - MCMCDiagnosticTools._mcse_sbm(mad, x) ≠ - MCMCDiagnosticTools._mcse_sbm(mad, x; batch_size=16) == - mcse(mad, x; batch_size=16) + estimator = mad + @test @inferred(mcse(x; estimator=estimator)) == + MCMCDiagnosticTools._mcse_sbm(estimator, x) ≠ + MCMCDiagnosticTools._mcse_sbm(estimator, x; batch_size=16) == + mcse(x; estimator=estimator, batch_size=16) end @testset "mcse produces similar vectors to inputs" begin # simultaneously checks that we index correctly and that output types are correct @testset for T in (Float32, Float64), - f in [mean, median, std, Base.Fix2(quantile, T(0.1)), mad] + estimator in [mean, median, std, Base.Fix2(quantile, T(0.1)), mad] x = randn(T, 100, 4, 5) y = OffsetArray(x, -5:94, 2:5, 11:15) - se = mcse(f, y) + se = mcse(y; estimator=estimator) @test se isa OffsetVector{T} @test axes(se, 1) == axes(y, 3) - se2 = mcse(f, x) + se2 = mcse(x; estimator=estimator) @test se2 ≈ collect(se) # quantile errors if data contains missings f isa Base.Fix2{typeof(quantile)} && continue y = OffsetArray(similar(x, Missing), -5:94, 2:5, 11:15) - @test mcse(f, y) isa OffsetVector{Missing} + @test mcse(y; estimator=estimator) isa OffsetVector{Missing} end end @@ -50,7 +51,7 @@ using StatsBase x .= randn.() x[1, 1, 1] = missing @testset for f in [mean, median, std, mad] - se = mcse(f, x) + se = mcse(x; estimator=f) @test ismissing(se[1]) @test !any(ismissing, se[2:end]) end @@ -81,7 +82,7 @@ using StatsBase x .= quantile.(dist, cdf.(Normal(), x)) # stationary distribution is dist μ_mean = dropdims(mapslices(f ∘ vec, x; dims=(1, 2)); dims=(1, 2)) μ = mean(asymptotic_dist(f, dist)) - se = mcse(f, x) + se = mcse === MCMCDiagnosticTools._mcse_sbm ? mcse(f, x) : mcse(x; estimator=f) for i in eachindex(μ_mean, se) atol = quantile(Normal(0, se[i]), 1 - α) @test μ_mean[i] ≈ μ atol = atol From 1d833df74754d23ad75c9c375cbbcc49502c3f58 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 00:27:29 +0100 Subject: [PATCH 12/51] Update documented methods --- docs/src/index.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 54a29e74..39ca9004 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -9,10 +9,9 @@ CurrentModule = MCMCDiagnosticTools The effective sample size (ESS) and $\widehat{R}$ can be estimated with [`ess_rhat`](@ref). ```@docs +ess +rhat ess_rhat -ess_rhat_bulk -ess_tail -rhat_tail ``` The following methods are supported: From 49a69e741bcef889b4e62ba126aafce7fc327b6c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 00:27:37 +0100 Subject: [PATCH 13/51] Update text --- docs/src/index.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 39ca9004..5af1f44c 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -6,15 +6,13 @@ CurrentModule = MCMCDiagnosticTools ## Effective sample size and $\widehat{R}$ -The effective sample size (ESS) and $\widehat{R}$ can be estimated with [`ess_rhat`](@ref). - ```@docs ess rhat ess_rhat ``` -The following methods are supported: +The following `method`s are supported: ```@docs ESSMethod From dc51b6f5f02aef96bce5a02f7c5e604c3669d2ed Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 09:16:01 +0100 Subject: [PATCH 14/51] Fix test --- test/mcse.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcse.jl b/test/mcse.jl index bb4a0ff8..5a8217d2 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -8,7 +8,7 @@ using StatsBase @testset "mcse.jl" begin @testset "estimator defaults to mean" begin x = randn(100, 4, 10) - @test_throws mcse(x) == mcse(x; estimator=mean) + @test mcse(x) == mcse(x; estimator=mean) end @testset "ESS-based methods forward kwargs to ess_rhat" begin From 7d1e9eff100da3c4a3a8f2ad5a0a9685b5e0db37 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 09:16:14 +0100 Subject: [PATCH 15/51] Use new mcse signature --- src/gewekediag.jl | 4 ++-- src/heideldiag.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gewekediag.jl b/src/gewekediag.jl index 38bc788b..e6b86576 100644 --- a/src/gewekediag.jl +++ b/src/gewekediag.jl @@ -25,8 +25,8 @@ function gewekediag(x::AbstractVector{<:Real}; first::Real=0.1, last::Real=0.5, x1 = x[1:round(Int, first * n)] x2 = x[round(Int, n - last * n + 1):n] s = hypot( - Base.first(mcse(Statistics.mean, reshape(x1, :, 1, 1); split_chains=1, kwargs...)), - Base.first(mcse(Statistics.mean, reshape(x2, :, 1, 1); split_chains=1, kwargs...)), + Base.first(mcse(reshape(x1, :, 1, 1); split_chains=1, kwargs...)), + Base.first(mcse(reshape(x2, :, 1, 1); split_chains=1, kwargs...)), ) z = (Statistics.mean(x1) - Statistics.mean(x2)) / s p = SpecialFunctions.erfc(abs(z) / sqrt2) diff --git a/src/heideldiag.jl b/src/heideldiag.jl index 26ec5296..791e8bd4 100644 --- a/src/heideldiag.jl +++ b/src/heideldiag.jl @@ -20,7 +20,7 @@ function heideldiag( delta = trunc(Int, 0.10 * n) y = x[trunc(Int, n / 2):end] T = typeof(zero(eltype(x)) / 1) - s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...)) + s = first(mcse(reshape(y, :, 1, 1); split_chains=1, kwargs...)) S0 = length(y) * s^2 i, pvalue, converged, ybar = 1, one(T), false, T(NaN) while i < n / 2 @@ -37,7 +37,7 @@ function heideldiag( end i += delta end - s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...)) + s = first(mcse(reshape(y, :, 1, 1); split_chains=1, kwargs...)) halfwidth = sqrt2 * SpecialFunctions.erfcinv(T(alpha)) * s passed = halfwidth / abs(ybar) <= eps return ( From 3573277179dc33e81e5a0a1e502652bd09984916 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 09:59:20 +0100 Subject: [PATCH 16/51] Update type kwarg to estimator --- src/mcse.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcse.jl b/src/mcse.jl index 3bef83bf..d1d9ab69 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -53,7 +53,7 @@ function _mcse( ::typeof(Statistics.std), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) x = (samples .- Statistics.mean(samples; dims=(1, 2))) .^ 2 # expectand proxy - S = ess(x; type=Statistics.mean, kwargs...) + S = ess(x; estimator=Statistics.mean, kwargs...) # asymptotic variance of sample variance estimate is Var[var] = E[μ₄] - E[var]², # where μ₄ is the 4th central moment # by the delta method, Var[std] = Var[var] / 4E[var] = (E[μ₄]/E[var] - E[var])/4, From f4d9f675ca991d58efd83f25b2a0bd89d10e9142 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 10:01:25 +0100 Subject: [PATCH 17/51] Update mcse docstring --- src/mcse.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcse.jl b/src/mcse.jl index d1d9ab69..cc7e10a4 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -7,15 +7,15 @@ const normcdfn1 = 0.15865525393145705 # StatsFuns.normcdf(-1) Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` of shape `(draws, chains, parameters)`. -See also: [`ess_rhat`](@ref) +See also: [`ess`](@ref) ## Estimators `estimator` must accept a vector of the same `eltype` as `samples` and return a real estimate. -For the following estimators, the effective sample size [`ess_rhat`](@ref) and an estimate +For the following estimators, the effective sample size [`ess`](@ref) and an estimate of the asymptotic variance are used to compute the MCSE, and `kwargs` are forwarded to -`ess_rhat`: +`ess`: - `Statistics.mean` - `Statistics.median` - `Statistics.std` @@ -26,7 +26,7 @@ is used as a fallback, and the only accepted `kwargs` are `batch_size`, which in size of the overlapping batches used to estimate the MCSE, defaulting to `floor(Int, sqrt(draws * chains))`. Note that SBM tends to underestimate the MCSE, especially for highly autocorrelated chains. One should verify that autocorrelation is low -by checking the bulk- and tail-[`ess_rhat`](@ref) values. +by checking the bulk- and tail-ESS values. [^FlegalJones2011]: Flegal JM, Jones GL. (2011) Implementing MCMC: estimating with confidence. Handbook of Markov Chain Monte Carlo. pp. 175-97. From 7ea7927d4caf91b5e1548b9193a31665841a8988 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 11:01:34 +0100 Subject: [PATCH 18/51] Fix variable name --- test/mcse.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcse.jl b/test/mcse.jl index 5a8217d2..25190e5b 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -40,7 +40,7 @@ using StatsBase se2 = mcse(x; estimator=estimator) @test se2 ≈ collect(se) # quantile errors if data contains missings - f isa Base.Fix2{typeof(quantile)} && continue + estimator isa Base.Fix2{typeof(quantile)} && continue y = OffsetArray(similar(x, Missing), -5:94, 2:5, 11:15) @test mcse(y; estimator=estimator) isa OffsetVector{Missing} end From 18997225eacadf3e93f80b20dd0bf7a3087c0759 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 21:01:10 +0100 Subject: [PATCH 19/51] Make tail-ESS work for typeunion with Missing --- src/ess.jl | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/ess.jl b/src/ess.jl index d6ec8e56..9eb2db52 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -298,10 +298,10 @@ function _ess( ) # workaround for https://github.com/JuliaStats/Statistics.jl/issues/136 T = Base.promote_eltype(x, tail_prob) - S_lower = ess(x; estimator=Base.Fix2(Statistics.quantile, T(tail_prob / 2)), kwargs...) - S_upper = ess( - x; estimator=Base.Fix2(Statistics.quantile, T(1 - tail_prob / 2)), kwargs... - ) + pl = convert(T, tail_prob / 2) + pu = convert(T, 1 - tail_prob / 2) + S_lower = ess(x; estimator=Base.Fix2(Statistics.quantile, pl), kwargs...) + S_upper = ess(x; estimator=Base.Fix2(Statistics.quantile, pu), kwargs...) return map(min, S_lower, S_upper) end @@ -498,7 +498,12 @@ function _expectand_proxy(f::Base.Fix2{typeof(Statistics.quantile),<:Real}, x) y = similar(x) # currently quantile does not support a dims keyword argument for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3)) - yi .= xi .≤ f(vec(xi)) + if any(ismissing, xi) + # quantile function raises an error if there are missing values + fill!(yi, missing) + else + yi .= xi .≤ f(vec(xi)) + end end return y end From 4f9c7094d09df3f34748858658d9dc61944a9954 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 21:01:20 +0100 Subject: [PATCH 20/51] Add rank for ESS --- src/ess.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ess.jl b/src/ess.jl index 9eb2db52..9b494bc8 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -238,7 +238,7 @@ The ESS and ``\\widehat{R}`` values can be computed for the following estimators ## Types If no `estimator` is provided, the following types of ESS estimates may be computed: -- `:bulk`: mean-ESS computed on rank-normalized draws. This type diagnoses poor +- `:bulk`/`:rank`: mean-ESS computed on rank-normalized draws. This type diagnoses poor convergence in the bulk of the distribution due to trends or different locations of the chains. - `:tail`: minimum of the quantile-ESS for the symmetric quantiles where @@ -304,6 +304,9 @@ function _ess( S_upper = ess(x; estimator=Base.Fix2(Statistics.quantile, pu), kwargs...) return map(min, S_lower, S_upper) end +function _ess(::Val{:rank}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + return _ess(Val(:bulk), samples; kwargs...) +end """ ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; type=:rank, kwargs...) From 7a4d4975334d28eb8a69a6f0d25b0d9f7c5b8797 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 21:38:28 +0100 Subject: [PATCH 21/51] Rearrange tests --- test/ess.jl | 134 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 79 insertions(+), 55 deletions(-) diff --git a/test/ess.jl b/test/ess.jl index eb710f7a..3b0e16c4 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -33,6 +33,85 @@ function LogDensityProblems.capabilities(p::CauchyProblem) end @testset "ess.jl" begin + @testset "ess/ess_rhat basics" begin + @testset "only promote eltype when necessary" begin + @testset for type in [:rank, :bulk, :tail, :basic] + @testset for T in (Float32, Float64) + x = rand(T, 100, 4, 2) + TV = Vector{T} + @inferred TV ess(x; type=type) + @inferred Tuple{TV,TV} ess_rhat(x; type=type) + end + @testset "Int" begin + x = rand(1:10, 100, 4, 2) + TV = Vector{Float64} + @inferred TV ess(x; type=type) + @inferred Tuple{TV,TV} ess_rhat(x; type=type) + end + end + end + + @testset "errors" begin # check that issue #137 is fixed + x = rand(4, 3, 5) + x2 = rand(5, 3, 5) + x3 = rand(100, 3, 5) + @testset for type in [:rank, :bulk, :tail, :basic] + @test_throws ArgumentError ess(x; split_chains=1, type=type) + @test_throws ArgumentError ess_rhat(x; split_chains=1, type=type) + @test ess(x2; split_chains=1, type=type) == + ess_rhat(x2; split_chains=1, type=type)[1] + @test_throws ArgumentError ess(x2; split_chains=2, type=type) + @test_throws ArgumentError ess_rhat(x2; split_chains=2, type=type) + @test ess(x3; maxlag=1, type=type) == ess_rhat(x3; maxlag=1, type=type)[1] + @test_throws DomainError ess(x3; maxlag=0, type=type) + @test_throws DomainError ess_rhat(x3; maxlag=0, type=type) + end + end + + @testset "Union{Missing,Float64} eltype" begin + @testset for type in [:rank, :bulk, :tail, :basic] + x = Array{Union{Missing,Float64}}(undef, 1000, 4, 3) + x .= randn.() + x[1, 1, 1] = missing + S1 = ess(x; type=type) + S2, R = ess_rhat(x; type=type) + @test ismissing(S1[1]) + @test ismissing(S2[1]) + @test ismissing(R[1]) + @test !any(ismissing, S1[2:3]) + @test !any(ismissing, S2[2:3]) + @test !any(ismissing, R[2:3]) + end + end + + @testset "produces similar vectors to inputs" begin + @testset for type in [:rank, :bulk, :tail, :basic] + # simultaneously checks that we index correctly and that output types are correct + x = randn(100, 4, 5) + y = OffsetArray(x, -5:94, 2:5, 11:15) + S11 = ess(y; type=type) + S12, R1 = ess_rhat(y; type=type) + @test S11 isa OffsetVector{Float64} + @test S12 isa OffsetVector{Float64} + @test axes(S11, 1) == axes(S12, 1) == axes(y, 3) + @test R1 isa OffsetVector{Float64} + @test axes(R1, 1) == axes(y, 3) + S21 = ess(x; type=type) + S22, R2 = ess_rhat(x; type=type) + @test S22 == S21 == collect(S21) + @test R2 == collect(R1) + y = OffsetArray(similar(x, Missing), -5:94, 2:5, 11:15) + S31 = ess(y; type=type) + S32, R3 = ess_rhat(y; type=type) + @test S31 isa OffsetVector{Missing} + @test S32 isa OffsetVector{Missing} + @test axes(S31, 1) == axes(S32, 1) == axes(y, 3) + @test R3 isa OffsetVector{Missing} + @test axes(R3, 1) == axes(y, 3) + end + end + end + @testset "ESS and R̂ (IID samples)" begin # Repeat tests with different scales @testset for scale in (1, 50, 100), nchains in (1, 10), split_chains in (1, 2) @@ -65,39 +144,6 @@ end end end - @testset "ESS and R̂ only promote eltype when necessary" begin - @testset for T in (Float32, Float64) - x = rand(T, 100, 4, 2) - TV = Vector{T} - @inferred Tuple{TV,TV} ess_rhat(x) - end - @testset "Int" begin - x = rand(1:10, 100, 4, 2) - TV = Vector{Float64} - @inferred Tuple{TV,TV} ess_rhat(x) - end - end - - @testset "ESS and R̂ are similar vectors to inputs" begin - # simultaneously checks that we index correctly and that output types are correct - x = randn(100, 4, 5) - y = OffsetArray(x, -5:94, 2:5, 11:15) - S, R = ess_rhat(y) - @test S isa OffsetVector{Float64} - @test axes(S, 1) == axes(y, 3) - @test R isa OffsetVector{Float64} - @test axes(R, 1) == axes(y, 3) - S2, R2 = ess_rhat(x) - @test S2 == collect(S) - @test R2 == collect(R) - y = OffsetArray(similar(x, Missing), -5:94, 2:5, 11:15) - S3, R3 = ess_rhat(y) - @test S3 isa OffsetVector{Missing} - @test axes(S3, 1) == axes(y, 3) - @test R3 isa OffsetVector{Missing} - @test axes(R3, 1) == axes(y, 3) - end - @testset "ESS and R̂ (identical samples)" begin x = ones(10_000, 10, 40) @@ -115,28 +161,6 @@ end end end - @testset "ESS and R̂ errors" begin # check that issue #137 is fixed - x = rand(4, 3, 5) - x2 = rand(5, 3, 5) - @test_throws ArgumentError ess_rhat(x; split_chains=1) - ess_rhat(x2; split_chains=1) - @test_throws ArgumentError ess_rhat(x2; split_chains=2) - x3 = rand(100, 3, 5) - ess_rhat(x3; maxlag=1) - @test_throws DomainError ess_rhat(x3; maxlag=0) - end - - @testset "ESS and R̂ with Union{Missing,Float64} eltype" begin - x = Array{Union{Missing,Float64}}(undef, 1000, 4, 3) - x .= randn.() - x[1, 1, 1] = missing - S, R = ess_rhat(x) - @test ismissing(S[1]) - @test ismissing(R[1]) - @test !any(ismissing, S[2:3]) - @test !any(ismissing, R[2:3]) - end - @testset "Autocov of ESSMethod and FFTESSMethod equivalent to StatsBase" begin x = randn(1_000, 10, 40) ess_exp = ess(x; method=ExplicitESSMethod()) From 54d0046b11434d83fbe47ab5eec45e4013e68723 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 21:38:48 +0100 Subject: [PATCH 22/51] Test against ess --- test/ess.jl | 8 ++++---- test/mcse.jl | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/ess.jl b/test/ess.jl index 3b0e16c4..a378267a 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -179,7 +179,7 @@ end @test all(>(2), rhat_array2) end - @testset "ess(x; estimator=f)[1]" begin + @testset "ess(x; estimator=f)" begin # we check the ESS estimates by simulating uncorrelated, correlated, and # anticorrelated chains, mapping the draws to a target distribution, computing the # estimand, and estimating the ESS for the chosen estimator, computing the @@ -230,12 +230,12 @@ end end end - @testset "ess_rhat(x; type=:bulk)" begin + @testset "ess(x; type=:bulk)" begin xnorm = randn(1_000, 4, 10) - @test ess_rhat(xnorm; type=:bulk) == ess_rhat(_rank_normalize(xnorm); type=:basic) + @test ess(xnorm; type=:bulk) == ess(_rank_normalize(xnorm); type=:basic) xcauchy = quantile.(Cauchy(), cdf.(Normal(), xnorm)) # transformation by any monotonic function should not change the bulk ESS/R-hat - @test ess_rhat(xnorm; type=:bulk) == ess_rhat(xcauchy; type=:bulk) + @test ess(xnorm; type=:bulk) == ess(xcauchy; type=:bulk) end @testset "tail- ESS and R-hat detect mismatched scales" begin diff --git a/test/mcse.jl b/test/mcse.jl index 25190e5b..43ac247d 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -11,7 +11,7 @@ using StatsBase @test mcse(x) == mcse(x; estimator=mean) end - @testset "ESS-based methods forward kwargs to ess_rhat" begin + @testset "ESS-based methods forward kwargs to ess" begin x = randn(100, 4, 10) @testset for f in [mean, median, std, Base.Fix2(quantile, 0.1)] @test @inferred(mcse(x; estimator=f, split_chains=1)) ≠ mcse(x; estimator=f) From f13a45c50c33aa3a18a6a8eceba4e815d567791b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 21:38:55 +0100 Subject: [PATCH 23/51] Add consistency test --- test/ess.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/ess.jl b/test/ess.jl index a378267a..736e4e98 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -112,6 +112,23 @@ end end end + @testset "ess, ess_rhat, and rhat consistency" begin + x = randn(1000, 4, 10) + @testset for type in [:rank, :bulk, :tail, :basic], split_chains in [1, 2] + R1 = rhat(x; type=type, split_chains=split_chains) + @testset for method in [ESSMethod(), BDAESSMethod()], maxlag in [100, 10] + S1 = ess( + x; type=type, split_chains=split_chains, method=method, maxlag=maxlag + ) + S2, R2 = ess_rhat( + x; type=type, split_chains=split_chains, method=method, maxlag=maxlag + ) + @test S1 == S2 + @test R1 == R2 + end + end + end + @testset "ESS and R̂ (IID samples)" begin # Repeat tests with different scales @testset for scale in (1, 50, 100), nchains in (1, 10), split_chains in (1, 2) From 89facddfc4594b9f7bdf774bc134e54f574d4c01 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 22:05:16 +0100 Subject: [PATCH 24/51] Add inferrability tests for estimators --- test/ess.jl | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/test/ess.jl b/test/ess.jl index 736e4e98..6a553339 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -39,14 +39,24 @@ end @testset for T in (Float32, Float64) x = rand(T, 100, 4, 2) TV = Vector{T} - @inferred TV ess(x; type=type) - @inferred Tuple{TV,TV} ess_rhat(x; type=type) + @test @inferred(ess(x; type=type)) isa TV + @test @inferred(ess_rhat(x; type=type)) isa Tuple{TV,TV} end @testset "Int" begin x = rand(1:10, 100, 4, 2) TV = Vector{Float64} - @inferred TV ess(x; type=type) - @inferred Tuple{TV,TV} ess_rhat(x; type=type) + @test @inferred(ess(x; type=type)) isa TV + @test @inferred(ess_rhat(x; type=type)) isa Tuple{TV,TV} + end + end + @testset for estimator in [mean, median, mad, std, Base.Fix2(quantile, 0.25)] + @testset for T in (Float32, Float64) + x = rand(T, 100, 4, 2) + @test @inferred(ess(x; estimator=estimator)) isa Vector{T} + end + @testset "Int" begin + x = rand(1:10, 100, 4, 2) + @test @inferred(ess(x; estimator=estimator)) isa Vector{Float64} end end end From 19cececd4e9a04a383b3b69c1bb81f3bbbaf43d6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 23:31:11 +0100 Subject: [PATCH 25/51] Add more error tests --- test/ess.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/ess.jl b/test/ess.jl index 6a553339..ad079ec0 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -32,6 +32,8 @@ function LogDensityProblems.capabilities(p::CauchyProblem) return LogDensityProblems.LogDensityOrder{1}() end +mymean(x) = mean(x) + @testset "ess.jl" begin @testset "ess/ess_rhat basics" begin @testset "only promote eltype when necessary" begin @@ -76,6 +78,10 @@ end @test_throws DomainError ess(x3; maxlag=0, type=type) @test_throws DomainError ess_rhat(x3; maxlag=0, type=type) end + @test_throws ArgumentError ess(x2; type=:rank, estimator=mean) + @test_throws ArgumentError ess(x2; estimator=mymean) + @test_throws ArgumentError ess(x2; type=:foo) + @test_throws ArgumentError ess_rhat(x2; type=:foo) end @testset "Union{Missing,Float64} eltype" begin From 083011a651f7150b412ba0b10564a9b961a808de Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 24 Feb 2023 23:31:32 +0100 Subject: [PATCH 26/51] Try to resolve type-instability --- src/ess.jl | 44 +++++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/src/ess.jl b/src/ess.jl index 9b494bc8..be29cdac 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -259,36 +259,39 @@ monotonic. doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) """ -ess(samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) = _ess(samples; kwargs...) - -function _ess( +function ess( samples::AbstractArray{<:Union{Missing,Real},3}; estimator=nothing, - type=estimator === nothing ? Val(:bulk) : nothing, + type=nothing, kwargs..., ) if estimator !== nothing && type !== nothing throw(ArgumentError("only one of `estimator` and `type` can be specified")) elseif estimator !== nothing - x = _expectand_proxy(estimator, samples) - if x === nothing - throw(ArgumentError("the estimator $estimator is not yet supported by `ess`")) - end - return _ess(Val(:basic), x; kwargs...) - else + return _ess(estimator, samples; kwargs...) + elseif type !== nothing return _ess(_val(type), samples; kwargs...) + else + return _ess(Val(:basic), samples; kwargs...) end end +function _ess(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + x = _expectand_proxy(estimator, samples) + if x === nothing + throw(ArgumentError("the estimator $estimator is not yet supported by `ess`")) + end + return _ess(Val(:basic), x; kwargs...) +end function _ess( ::Val{T}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) where {T} return throw(ArgumentError("the `type` `$T` is not supported by `ess`")) end function _ess(type::Val{:basic}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - return first(ess_rhat(samples; type=type, kwargs...)) + return first(_ess_rhat(type, samples; kwargs...)) end function _ess(type::Val{:bulk}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - return first(ess_rhat(samples; type=type, kwargs...)) + return first(_ess_rhat(type, samples; kwargs...)) end function _ess( ::Val{:tail}, @@ -300,8 +303,8 @@ function _ess( T = Base.promote_eltype(x, tail_prob) pl = convert(T, tail_prob / 2) pu = convert(T, 1 - tail_prob / 2) - S_lower = ess(x; estimator=Base.Fix2(Statistics.quantile, pl), kwargs...) - S_upper = ess(x; estimator=Base.Fix2(Statistics.quantile, pu), kwargs...) + S_lower = _ess(Base.Fix2(Statistics.quantile, pl), x; kwargs...) + S_upper = _ess(Base.Fix2(Statistics.quantile, pu), x; kwargs...) return map(min, S_lower, S_upper) end function _ess(::Val{:rank}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) @@ -320,7 +323,7 @@ calling `ess` and `rhat` separately. See [`rhat`](@ref) for a description of supported `type`s and [`ess`](@ref) for a description of `kwargs`. """ -function ess_rhat( +@inline function ess_rhat( samples::AbstractArray{<:Union{Missing,Real},3}; type=Val(:rank), kwargs... ) return _ess_rhat(_val(type), samples; kwargs...) @@ -461,17 +464,20 @@ function _ess_rhat(::Val{:bulk}, x::AbstractArray{<:Union{Missing,Real},3}; kwar return _ess_rhat(Val(:basic), _rank_normalize(x); kwargs...) end function _ess_rhat( - ::Val{:tail}, x::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2, kwargs... + type::Val{:tail}, + x::AbstractArray{<:Union{Missing,Real},3}; + split_chains::Int=2, + kwargs..., ) - S = ess(x; type=Val(:tail), split_chains=split_chains, kwargs...) - R = rhat(x; type=Val(:tail), split_chains=split_chains) + S = _ess(type, x; split_chains=split_chains, kwargs...) + R = _rhat(type, x; split_chains=split_chains) return S, R end function _ess_rhat( ::Val{:rank}, x::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2, kwargs... ) Sbulk, Rbulk = _ess_rhat(Val(:bulk), x; split_chains=split_chains, kwargs...) - Rtail = rhat(x; type=Val(:tail), split_chains=split_chains) + Rtail = _rhat(Val(:tail), x; split_chains=split_chains) Rrank = map(max, Rtail, Rbulk) return Sbulk, Rrank end From fcbbc43684e491dd0cfa04ab6b7e456b73d0e2e6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 13:17:44 +0100 Subject: [PATCH 27/51] Add Compat for `@constprop` macro --- Project.toml | 2 ++ src/MCMCDiagnosticTools.jl | 1 + 2 files changed, 3 insertions(+) diff --git a/Project.toml b/Project.toml index 46e26b15..06ba3e2f 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.3.0-DEV" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -19,6 +20,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] AbstractFFTs = "0.5, 1" +Compat = "3.36.0, 4" DataAPI = "1.6" DataStructures = "0.18.3" Distributions = "0.25" diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index b197c7fd..990feb52 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -1,6 +1,7 @@ module MCMCDiagnosticTools using AbstractFFTs: AbstractFFTs +using Compat: @constprop using DataAPI: DataAPI using DataStructures: DataStructures using Distributions: Distributions From 50ec38de32efef3ea537fb790c1d07345d0a7232 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 13:34:26 +0100 Subject: [PATCH 28/51] Aggressively constprop --- src/ess.jl | 4 ++-- src/rhat.jl | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/ess.jl b/src/ess.jl index be29cdac..5a30a1b1 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -259,7 +259,7 @@ monotonic. doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) """ -function ess( +@constprop :aggressive function ess( samples::AbstractArray{<:Union{Missing,Real},3}; estimator=nothing, type=nothing, @@ -323,7 +323,7 @@ calling `ess` and `rhat` separately. See [`rhat`](@ref) for a description of supported `type`s and [`ess`](@ref) for a description of `kwargs`. """ -@inline function ess_rhat( +@constprop :aggressive function ess_rhat( samples::AbstractArray{<:Union{Missing,Real},3}; type=Val(:rank), kwargs... ) return _ess_rhat(_val(type), samples; kwargs...) diff --git a/src/rhat.jl b/src/rhat.jl index d6a92f22..b3210b9e 100644 --- a/src/rhat.jl +++ b/src/rhat.jl @@ -31,7 +31,9 @@ The following types are supported: doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) """ -function rhat(samples::AbstractArray{<:Union{Missing,Real},3}; type=Val(:rank), kwargs...) +@constprop :aggressive function rhat( + samples::AbstractArray{<:Union{Missing,Real},3}; type=Val(:rank), kwargs... +) return _rhat(_val(type), samples; kwargs...) end From c0076785d86e1b5de92e75d566786b86d9ff65ad Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 13:34:40 +0100 Subject: [PATCH 29/51] Test type-inferrability with Vals --- test/ess.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ess.jl b/test/ess.jl index ad079ec0..47e911e8 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -37,7 +37,7 @@ mymean(x) = mean(x) @testset "ess.jl" begin @testset "ess/ess_rhat basics" begin @testset "only promote eltype when necessary" begin - @testset for type in [:rank, :bulk, :tail, :basic] + @testset for type in map(Val, [:rank, :bulk, :tail, :basic]) @testset for T in (Float32, Float64) x = rand(T, 100, 4, 2) TV = Vector{T} From 546cb37972ab1b5648337ee78f75ddbaabbe6c6b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 14:47:50 +0100 Subject: [PATCH 30/51] Make work for typeunions with Missing on v1.6 --- src/utils.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 25d9bf54..c8c1d662 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -166,7 +166,8 @@ and then transforming the ranks to normal quantiles so that the result is standa normally distributed. """ function _rank_normalize(x::AbstractArray{<:Any,3}) - y = similar(x, float(eltype(x))) + T = promote_type(eltype(x), typeof(zero(eltype(x)) / 1)) + y = similar(x, T) map(_rank_normalize!, eachslice(y; dims=3), eachslice(x; dims=3)) return y end From d3f2a799a370b4365c25ded43108a55304e72aaa Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 14:59:15 +0100 Subject: [PATCH 31/51] Avoid repeating split_chains docs --- src/ess.jl | 12 +++++++----- src/rhat.jl | 5 +---- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/ess.jl b/src/ess.jl index 5a30a1b1..1ccb4923 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -1,6 +1,12 @@ # methods abstract type AbstractESSMethod end +const _DOC_SPLIT_CHAINS = + """`split_chains` indicates the number of chains each chain is split into. + When `split_chains > 1`, then the diagnostics check for within-chain convergence. When + `d = mod(draws, split_chains) > 0`, i.e. the chains cannot be evenly split, then 1 draw + is discarded after each of the first `d` splits within each chain.""" + """ ESSMethod <: AbstractESSMethod @@ -211,11 +217,7 @@ Estimate the effective sample size (ESS) of the `samples` of shape Optionally, only one of the `type` of ESS estimate to return or the `estimator` for which ESS is computed can be specified (see below). Some `type`s accept additional `kwargs`. -`split_chains` indicates the number of chains each chain is split into. -When `split_chains > 1`, then the diagnostics check for within-chain convergence. When -`d = mod(draws, split_chains) > 0`, i.e. the chains cannot be evenly split, then 1 draw -is discarded after each of the first `d` splits within each chain. There must be at least -3 draws in each chain after splitting. +$_DOC_SPLIT_CHAINS There must be at least 3 draws in each chain after splitting. `maxlag` indicates the maximum lag for which autocovariance is computed and must be greater than 0. diff --git a/src/rhat.jl b/src/rhat.jl index b3210b9e..f655e956 100644 --- a/src/rhat.jl +++ b/src/rhat.jl @@ -6,10 +6,7 @@ Compute the ``\\widehat{R}`` diagnostics for each parameter in `samples` of shap `type` indicates the type of ``\\widehat{R}`` to compute (see below). -`split_chains` indicates the number of chains each chain is split into. -When `split_chains > 1`, then the diagnostics check for within-chain convergence. When -`d = mod(draws, split_chains) > 0`, i.e. the chains cannot be evenly split, then 1 draw -is discarded after each of the first `d` splits within each chain. +$_DOC_SPLIT_CHAINS See also [`ess`](@ref), [`ess_rhat`](@ref), [`rstar`](@ref) From 79333428651240c534039754ecc8fdd021a0f4d6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 15:00:38 +0100 Subject: [PATCH 32/51] Rename ess.jl to ess_rhat.jl --- src/MCMCDiagnosticTools.jl | 2 +- src/{ess.jl => ess_rhat.jl} | 0 test/{ess.jl => ess_rhat.jl} | 0 test/runtests.jl | 4 ++-- 4 files changed, 3 insertions(+), 3 deletions(-) rename src/{ess.jl => ess_rhat.jl} (100%) rename test/{ess.jl => ess_rhat.jl} (100%) diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index 990feb52..be961c14 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -29,7 +29,7 @@ export rstar include("utils.jl") include("bfmi.jl") include("discretediag.jl") -include("ess.jl") +include("ess_rhat.jl") include("gelmandiag.jl") include("gewekediag.jl") include("heideldiag.jl") diff --git a/src/ess.jl b/src/ess_rhat.jl similarity index 100% rename from src/ess.jl rename to src/ess_rhat.jl diff --git a/test/ess.jl b/test/ess_rhat.jl similarity index 100% rename from test/ess.jl rename to test/ess_rhat.jl diff --git a/test/runtests.jl b/test/runtests.jl index 17507a39..b756fbe2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,8 +21,8 @@ Random.seed!(1) @testset "discrete diagnostic" begin include("discretediag.jl") end - @testset "ESS" begin - include("ess.jl") + @testset "ESS and R̂" begin + include("ess_rhat.jl") end @testset "Monte Carlo standard error" begin include("mcse.jl") From 60395395379f9618f6ebf2d03a56b6ea7ef6dc13 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 15:02:10 +0100 Subject: [PATCH 33/51] Move rhat to same file as ess --- src/ess_rhat.jl | 100 ++++++++++++++++++++++++++++++++++++++++++++++++ src/rhat.jl | 97 ---------------------------------------------- 2 files changed, 100 insertions(+), 97 deletions(-) delete mode 100644 src/rhat.jl diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 1ccb4923..11b668c5 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -313,6 +313,106 @@ function _ess(::Val{:rank}, samples::AbstractArray{<:Union{Missing,Real},3}; kwa return _ess(Val(:bulk), samples; kwargs...) end + +""" + rhat(samples::AbstractArray{Union{Real,Missing},3}; type=:rank, split_chains=2) + +Compute the ``\\widehat{R}`` diagnostics for each parameter in `samples` of shape +`(chains, draws, parameters)`. [^VehtariGelman2021] + +`type` indicates the type of ``\\widehat{R}`` to compute (see below). + +$_DOC_SPLIT_CHAINS + +See also [`ess`](@ref), [`ess_rhat`](@ref), [`rstar`](@ref) + +## Types + +The following types are supported: +- `:rank`: maximum of ``\\widehat{R}`` with `type=:bulk` and `type=:tail`. +- `:bulk`: basic ``\\widehat{R}``` computed on rank-normalized draws. This type diagnoses + poor convergence in the bulk of the distribution due to trends or different locations of + the chains. +- `:tail`: ``\\widehat{R}`` computed on draws folded around the median and then + rank-normalized. This type diagnoses poor convergence in the tails of the distribution + due to different scales of the chains. +- `:basic`: Classic ``\\widehat{R}``. + +[^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021). + Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for + assessing convergence of MCMC. Bayesian Analysis. + doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) + arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) +""" +@constprop :aggressive function rhat( + samples::AbstractArray{<:Union{Missing,Real},3}; type=Val(:rank), kwargs... +) + return _rhat(_val(type), samples; kwargs...) +end + +function _rhat( + ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2 +) + # compute size of matrices (each chain may be split!) + niter = size(chains, 1) ÷ split_chains + nchains = split_chains * size(chains, 2) + axes_out = (axes(chains, 3),) + T = promote_type(eltype(chains), typeof(zero(eltype(chains)) / 1)) + + # define output arrays + rhat = similar(chains, T, axes_out) + + T === Missing && return rhat + + # define caches for mean and variance + chain_mean = Array{T}(undef, 1, nchains) + chain_var = Array{T}(undef, nchains) + samples = Array{T}(undef, niter, nchains) + + # compute correction factor + correctionfactor = (niter - 1)//niter + + # for each parameter + for (i, chains_slice) in zip(eachindex(rhat), eachslice(chains; dims=3)) + # check that no values are missing + if any(x -> x === missing, chains_slice) + rhat[i] = missing + continue + end + + # split chains + copyto_split!(samples, chains_slice) + + # calculate mean of chains + Statistics.mean!(chain_mean, samples) + + # calculate within-chain variance + @inbounds for j in 1:nchains + chain_var[j] = Statistics.var( + view(samples, :, j); mean=chain_mean[j], corrected=true + ) + end + W = Statistics.mean(chain_var) + + # compute variance estimator var₊, which accounts for between-chain variance as well + # avoid NaN when nchains=1 and set the variance estimator var₊ to the the within-chain variance in that case + var₊ = correctionfactor * W + Statistics.var(chain_mean; corrected=(nchains > 1)) + + # estimate rhat + rhat[i] = sqrt(var₊ / W) + end + + return rhat +end +_rhat(::Val{:bulk}, x; kwargs...) = _rhat(Val(:basic), _rank_normalize(x); kwargs...) +_rhat(::Val{:tail}, x; kwargs...) = _rhat(Val(:bulk), _fold_around_median(x); kwargs...) +function _rhat(::Val{:rank}, x; kwargs...) + Rbulk = _rhat(Val(:bulk), x; kwargs...) + Rtail = _rhat(Val(:tail), x; kwargs...) + return map(max, Rtail, Rbulk) +end + + """ ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; type=:rank, kwargs...) diff --git a/src/rhat.jl b/src/rhat.jl deleted file mode 100644 index f655e956..00000000 --- a/src/rhat.jl +++ /dev/null @@ -1,97 +0,0 @@ -""" - rhat(samples::AbstractArray{Union{Real,Missing},3}; type=:rank, split_chains=2) - -Compute the ``\\widehat{R}`` diagnostics for each parameter in `samples` of shape -`(chains, draws, parameters)`. [^VehtariGelman2021] - -`type` indicates the type of ``\\widehat{R}`` to compute (see below). - -$_DOC_SPLIT_CHAINS - -See also [`ess`](@ref), [`ess_rhat`](@ref), [`rstar`](@ref) - -## Types - -The following types are supported: -- `:rank`: maximum of ``\\widehat{R}`` with `type=:bulk` and `type=:tail`. -- `:bulk`: basic ``\\widehat{R}``` computed on rank-normalized draws. This type diagnoses - poor convergence in the bulk of the distribution due to trends or different locations of - the chains. -- `:tail`: ``\\widehat{R}`` computed on draws folded around the median and then - rank-normalized. This type diagnoses poor convergence in the tails of the distribution - due to different scales of the chains. -- `:basic`: Classic ``\\widehat{R}``. - -[^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021). - Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for - assessing convergence of MCMC. Bayesian Analysis. - doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) - arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) -""" -@constprop :aggressive function rhat( - samples::AbstractArray{<:Union{Missing,Real},3}; type=Val(:rank), kwargs... -) - return _rhat(_val(type), samples; kwargs...) -end - -function _rhat( - ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2 -) - # compute size of matrices (each chain may be split!) - niter = size(chains, 1) ÷ split_chains - nchains = split_chains * size(chains, 2) - axes_out = (axes(chains, 3),) - T = promote_type(eltype(chains), typeof(zero(eltype(chains)) / 1)) - - # define output arrays - rhat = similar(chains, T, axes_out) - - T === Missing && return rhat - - # define caches for mean and variance - chain_mean = Array{T}(undef, 1, nchains) - chain_var = Array{T}(undef, nchains) - samples = Array{T}(undef, niter, nchains) - - # compute correction factor - correctionfactor = (niter - 1)//niter - - # for each parameter - for (i, chains_slice) in zip(eachindex(rhat), eachslice(chains; dims=3)) - # check that no values are missing - if any(x -> x === missing, chains_slice) - rhat[i] = missing - continue - end - - # split chains - copyto_split!(samples, chains_slice) - - # calculate mean of chains - Statistics.mean!(chain_mean, samples) - - # calculate within-chain variance - @inbounds for j in 1:nchains - chain_var[j] = Statistics.var( - view(samples, :, j); mean=chain_mean[j], corrected=true - ) - end - W = Statistics.mean(chain_var) - - # compute variance estimator var₊, which accounts for between-chain variance as well - # avoid NaN when nchains=1 and set the variance estimator var₊ to the the within-chain variance in that case - var₊ = correctionfactor * W + Statistics.var(chain_mean; corrected=(nchains > 1)) - - # estimate rhat - rhat[i] = sqrt(var₊ / W) - end - - return rhat -end -_rhat(::Val{:bulk}, x; kwargs...) = _rhat(Val(:basic), _rank_normalize(x); kwargs...) -_rhat(::Val{:tail}, x; kwargs...) = _rhat(Val(:bulk), _fold_around_median(x); kwargs...) -function _rhat(::Val{:rank}, x; kwargs...) - Rbulk = _rhat(Val(:bulk), x; kwargs...) - Rtail = _rhat(Val(:tail), x; kwargs...) - return map(max, Rtail, Rbulk) -end From 49ea9bc545a38af3367d78963b81dce178c48f53 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 15:03:22 +0100 Subject: [PATCH 34/51] Update costring --- src/ess_rhat.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 11b668c5..6c9ea3f0 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -239,7 +239,7 @@ The ESS and ``\\widehat{R}`` values can be computed for the following estimators ## Types -If no `estimator` is provided, the following types of ESS estimates may be computed: +If no `estimator` is provided, the following `type`s of ESS estimates may be computed: - `:bulk`/`:rank`: mean-ESS computed on rank-normalized draws. This type diagnoses poor convergence in the bulk of the distribution due to trends or different locations of the chains. @@ -328,7 +328,7 @@ See also [`ess`](@ref), [`ess_rhat`](@ref), [`rstar`](@ref) ## Types -The following types are supported: +The following `type`s are supported: - `:rank`: maximum of ``\\widehat{R}`` with `type=:bulk` and `type=:tail`. - `:bulk`: basic ``\\widehat{R}``` computed on rank-normalized draws. This type diagnoses poor convergence in the bulk of the distribution due to trends or different locations of From a6cb1e1e78b1ea80108a1606ddff338e052defbe Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 15:05:19 +0100 Subject: [PATCH 35/51] Run formatter --- src/ess_rhat.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 6c9ea3f0..163986f4 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -1,11 +1,10 @@ # methods abstract type AbstractESSMethod end -const _DOC_SPLIT_CHAINS = - """`split_chains` indicates the number of chains each chain is split into. - When `split_chains > 1`, then the diagnostics check for within-chain convergence. When - `d = mod(draws, split_chains) > 0`, i.e. the chains cannot be evenly split, then 1 draw - is discarded after each of the first `d` splits within each chain.""" +const _DOC_SPLIT_CHAINS = """`split_chains` indicates the number of chains each chain is split into. + When `split_chains > 1`, then the diagnostics check for within-chain convergence. When + `d = mod(draws, split_chains) > 0`, i.e. the chains cannot be evenly split, then 1 draw + is discarded after each of the first `d` splits within each chain.""" """ ESSMethod <: AbstractESSMethod @@ -313,7 +312,6 @@ function _ess(::Val{:rank}, samples::AbstractArray{<:Union{Missing,Real},3}; kwa return _ess(Val(:bulk), samples; kwargs...) end - """ rhat(samples::AbstractArray{Union{Real,Missing},3}; type=:rank, split_chains=2) @@ -412,7 +410,6 @@ function _rhat(::Val{:rank}, x; kwargs...) return map(max, Rtail, Rbulk) end - """ ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; type=:rank, kwargs...) From aa5ff07e378740af51f0ef372656dcd5b5520354 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 15:05:26 +0100 Subject: [PATCH 36/51] Remove rhat.jl --- src/MCMCDiagnosticTools.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index be961c14..9df40874 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -35,6 +35,5 @@ include("gewekediag.jl") include("heideldiag.jl") include("mcse.jl") include("rafterydiag.jl") -include("rhat.jl") include("rstar.jl") end From dc80418ff7dd37207beb795a41387654995dde36 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 15:34:14 +0100 Subject: [PATCH 37/51] Better unify tests --- test/ess_rhat.jl | 97 +++++++++++++++++++++++++++++------------------- 1 file changed, 58 insertions(+), 39 deletions(-) diff --git a/test/ess_rhat.jl b/test/ess_rhat.jl index 47e911e8..d97e5d9c 100644 --- a/test/ess_rhat.jl +++ b/test/ess_rhat.jl @@ -34,20 +34,22 @@ end mymean(x) = mean(x) -@testset "ess.jl" begin - @testset "ess/ess_rhat basics" begin +@testset "ess_rhat.jl" begin + @testset "ess/ess_rhat/rhat basics" begin @testset "only promote eltype when necessary" begin @testset for type in map(Val, [:rank, :bulk, :tail, :basic]) @testset for T in (Float32, Float64) x = rand(T, 100, 4, 2) TV = Vector{T} @test @inferred(ess(x; type=type)) isa TV + @test @inferred(rhat(x; type=type)) isa TV @test @inferred(ess_rhat(x; type=type)) isa Tuple{TV,TV} end @testset "Int" begin x = rand(1:10, 100, 4, 2) TV = Vector{Float64} @test @inferred(ess(x; type=type)) isa TV + @test @inferred(rhat(x; type=type)) isa TV @test @inferred(ess_rhat(x; type=type)) isa Tuple{TV,TV} end end @@ -67,21 +69,19 @@ mymean(x) = mean(x) x = rand(4, 3, 5) x2 = rand(5, 3, 5) x3 = rand(100, 3, 5) - @testset for type in [:rank, :bulk, :tail, :basic] - @test_throws ArgumentError ess(x; split_chains=1, type=type) - @test_throws ArgumentError ess_rhat(x; split_chains=1, type=type) - @test ess(x2; split_chains=1, type=type) == - ess_rhat(x2; split_chains=1, type=type)[1] - @test_throws ArgumentError ess(x2; split_chains=2, type=type) - @test_throws ArgumentError ess_rhat(x2; split_chains=2, type=type) - @test ess(x3; maxlag=1, type=type) == ess_rhat(x3; maxlag=1, type=type)[1] - @test_throws DomainError ess(x3; maxlag=0, type=type) - @test_throws DomainError ess_rhat(x3; maxlag=0, type=type) + @testset for f in [ess, ess_rhat] + @testset for type in [:rank, :bulk, :tail, :basic] + @test_throws ArgumentError f(x; split_chains=1, type=type) + f(x2; split_chains=1, type=type) + @test_throws ArgumentError f(x2; split_chains=2, type=type) + f(x3; maxlag=1, type=type) + @test_throws DomainError f(x3; maxlag=0, type=type) + end + @test_throws ArgumentError f(x2; type=:foo) end + @test_throws ArgumentError rhat(x2; type=:foo) @test_throws ArgumentError ess(x2; type=:rank, estimator=mean) @test_throws ArgumentError ess(x2; estimator=mymean) - @test_throws ArgumentError ess(x2; type=:foo) - @test_throws ArgumentError ess_rhat(x2; type=:foo) end @testset "Union{Missing,Float64} eltype" begin @@ -90,13 +90,16 @@ mymean(x) = mean(x) x .= randn.() x[1, 1, 1] = missing S1 = ess(x; type=type) - S2, R = ess_rhat(x; type=type) + R1 = rhat(x; type=type) + S2, R2 = ess_rhat(x; type=type) @test ismissing(S1[1]) + @test ismissing(R1[1]) @test ismissing(S2[1]) - @test ismissing(R[1]) + @test ismissing(R2[1]) @test !any(ismissing, S1[2:3]) + @test !any(ismissing, R1[2:3]) @test !any(ismissing, S2[2:3]) - @test !any(ismissing, R[2:3]) + @test !any(ismissing, R2[2:3]) end end @@ -106,45 +109,61 @@ mymean(x) = mean(x) x = randn(100, 4, 5) y = OffsetArray(x, -5:94, 2:5, 11:15) S11 = ess(y; type=type) - S12, R1 = ess_rhat(y; type=type) + R11 = rhat(y; type=type) + S12, R12 = ess_rhat(y; type=type) @test S11 isa OffsetVector{Float64} @test S12 isa OffsetVector{Float64} @test axes(S11, 1) == axes(S12, 1) == axes(y, 3) - @test R1 isa OffsetVector{Float64} - @test axes(R1, 1) == axes(y, 3) + @test R11 isa OffsetVector{Float64} + @test R12 isa OffsetVector{Float64} + @test axes(R11, 1) == axes(R12, 1) == axes(y, 3) S21 = ess(x; type=type) - S22, R2 = ess_rhat(x; type=type) + R21 = rhat(x; type=type) + S22, R22 = ess_rhat(x; type=type) @test S22 == S21 == collect(S21) - @test R2 == collect(R1) + @test R21 == R22 == collect(R11) y = OffsetArray(similar(x, Missing), -5:94, 2:5, 11:15) S31 = ess(y; type=type) - S32, R3 = ess_rhat(y; type=type) + R31 = rhat(y; type=type) + S32, R32 = ess_rhat(y; type=type) @test S31 isa OffsetVector{Missing} @test S32 isa OffsetVector{Missing} @test axes(S31, 1) == axes(S32, 1) == axes(y, 3) - @test R3 isa OffsetVector{Missing} - @test axes(R3, 1) == axes(y, 3) + @test R31 isa OffsetVector{Missing} + @test R32 isa OffsetVector{Missing} + @test axes(R31, 1) == axes(R32, 1) == axes(y, 3) end end - end - @testset "ess, ess_rhat, and rhat consistency" begin - x = randn(1000, 4, 10) - @testset for type in [:rank, :bulk, :tail, :basic], split_chains in [1, 2] - R1 = rhat(x; type=type, split_chains=split_chains) - @testset for method in [ESSMethod(), BDAESSMethod()], maxlag in [100, 10] - S1 = ess( - x; type=type, split_chains=split_chains, method=method, maxlag=maxlag - ) - S2, R2 = ess_rhat( - x; type=type, split_chains=split_chains, method=method, maxlag=maxlag - ) - @test S1 == S2 - @test R1 == R2 + @testset "ess, ess_rhat, and rhat consistency" begin + x = randn(1000, 4, 10) + @testset for type in [:rank, :bulk, :tail, :basic], split_chains in [1, 2] + R1 = rhat(x; type=type, split_chains=split_chains) + @testset for method in [ESSMethod(), BDAESSMethod()], maxlag in [100, 10] + S1 = ess( + x; + type=type, + split_chains=split_chains, + method=method, + maxlag=maxlag, + ) + S2, R2 = ess_rhat( + x; + type=type, + split_chains=split_chains, + method=method, + maxlag=maxlag, + ) + @test S1 == S2 + @test R1 == R2 + end end end end + # now that we have checked mutual consistency of each method, we perform all following + # checks for whichever method is most convenient + @testset "ESS and R̂ (IID samples)" begin # Repeat tests with different scales @testset for scale in (1, 50, 100), nchains in (1, 10), split_chains in (1, 2) From a2d9d8f95dec307396ddaf18628a59786c15a484 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 15:34:26 +0100 Subject: [PATCH 38/51] Add missing type annotations --- src/ess_rhat.jl | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 163986f4..609294a3 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -283,9 +283,7 @@ function _ess(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs end return _ess(Val(:basic), x; kwargs...) end -function _ess( - ::Val{T}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... -) where {T} +function _ess(::Val{T}, ::AbstractArray{<:Union{Missing,Real},3}; kwargs...) where {T} return throw(ArgumentError("the `type` `$T` is not supported by `ess`")) end function _ess(type::Val{:basic}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) @@ -347,7 +345,9 @@ The following `type`s are supported: ) return _rhat(_val(type), samples; kwargs...) end - +function _rhat(::Val{T}, ::AbstractArray{<:Union{Missing,Real},3}; kwargs...) where {T} + return throw(ArgumentError("the `type` `$T` is not supported by `rhat`")) +end function _rhat( ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2 ) @@ -402,9 +402,13 @@ function _rhat( return rhat end -_rhat(::Val{:bulk}, x; kwargs...) = _rhat(Val(:basic), _rank_normalize(x); kwargs...) -_rhat(::Val{:tail}, x; kwargs...) = _rhat(Val(:bulk), _fold_around_median(x); kwargs...) -function _rhat(::Val{:rank}, x; kwargs...) +function _rhat(::Val{:bulk}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + return _rhat(Val(:basic), _rank_normalize(x); kwargs...) +end +function _rhat(::Val{:tail}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + return _rhat(Val(:bulk), _fold_around_median(x); kwargs...) +end +function _rhat(::Val{:rank}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) Rbulk = _rhat(Val(:bulk), x; kwargs...) Rtail = _rhat(Val(:tail), x; kwargs...) return map(max, Rtail, Rbulk) @@ -427,9 +431,7 @@ description of `kwargs`. ) return _ess_rhat(_val(type), samples; kwargs...) end -function _ess_rhat( - ::Val{T}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... -) where {T} +function _ess_rhat(::Val{T}, ::AbstractArray{<:Union{Missing,Real},3}; kwargs...) where {T} return throw(ArgumentError("the `type` `$T` is not supported by `ess_rhat`")) end function _ess_rhat( From a7977f6eb2c2fe8707490184bf857184ed7a9872 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 17:04:54 +0100 Subject: [PATCH 39/51] Apply suggestions from code review Co-authored-by: David Widmann --- src/MCMCDiagnosticTools.jl | 3 +-- src/ess_rhat.jl | 16 +++++----------- src/utils.jl | 2 +- 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index 9df40874..18edce0e 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -17,13 +17,12 @@ using Statistics: Statistics export bfmi export discretediag -export ess, ess_rhat, ESSMethod, FFTESSMethod, BDAESSMethod +export ess, ess_rhat, rhat, ESSMethod, FFTESSMethod, BDAESSMethod export gelmandiag, gelmandiag_multivariate export gewekediag export heideldiag export mcse export rafterydiag -export rhat export rstar include("utils.jl") diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 609294a3..7fafed15 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -213,8 +213,8 @@ end Estimate the effective sample size (ESS) of the `samples` of shape `(draws, chains, parameters)` with the `method`. -Optionally, only one of the `type` of ESS estimate to return or the `estimator` for which -ESS is computed can be specified (see below). Some `type`s accept additional `kwargs`. +Optionally, one can specify either the `type` of the ESS estimate or the `estimator` for which +ESS is computed (see below). Some `type`s accept additional `kwargs`. $_DOC_SPLIT_CHAINS There must be at least 3 draws in each chain after splitting. @@ -249,7 +249,7 @@ If no `estimator` is provided, the following `type`s of ESS estimates may be com - `:basic`: basic ESS, equivalent to specifying `estimator=Statistics.mean`. While Bulk-ESS is conceptually related to basic ESS, it is well-defined even if the chains -do not have finite variance.[^VehtariGelman2021]. For each parameter, rank-normalization +do not have finite variance.[^VehtariGelman2021] For each parameter, rank-normalization proceeds by first ranking the inputs using "tied ranking" and then transforming the ranks to normal quantiles so that the result is standard normally distributed. This transform is monotonic. @@ -283,13 +283,7 @@ function _ess(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs end return _ess(Val(:basic), x; kwargs...) end -function _ess(::Val{T}, ::AbstractArray{<:Union{Missing,Real},3}; kwargs...) where {T} - return throw(ArgumentError("the `type` `$T` is not supported by `ess`")) -end -function _ess(type::Val{:basic}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - return first(_ess_rhat(type, samples; kwargs...)) -end -function _ess(type::Val{:bulk}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) +function _ess(type::Val, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) return first(_ess_rhat(type, samples; kwargs...)) end function _ess( @@ -314,7 +308,7 @@ end rhat(samples::AbstractArray{Union{Real,Missing},3}; type=:rank, split_chains=2) Compute the ``\\widehat{R}`` diagnostics for each parameter in `samples` of shape -`(chains, draws, parameters)`. [^VehtariGelman2021] +`(chains, draws, parameters)`.[^VehtariGelman2021] `type` indicates the type of ``\\widehat{R}`` to compute (see below). diff --git a/src/utils.jl b/src/utils.jl index c8c1d662..fe7fb462 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -191,5 +191,5 @@ function _normal_quantiles_from_ranks!(q, r; α=3//8) return q end -_val(k) = Val(k) +_val(k::Symbol) = Val(k) _val(k::Val) = k From 3b6f1bc51a8b89814353bd92aa362db975a9a524 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 15:16:31 +0100 Subject: [PATCH 40/51] Avoid `@constprop`, rename `type` and `estimator` to `kind` --- Project.toml | 2 - src/MCMCDiagnosticTools.jl | 1 - src/ess_rhat.jl | 146 ++++++++++++++++++------------------- 3 files changed, 73 insertions(+), 76 deletions(-) diff --git a/Project.toml b/Project.toml index 06ba3e2f..46e26b15 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.3.0-DEV" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -20,7 +19,6 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] AbstractFFTs = "0.5, 1" -Compat = "3.36.0, 4" DataAPI = "1.6" DataStructures = "0.18.3" Distributions = "0.25" diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index 9df40874..fd704317 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -1,7 +1,6 @@ module MCMCDiagnosticTools using AbstractFFTs: AbstractFFTs -using Compat: @constprop using DataAPI: DataAPI using DataStructures: DataStructures using Distributions: Distributions diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 609294a3..e668f5b7 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -202,8 +202,7 @@ end """ ess( samples::AbstractArray{<:Union{Missing,Real},3}; - type=:bulk, - [estimator,] + kind=:bulk, method=ESSMethod(), split_chains::Int=2, maxlag::Int=250, @@ -213,8 +212,8 @@ end Estimate the effective sample size (ESS) of the `samples` of shape `(draws, chains, parameters)` with the `method`. -Optionally, only one of the `type` of ESS estimate to return or the `estimator` for which -ESS is computed can be specified (see below). Some `type`s accept additional `kwargs`. +Optionally, the `kind` of ESS estimate to be computed can be specified (see below). Some +`kind`s accept additional `kwargs`. $_DOC_SPLIT_CHAINS There must be at least 3 draws in each chain after splitting. @@ -227,32 +226,31 @@ For a given estimand, it is recommended that the ESS is at least `100 * chains` See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref), [`rhat`](@ref), [`ess_rhat`](@ref), [`mcse`](@ref) -## Estimators +## Kinds of ESS estimates -The ESS and ``\\widehat{R}`` values can be computed for the following estimators: -- `Statistics.mean` -- `Statistics.median` -- `Statistics.std` -- `StatsBase.mad` -- `Base.Fix2(Statistics.quantile, p::Real)` - -## Types - -If no `estimator` is provided, the following `type`s of ESS estimates may be computed: -- `:bulk`/`:rank`: mean-ESS computed on rank-normalized draws. This type diagnoses poor +If `kind` isa a `Symbol`, it may take one of the following values: +- `:bulk`/`:rank`: mean-ESS computed on rank-normalized draws. This kind diagnoses poor convergence in the bulk of the distribution due to trends or different locations of the chains. - `:tail`: minimum of the quantile-ESS for the symmetric quantiles where - `tail_prob=0.1` is the probability in the tails. This type diagnoses poor convergence in - the tails of the distribution. If this type is chosen, `kwargs` may contain a + `tail_prob=0.1` is the probability in the tails. This kind diagnoses poor convergence in + the tails of the distribution. If this kind is chosen, `kwargs` may contain a `tail_prob` keyword. -- `:basic`: basic ESS, equivalent to specifying `estimator=Statistics.mean`. +- `:basic`: basic ESS, equivalent to specifying `kind=Statistics.mean`. -While Bulk-ESS is conceptually related to basic ESS, it is well-defined even if the chains -do not have finite variance.[^VehtariGelman2021]. For each parameter, rank-normalization -proceeds by first ranking the inputs using "tied ranking" and then transforming the ranks to -normal quantiles so that the result is standard normally distributed. This transform is -monotonic. +!!! note + While Bulk-ESS is conceptually related to basic ESS, it is well-defined even if the + chains do not have finite variance.[^VehtariGelman2021]. For each parameter, + rank-normalization proceeds by first ranking the inputs using "tied ranking" and then + transforming the ranks to normal quantiles so that the result is standard normally + distributed. This transform is monotonic. + +Otherwise, `kind` specifies one of the following estimators, whose ESS is to be estimated: +- `Statistics.mean` +- `Statistics.median` +- `Statistics.std` +- `StatsBase.mad` +- `Base.Fix2(Statistics.quantile, p::Real)` [^VehtariGelman2021]: Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P. C. (2021). Rank-normalization, folding, and localization: An improved ``\\widehat {R}`` for @@ -260,22 +258,20 @@ monotonic. doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) """ -@constprop :aggressive function ess( - samples::AbstractArray{<:Union{Missing,Real},3}; - estimator=nothing, - type=nothing, - kwargs..., -) - if estimator !== nothing && type !== nothing - throw(ArgumentError("only one of `estimator` and `type` can be specified")) - elseif estimator !== nothing - return _ess(estimator, samples; kwargs...) - elseif type !== nothing - return _ess(_val(type), samples; kwargs...) - else +function ess(samples::AbstractArray{<:Union{Missing,Real},3}; kind=:bulk, kwargs...) + if kind isa Symbol + return _ess(Val(:bulk), samples; kwargs...) + elseif kind === :tail + return _ess(Val(:tail), samples; kwargs...) + elseif kind === :basic return _ess(Val(:basic), samples; kwargs...) + else + return _ess(kind, samples; kwargs...) end end +function _ess(kind::Symbol, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + return throw(ArgumentError("the `kind` `$kind` is not supported by `ess`")) +end function _ess(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) x = _expectand_proxy(estimator, samples) if x === nothing @@ -283,14 +279,11 @@ function _ess(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs end return _ess(Val(:basic), x; kwargs...) end -function _ess(::Val{T}, ::AbstractArray{<:Union{Missing,Real},3}; kwargs...) where {T} - return throw(ArgumentError("the `type` `$T` is not supported by `ess`")) +function _ess(kind::Val{:basic}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + return first(_ess_rhat(kind, samples; kwargs...)) end -function _ess(type::Val{:basic}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - return first(_ess_rhat(type, samples; kwargs...)) -end -function _ess(type::Val{:bulk}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - return first(_ess_rhat(type, samples; kwargs...)) +function _ess(kind::Val{:bulk}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) + return first(_ess_rhat(kind, samples; kwargs...)) end function _ess( ::Val{:tail}, @@ -306,31 +299,28 @@ function _ess( S_upper = _ess(Base.Fix2(Statistics.quantile, pu), x; kwargs...) return map(min, S_lower, S_upper) end -function _ess(::Val{:rank}, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - return _ess(Val(:bulk), samples; kwargs...) -end """ - rhat(samples::AbstractArray{Union{Real,Missing},3}; type=:rank, split_chains=2) + rhat(samples::AbstractArray{Union{Real,Missing},3}; kind=:rank, split_chains=2) Compute the ``\\widehat{R}`` diagnostics for each parameter in `samples` of shape `(chains, draws, parameters)`. [^VehtariGelman2021] -`type` indicates the type of ``\\widehat{R}`` to compute (see below). +`kind` indicates the kind of ``\\widehat{R}`` to compute (see below). $_DOC_SPLIT_CHAINS See also [`ess`](@ref), [`ess_rhat`](@ref), [`rstar`](@ref) -## Types +## Kinds of ``\\widehat{R}`` -The following `type`s are supported: -- `:rank`: maximum of ``\\widehat{R}`` with `type=:bulk` and `type=:tail`. -- `:bulk`: basic ``\\widehat{R}``` computed on rank-normalized draws. This type diagnoses +The following `kind`s are supported: +- `:rank`: maximum of ``\\widehat{R}`` with `kind=:bulk` and `kind=:tail`. +- `:bulk`: basic ``\\widehat{R}``` computed on rank-normalized draws. This kind diagnoses poor convergence in the bulk of the distribution due to trends or different locations of the chains. - `:tail`: ``\\widehat{R}`` computed on draws folded around the median and then - rank-normalized. This type diagnoses poor convergence in the tails of the distribution + rank-normalized. This kind diagnoses poor convergence in the tails of the distribution due to different scales of the chains. - `:basic`: Classic ``\\widehat{R}``. @@ -340,13 +330,18 @@ The following `type`s are supported: doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) """ -@constprop :aggressive function rhat( - samples::AbstractArray{<:Union{Missing,Real},3}; type=Val(:rank), kwargs... -) - return _rhat(_val(type), samples; kwargs...) -end -function _rhat(::Val{T}, ::AbstractArray{<:Union{Missing,Real},3}; kwargs...) where {T} - return throw(ArgumentError("the `type` `$T` is not supported by `rhat`")) +function rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kind=:rank, kwargs...) + if kind === :rank + return _rhat(Val(:rank), samples; kwargs...) + elseif kind === :bulk + return _rhat(Val(:bulk), samples; kwargs...) + elseif kind === :tail + return _rhat(Val(:tail), samples; kwargs...) + elseif kind === :basic + return _rhat(Val(:basic), samples; kwargs...) + else + return throw(ArgumentError("the `kind` `$kind` is not supported by `rhat`")) + end end function _rhat( ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2 @@ -415,7 +410,7 @@ function _rhat(::Val{:rank}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs.. end """ - ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; type=:rank, kwargs...) + ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kind=:rank, kwargs...) Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shape `(draws, chains, parameters)` with the `method`. @@ -423,16 +418,21 @@ Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shap When both ESS and ``\\widehat{R}`` are needed, this method is often more efficient than calling `ess` and `rhat` separately. -See [`rhat`](@ref) for a description of supported `type`s and [`ess`](@ref) for a +See [`rhat`](@ref) for a description of supported `kind`s and [`ess`](@ref) for a description of `kwargs`. """ -@constprop :aggressive function ess_rhat( - samples::AbstractArray{<:Union{Missing,Real},3}; type=Val(:rank), kwargs... -) - return _ess_rhat(_val(type), samples; kwargs...) -end -function _ess_rhat(::Val{T}, ::AbstractArray{<:Union{Missing,Real},3}; kwargs...) where {T} - return throw(ArgumentError("the `type` `$T` is not supported by `ess_rhat`")) +function ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kind=:rank, kwargs...) + if kind === :rank + return _ess_rhat(Val(:rank), samples; kwargs...) + elseif kind === :bulk + return _ess_rhat(Val(:bulk), samples; kwargs...) + elseif kind === :tail + return _ess_rhat(Val(:tail), samples; kwargs...) + elseif kind === :basic + return _ess_rhat(Val(:basic), samples; kwargs...) + else + return throw(ArgumentError("the `kind` `$kind` is not supported by `ess_rhat`")) + end end function _ess_rhat( ::Val{:basic}, @@ -565,13 +565,13 @@ function _ess_rhat(::Val{:bulk}, x::AbstractArray{<:Union{Missing,Real},3}; kwar return _ess_rhat(Val(:basic), _rank_normalize(x); kwargs...) end function _ess_rhat( - type::Val{:tail}, + kind::Val{:tail}, x::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2, kwargs..., ) - S = _ess(type, x; split_chains=split_chains, kwargs...) - R = _rhat(type, x; split_chains=split_chains) + S = _ess(kind, x; split_chains=split_chains, kwargs...) + R = _rhat(kind, x; split_chains=split_chains) return S, R end function _ess_rhat( From fa5484adcc96306dc47b4fd98b1af9bda3980263 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 15:24:16 +0100 Subject: [PATCH 41/51] Cleanup --- src/ess_rhat.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index e668f5b7..7e1de6f7 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -259,19 +259,18 @@ Otherwise, `kind` specifies one of the following estimators, whose ESS is to be arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) """ function ess(samples::AbstractArray{<:Union{Missing,Real},3}; kind=:bulk, kwargs...) - if kind isa Symbol + if kind === :bulk return _ess(Val(:bulk), samples; kwargs...) elseif kind === :tail return _ess(Val(:tail), samples; kwargs...) elseif kind === :basic return _ess(Val(:basic), samples; kwargs...) + elseif kind isa Symbol + throw(ArgumentError("the `kind` `$kind` is not supported by `ess`")) else return _ess(kind, samples; kwargs...) end end -function _ess(kind::Symbol, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - return throw(ArgumentError("the `kind` `$kind` is not supported by `ess`")) -end function _ess(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) x = _expectand_proxy(estimator, samples) if x === nothing From d2980b6ac3589682b8f1dc36a7f592899f5aa540 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 15:46:27 +0100 Subject: [PATCH 42/51] Rename `estimator` to `kind` for `mcse` --- src/mcse.jl | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/mcse.jl b/src/mcse.jl index cc7e10a4..20c28fcb 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -2,16 +2,17 @@ const normcdf1 = 0.8413447460685429 # StatsFuns.normcdf(1) const normcdfn1 = 0.15865525393145705 # StatsFuns.normcdf(-1) """ - mcse(samples::AbstractArray{<:Union{Missing,Real}}; estimator=Statistics.mean, kwargs...) + mcse(samples::AbstractArray{<:Union{Missing,Real}}; kind=Statistics.mean, kwargs...) -Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` of -shape `(draws, chains, parameters)`. +Estimate the Monte Carlo standard errors (MCSE) of the estimator `kind` applied to `samples` +of shape `(draws, chains, parameters)`. See also: [`ess`](@ref) -## Estimators +## Kinds of MCSE estimates -`estimator` must accept a vector of the same `eltype` as `samples` and return a real estimate. +The estimator whose MCSE should be estimator is specified with `kind`. `kind` must accept a +vector of the same `eltype` as `samples` and return a real estimate. For the following estimators, the effective sample size [`ess`](@ref) and an estimate of the asymptotic variance are used to compute the MCSE, and `kwargs` are forwarded to @@ -36,24 +37,22 @@ by checking the bulk- and tail-ESS values. doi: [10.1007/978-3-642-27440-4_18](https://doi.org/10.1007/978-3-642-27440-4_18) """ -function mcse( - x::AbstractArray{<:Union{Missing,Real},3}; estimator=Statistics.mean, kwargs... -) - return _mcse(estimator, x; kwargs...) +function mcse(x::AbstractArray{<:Union{Missing,Real},3}; kind=Statistics.mean, kwargs...) + return _mcse(kind, x; kwargs...) end _mcse(f, x; kwargs...) = _mcse_sbm(f, x; kwargs...) function _mcse( ::typeof(Statistics.mean), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) - S = ess(samples; estimator=Statistics.mean, kwargs...) + S = _ess(Statistics.mean, samples; kwargs...) return dropdims(Statistics.std(samples; dims=(1, 2)); dims=(1, 2)) ./ sqrt.(S) end function _mcse( ::typeof(Statistics.std), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) x = (samples .- Statistics.mean(samples; dims=(1, 2))) .^ 2 # expectand proxy - S = ess(x; estimator=Statistics.mean, kwargs...) + S = _ess(Statistics.mean, x; kwargs...) # asymptotic variance of sample variance estimate is Var[var] = E[μ₄] - E[var]², # where μ₄ is the 4th central moment # by the delta method, Var[std] = Var[var] / 4E[var] = (E[μ₄]/E[var] - E[var])/4, @@ -68,7 +67,7 @@ function _mcse( kwargs..., ) p = f.x - S = ess(samples; estimator=f, kwargs...) + S = _ess(f, samples; kwargs...) T = eltype(S) R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T)))) values = similar(S, R) @@ -80,7 +79,7 @@ end function _mcse( ::typeof(Statistics.median), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) - S = ess(samples; estimator=Statistics.median, kwargs...) + S = _ess(Statistics.median, samples; kwargs...) T = eltype(S) R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T)))) values = similar(S, R) From 61a5c0737926eed5ea311a9523ca1f7159dbcaee Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 15:46:57 +0100 Subject: [PATCH 43/51] Update tests --- test/ess_rhat.jl | 104 +++++++++++++++++++++++------------------------ test/mcse.jl | 18 ++++---- 2 files changed, 61 insertions(+), 61 deletions(-) diff --git a/test/ess_rhat.jl b/test/ess_rhat.jl index d97e5d9c..1757c1ed 100644 --- a/test/ess_rhat.jl +++ b/test/ess_rhat.jl @@ -37,30 +37,30 @@ mymean(x) = mean(x) @testset "ess_rhat.jl" begin @testset "ess/ess_rhat/rhat basics" begin @testset "only promote eltype when necessary" begin - @testset for type in map(Val, [:rank, :bulk, :tail, :basic]) + @testset for kind in [:rank, :bulk, :tail, :basic] @testset for T in (Float32, Float64) x = rand(T, 100, 4, 2) TV = Vector{T} - @test @inferred(ess(x; type=type)) isa TV - @test @inferred(rhat(x; type=type)) isa TV - @test @inferred(ess_rhat(x; type=type)) isa Tuple{TV,TV} + kind === :rank || @test @inferred(ess(x; kind=kind)) isa TV + @test @inferred(rhat(x; kind=kind)) isa TV + @test @inferred(ess_rhat(x; kind=kind)) isa Tuple{TV,TV} end @testset "Int" begin x = rand(1:10, 100, 4, 2) TV = Vector{Float64} - @test @inferred(ess(x; type=type)) isa TV - @test @inferred(rhat(x; type=type)) isa TV - @test @inferred(ess_rhat(x; type=type)) isa Tuple{TV,TV} + kind === :rank || @test @inferred(ess(x; kind=kind)) isa TV + @test @inferred(rhat(x; kind=kind)) isa TV + @test @inferred(ess_rhat(x; kind=kind)) isa Tuple{TV,TV} end end - @testset for estimator in [mean, median, mad, std, Base.Fix2(quantile, 0.25)] + @testset for kind in [mean, median, mad, std, Base.Fix2(quantile, 0.25)] @testset for T in (Float32, Float64) x = rand(T, 100, 4, 2) - @test @inferred(ess(x; estimator=estimator)) isa Vector{T} + @test @inferred(ess(x; kind=kind)) isa Vector{T} end @testset "Int" begin x = rand(1:10, 100, 4, 2) - @test @inferred(ess(x; estimator=estimator)) isa Vector{Float64} + @test @inferred(ess(x; kind=kind)) isa Vector{Float64} end end end @@ -70,28 +70,28 @@ mymean(x) = mean(x) x2 = rand(5, 3, 5) x3 = rand(100, 3, 5) @testset for f in [ess, ess_rhat] - @testset for type in [:rank, :bulk, :tail, :basic] - @test_throws ArgumentError f(x; split_chains=1, type=type) - f(x2; split_chains=1, type=type) - @test_throws ArgumentError f(x2; split_chains=2, type=type) - f(x3; maxlag=1, type=type) - @test_throws DomainError f(x3; maxlag=0, type=type) + @testset for kind in [:rank, :bulk, :tail, :basic] + f === ess && kind === :rank && continue + @test_throws ArgumentError f(x; split_chains=1, kind=kind) + f(x2; split_chains=1, kind=kind) + @test_throws ArgumentError f(x2; split_chains=2, kind=kind) + f(x3; maxlag=1, kind=kind) + @test_throws DomainError f(x3; maxlag=0, kind=kind) end - @test_throws ArgumentError f(x2; type=:foo) + @test_throws ArgumentError f(x2; kind=:foo) end - @test_throws ArgumentError rhat(x2; type=:foo) - @test_throws ArgumentError ess(x2; type=:rank, estimator=mean) - @test_throws ArgumentError ess(x2; estimator=mymean) + @test_throws ArgumentError rhat(x2; kind=:foo) + @test_throws ArgumentError ess(x2; kind=mymean) end @testset "Union{Missing,Float64} eltype" begin - @testset for type in [:rank, :bulk, :tail, :basic] + @testset for kind in [:rank, :bulk, :tail, :basic] x = Array{Union{Missing,Float64}}(undef, 1000, 4, 3) x .= randn.() x[1, 1, 1] = missing - S1 = ess(x; type=type) - R1 = rhat(x; type=type) - S2, R2 = ess_rhat(x; type=type) + S1 = ess(x; kind=kind === :rank ? :bulk : kind) + R1 = rhat(x; kind=kind) + S2, R2 = ess_rhat(x; kind=kind) @test ismissing(S1[1]) @test ismissing(R1[1]) @test ismissing(S2[1]) @@ -104,28 +104,28 @@ mymean(x) = mean(x) end @testset "produces similar vectors to inputs" begin - @testset for type in [:rank, :bulk, :tail, :basic] + @testset for kind in [:rank, :bulk, :tail, :basic] # simultaneously checks that we index correctly and that output types are correct x = randn(100, 4, 5) y = OffsetArray(x, -5:94, 2:5, 11:15) - S11 = ess(y; type=type) - R11 = rhat(y; type=type) - S12, R12 = ess_rhat(y; type=type) + S11 = ess(y; kind=kind === :rank ? :bulk : kind) + R11 = rhat(y; kind=kind) + S12, R12 = ess_rhat(y; kind=kind) @test S11 isa OffsetVector{Float64} @test S12 isa OffsetVector{Float64} @test axes(S11, 1) == axes(S12, 1) == axes(y, 3) @test R11 isa OffsetVector{Float64} @test R12 isa OffsetVector{Float64} @test axes(R11, 1) == axes(R12, 1) == axes(y, 3) - S21 = ess(x; type=type) - R21 = rhat(x; type=type) - S22, R22 = ess_rhat(x; type=type) + S21 = ess(x; kind=kind === :rank ? :bulk : kind) + R21 = rhat(x; kind=kind) + S22, R22 = ess_rhat(x; kind=kind) @test S22 == S21 == collect(S21) @test R21 == R22 == collect(R11) y = OffsetArray(similar(x, Missing), -5:94, 2:5, 11:15) - S31 = ess(y; type=type) - R31 = rhat(y; type=type) - S32, R32 = ess_rhat(y; type=type) + S31 = ess(y; kind=kind === :rank ? :bulk : kind) + R31 = rhat(y; kind=kind) + S32, R32 = ess_rhat(y; kind=kind) @test S31 isa OffsetVector{Missing} @test S32 isa OffsetVector{Missing} @test axes(S31, 1) == axes(S32, 1) == axes(y, 3) @@ -137,19 +137,19 @@ mymean(x) = mean(x) @testset "ess, ess_rhat, and rhat consistency" begin x = randn(1000, 4, 10) - @testset for type in [:rank, :bulk, :tail, :basic], split_chains in [1, 2] - R1 = rhat(x; type=type, split_chains=split_chains) + @testset for kind in [:rank, :bulk, :tail, :basic], split_chains in [1, 2] + R1 = rhat(x; kind=kind, split_chains=split_chains) @testset for method in [ESSMethod(), BDAESSMethod()], maxlag in [100, 10] S1 = ess( x; - type=type, + kind=kind === :rank ? :bulk : kind, split_chains=split_chains, method=method, maxlag=maxlag, ) S2, R2 = ess_rhat( x; - type=type, + kind=kind, split_chains=split_chains, method=method, maxlag=maxlag, @@ -224,14 +224,14 @@ mymean(x) = mean(x) @testset "ESS and R̂ for chains with 2 epochs that have not mixed" begin # checks that splitting yields lower ESS estimates and higher Rhat estimates x = randn(1000, 4, 10) .+ repeat([0, 10]; inner=(500, 1, 1)) - ess_array, rhat_array = ess_rhat(x; type=:basic, split_chains=1) + ess_array, rhat_array = ess_rhat(x; kind=:basic, split_chains=1) @test all(x -> isapprox(x, 1; rtol=0.1), rhat_array) - ess_array2, rhat_array2 = ess_rhat(x; type=:basic, split_chains=2) + ess_array2, rhat_array2 = ess_rhat(x; kind=:basic, split_chains=2) @test all(ess_array2 .< ess_array) @test all(>(2), rhat_array2) end - @testset "ess(x; estimator=f)" begin + @testset "ess(x; kind=f)" begin # we check the ESS estimates by simulating uncorrelated, correlated, and # anticorrelated chains, mapping the draws to a target distribution, computing the # estimand, and estimating the ESS for the chosen estimator, computing the @@ -242,7 +242,7 @@ mymean(x) = mean(x) nparams = 100 x = randn(ndraws, nchains, nparams) mymean(x; kwargs...) = mean(x; kwargs...) - @test_throws ArgumentError ess(x; estimator=mymean) + @test_throws ArgumentError ess(x; kind=mymean) estimators = [mean, median, std, mad, Base.Fix2(quantile, 0.25)] dists = [Normal(10, 100), Exponential(10), TDist(7) * 10 - 20] # AR(1) coefficients. 0 is IID, -0.3 is slightly anticorrelated, 0.9 is highly autocorrelated @@ -257,7 +257,7 @@ mymean(x) = mean(x) x .= quantile.(dist, cdf.(Normal(), x)) # stationary distribution is dist μ_mean = dropdims(mapslices(f ∘ vec, x; dims=(1, 2)); dims=(1, 2)) dist = asymptotic_dist(f, dist) - n = @inferred(ess(x; estimator=f)) + n = @inferred(ess(x; kind=f)) μ = mean(dist) mcse = sqrt.(var(dist) ./ n) for i in eachindex(μ_mean, mcse) @@ -275,19 +275,19 @@ mymean(x) = mean(x) nchains = 4 @testset for ndraws in (10, 100), φ in (-0.3, -0.9) x = ar1(φ, sqrt(1 - φ^2), ndraws, nchains, 1000) - Smin, Smax = extrema(ess(x; estimator=mean)) + Smin, Smax = extrema(ess(x; kind=mean)) ntotal = ndraws * nchains @test Smax == ntotal * log10(ntotal) @test Smin > 0 end end - @testset "ess(x; type=:bulk)" begin + @testset "ess(x; kind=:bulk)" begin xnorm = randn(1_000, 4, 10) - @test ess(xnorm; type=:bulk) == ess(_rank_normalize(xnorm); type=:basic) + @test ess(xnorm; kind=:bulk) == ess(_rank_normalize(xnorm); kind=:basic) xcauchy = quantile.(Cauchy(), cdf.(Normal(), xnorm)) # transformation by any monotonic function should not change the bulk ESS/R-hat - @test ess(xnorm; type=:bulk) == ess(xcauchy; type=:bulk) + @test ess(xnorm; kind=:bulk) == ess(xcauchy; kind=:bulk) end @testset "tail- ESS and R-hat detect mismatched scales" begin @@ -306,15 +306,15 @@ mymean(x) = mean(x) # sanity check that standard and bulk ESS and R-hat both fail to detect # mismatched scales - S, R = ess_rhat(x; type=:basic) + S, R = ess_rhat(x; kind=:basic) @test all(≥(ess_cutoff), S) @test all(≤(rhat_cutoff), R) - Sbulk, Rbulk = ess_rhat(x; type=:bulk) + Sbulk, Rbulk = ess_rhat(x; kind=:bulk) @test all(≥(ess_cutoff), Sbulk) @test all(≤(rhat_cutoff), Rbulk) # check that tail- ESS detects mismatched scales and signal poor convergence - S_tail, R_tail = ess_rhat(x; type=:tail) + S_tail, R_tail = ess_rhat(x; kind=:tail) @test all(<(ess_cutoff), S_tail) @test all(>(rhat_cutoff), R_tail) end @@ -334,8 +334,8 @@ mymean(x) = mean(x) end x = permutedims(cat(posterior_matrices...; dims=3), (2, 3, 1)) - Sbulk, Rbulk = ess_rhat(x; type=:bulk) - Stail, Rtail = ess_rhat(x; type=:tail) + Sbulk, Rbulk = ess_rhat(x; kind=:bulk) + Stail, Rtail = ess_rhat(x; kind=:tail) ess_cutoff = 100 * size(x, 2) # recommended cutoff is 100 * nchains @test mean(≥(ess_cutoff), Sbulk) > 0.9 @test mean(≥(ess_cutoff), Stail) < mean(≥(ess_cutoff), Sbulk) diff --git a/test/mcse.jl b/test/mcse.jl index 43ac247d..40e888ab 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -8,23 +8,23 @@ using StatsBase @testset "mcse.jl" begin @testset "estimator defaults to mean" begin x = randn(100, 4, 10) - @test mcse(x) == mcse(x; estimator=mean) + @test mcse(x) == mcse(x; kind=mean) end @testset "ESS-based methods forward kwargs to ess" begin x = randn(100, 4, 10) @testset for f in [mean, median, std, Base.Fix2(quantile, 0.1)] - @test @inferred(mcse(x; estimator=f, split_chains=1)) ≠ mcse(x; estimator=f) + @test @inferred(mcse(x; kind=f, split_chains=1)) ≠ mcse(x; kind=f) end end @testset "mcse falls back to _mcse_sbm" begin x = randn(100, 4, 10) estimator = mad - @test @inferred(mcse(x; estimator=estimator)) == + @test @inferred(mcse(x; kind=estimator)) == MCMCDiagnosticTools._mcse_sbm(estimator, x) ≠ MCMCDiagnosticTools._mcse_sbm(estimator, x; batch_size=16) == - mcse(x; estimator=estimator, batch_size=16) + mcse(x; kind=estimator, batch_size=16) end @testset "mcse produces similar vectors to inputs" begin @@ -34,15 +34,15 @@ using StatsBase x = randn(T, 100, 4, 5) y = OffsetArray(x, -5:94, 2:5, 11:15) - se = mcse(y; estimator=estimator) + se = mcse(y; kind=estimator) @test se isa OffsetVector{T} @test axes(se, 1) == axes(y, 3) - se2 = mcse(x; estimator=estimator) + se2 = mcse(x; kind=estimator) @test se2 ≈ collect(se) # quantile errors if data contains missings estimator isa Base.Fix2{typeof(quantile)} && continue y = OffsetArray(similar(x, Missing), -5:94, 2:5, 11:15) - @test mcse(y; estimator=estimator) isa OffsetVector{Missing} + @test mcse(y; kind=estimator) isa OffsetVector{Missing} end end @@ -51,7 +51,7 @@ using StatsBase x .= randn.() x[1, 1, 1] = missing @testset for f in [mean, median, std, mad] - se = mcse(x; estimator=f) + se = mcse(x; kind=f) @test ismissing(se[1]) @test !any(ismissing, se[2:end]) end @@ -82,7 +82,7 @@ using StatsBase x .= quantile.(dist, cdf.(Normal(), x)) # stationary distribution is dist μ_mean = dropdims(mapslices(f ∘ vec, x; dims=(1, 2)); dims=(1, 2)) μ = mean(asymptotic_dist(f, dist)) - se = mcse === MCMCDiagnosticTools._mcse_sbm ? mcse(f, x) : mcse(x; estimator=f) + se = mcse === MCMCDiagnosticTools._mcse_sbm ? mcse(f, x) : mcse(x; kind=f) for i in eachindex(μ_mean, se) atol = quantile(Normal(0, se[i]), 1 - α) @test μ_mean[i] ≈ μ atol = atol From 8e08a39a99359b20208965c467305d6bf2926767 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 16:36:33 +0100 Subject: [PATCH 44/51] Update kinds accepted for ESS --- src/ess_rhat.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index a0b91f42..0528fff3 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -229,9 +229,8 @@ See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref), ## Kinds of ESS estimates If `kind` isa a `Symbol`, it may take one of the following values: -- `:bulk`/`:rank`: mean-ESS computed on rank-normalized draws. This kind diagnoses poor - convergence in the bulk of the distribution due to trends or different locations of the - chains. +- `:bulk`: basic ESS computed on rank-normalized draws. This kind diagnoses poor convergence + in the bulk of the distribution due to trends or different locations of the chains. - `:tail`: minimum of the quantile-ESS for the symmetric quantiles where `tail_prob=0.1` is the probability in the tails. This kind diagnoses poor convergence in the tails of the distribution. If this kind is chosen, `kwargs` may contain a From c5bb96f621516202f3be484d76e4dac2fe4d76b0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 16:37:12 +0100 Subject: [PATCH 45/51] Remove unused utility --- src/utils.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index fe7fb462..fbb6e37c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -190,6 +190,3 @@ function _normal_quantiles_from_ranks!(q, r; α=3//8) q .= (r .- α) ./ (n - 2α + 1) return q end - -_val(k::Symbol) = Val(k) -_val(k::Val) = k From 2534c4728f70db7bd96f96da371c989b651ca885 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 16:39:10 +0100 Subject: [PATCH 46/51] Constrain kind to Symbol when applicable --- src/ess_rhat.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 0528fff3..4798adc5 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -296,7 +296,7 @@ function _ess( end """ - rhat(samples::AbstractArray{Union{Real,Missing},3}; kind=:rank, split_chains=2) + rhat(samples::AbstractArray{Union{Real,Missing},3}; kind::Symbol=:rank, split_chains=2) Compute the ``\\widehat{R}`` diagnostics for each parameter in `samples` of shape `(chains, draws, parameters)`.[^VehtariGelman2021] @@ -325,7 +325,9 @@ The following `kind`s are supported: doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) """ -function rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kind=:rank, kwargs...) +function rhat( + samples::AbstractArray{<:Union{Missing,Real},3}; kind::Symbol=:rank, kwargs... +) if kind === :rank return _rhat(Val(:rank), samples; kwargs...) elseif kind === :bulk @@ -405,7 +407,7 @@ function _rhat(::Val{:rank}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs.. end """ - ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kind=:rank, kwargs...) + ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kind::Symbol=:rank, kwargs...) Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shape `(draws, chains, parameters)` with the `method`. @@ -416,7 +418,9 @@ calling `ess` and `rhat` separately. See [`rhat`](@ref) for a description of supported `kind`s and [`ess`](@ref) for a description of `kwargs`. """ -function ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kind=:rank, kwargs...) +function ess_rhat( + samples::AbstractArray{<:Union{Missing,Real},3}; kind::Symbol=:rank, kwargs... +) if kind === :rank return _ess_rhat(Val(:rank), samples; kwargs...) elseif kind === :bulk From 84e8f1e2731639acc3a061e6c69719b26c71a37e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 18:04:41 +0100 Subject: [PATCH 47/51] Remove stray tick --- src/ess_rhat.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 4798adc5..6d185cc5 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -311,7 +311,7 @@ See also [`ess`](@ref), [`ess_rhat`](@ref), [`rstar`](@ref) The following `kind`s are supported: - `:rank`: maximum of ``\\widehat{R}`` with `kind=:bulk` and `kind=:tail`. -- `:bulk`: basic ``\\widehat{R}``` computed on rank-normalized draws. This kind diagnoses +- `:bulk`: basic ``\\widehat{R}`` computed on rank-normalized draws. This kind diagnoses poor convergence in the bulk of the distribution due to trends or different locations of the chains. - `:tail`: ``\\widehat{R}`` computed on draws folded around the median and then From 07f72fda2b0be7295b3912ee2102b62f77479cc8 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 22:43:22 +0100 Subject: [PATCH 48/51] Apply suggestions from code review Co-authored-by: David Widmann --- test/ess_rhat.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ess_rhat.jl b/test/ess_rhat.jl index 1757c1ed..ab0480a3 100644 --- a/test/ess_rhat.jl +++ b/test/ess_rhat.jl @@ -37,7 +37,7 @@ mymean(x) = mean(x) @testset "ess_rhat.jl" begin @testset "ess/ess_rhat/rhat basics" begin @testset "only promote eltype when necessary" begin - @testset for kind in [:rank, :bulk, :tail, :basic] + @testset for kind in (:rank, :bulk, :tail, :basic) @testset for T in (Float32, Float64) x = rand(T, 100, 4, 2) TV = Vector{T} From a5fb3073f573de0fce5cba469bd010d5a74d030d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 22:43:43 +0100 Subject: [PATCH 49/51] Apply suggestions from code review Co-authored-by: David Widmann --- src/mcse.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcse.jl b/src/mcse.jl index 20c28fcb..3e13f3a9 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -11,7 +11,7 @@ See also: [`ess`](@ref) ## Kinds of MCSE estimates -The estimator whose MCSE should be estimator is specified with `kind`. `kind` must accept a +The estimator whose MCSE should be estimated is specified with `kind`. `kind` must accept a vector of the same `eltype` as `samples` and return a real estimate. For the following estimators, the effective sample size [`ess`](@ref) and an estimate From 4ff8f8680bf154c8aa7e24550f038325e03a7742 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 22:48:30 +0100 Subject: [PATCH 50/51] Add comment explaining kind checks --- src/ess_rhat.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 6d185cc5..3421e523 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -258,6 +258,8 @@ Otherwise, `kind` specifies one of the following estimators, whose ESS is to be arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) """ function ess(samples::AbstractArray{<:Union{Missing,Real},3}; kind=:bulk, kwargs...) + # if we just call _ess(Val(kind), ...) Julia cannot infer the return type with default + # const-propagation. We keep this type-inferrable by manually dispatching to the cases. if kind === :bulk return _ess(Val(:bulk), samples; kwargs...) elseif kind === :tail @@ -328,6 +330,8 @@ The following `kind`s are supported: function rhat( samples::AbstractArray{<:Union{Missing,Real},3}; kind::Symbol=:rank, kwargs... ) + # if we just call _rhat(Val(kind), ...) Julia cannot infer the return type with default + # const-propagation. We keep this type-inferrable by manually dispatching to the cases. if kind === :rank return _rhat(Val(:rank), samples; kwargs...) elseif kind === :bulk @@ -421,6 +425,9 @@ description of `kwargs`. function ess_rhat( samples::AbstractArray{<:Union{Missing,Real},3}; kind::Symbol=:rank, kwargs... ) + # if we just call _ess_rhat(Val(kind), ...) Julia cannot infer the return type with + # default const-propagation. We keep this type-inferrable by manually dispatching to the + # cases. if kind === :rank return _ess_rhat(Val(:rank), samples; kwargs...) elseif kind === :bulk From bcc70539ca2bb7c299ea1b3d571e216e3f7d7230 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 26 Feb 2023 22:51:16 +0100 Subject: [PATCH 51/51] Change vectors to tuples --- test/ess_rhat.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/ess_rhat.jl b/test/ess_rhat.jl index ab0480a3..a639b480 100644 --- a/test/ess_rhat.jl +++ b/test/ess_rhat.jl @@ -53,7 +53,7 @@ mymean(x) = mean(x) @test @inferred(ess_rhat(x; kind=kind)) isa Tuple{TV,TV} end end - @testset for kind in [mean, median, mad, std, Base.Fix2(quantile, 0.25)] + @testset for kind in (mean, median, mad, std, Base.Fix2(quantile, 0.25)) @testset for T in (Float32, Float64) x = rand(T, 100, 4, 2) @test @inferred(ess(x; kind=kind)) isa Vector{T} @@ -69,8 +69,8 @@ mymean(x) = mean(x) x = rand(4, 3, 5) x2 = rand(5, 3, 5) x3 = rand(100, 3, 5) - @testset for f in [ess, ess_rhat] - @testset for kind in [:rank, :bulk, :tail, :basic] + @testset for f in (ess, ess_rhat) + @testset for kind in (:rank, :bulk, :tail, :basic) f === ess && kind === :rank && continue @test_throws ArgumentError f(x; split_chains=1, kind=kind) f(x2; split_chains=1, kind=kind) @@ -85,7 +85,7 @@ mymean(x) = mean(x) end @testset "Union{Missing,Float64} eltype" begin - @testset for kind in [:rank, :bulk, :tail, :basic] + @testset for kind in (:rank, :bulk, :tail, :basic) x = Array{Union{Missing,Float64}}(undef, 1000, 4, 3) x .= randn.() x[1, 1, 1] = missing @@ -104,7 +104,7 @@ mymean(x) = mean(x) end @testset "produces similar vectors to inputs" begin - @testset for kind in [:rank, :bulk, :tail, :basic] + @testset for kind in (:rank, :bulk, :tail, :basic) # simultaneously checks that we index correctly and that output types are correct x = randn(100, 4, 5) y = OffsetArray(x, -5:94, 2:5, 11:15) @@ -137,9 +137,9 @@ mymean(x) = mean(x) @testset "ess, ess_rhat, and rhat consistency" begin x = randn(1000, 4, 10) - @testset for kind in [:rank, :bulk, :tail, :basic], split_chains in [1, 2] + @testset for kind in (:rank, :bulk, :tail, :basic), split_chains in (1, 2) R1 = rhat(x; kind=kind, split_chains=split_chains) - @testset for method in [ESSMethod(), BDAESSMethod()], maxlag in [100, 10] + @testset for method in (ESSMethod(), BDAESSMethod()), maxlag in (100, 10) S1 = ess( x; kind=kind === :rank ? :bulk : kind,