From 2d822224d62b0c1fd6c88eebb096003340a148d6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 13:24:21 +0100 Subject: [PATCH 01/45] Add mcse_sbm --- src/mcse.jl | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/mcse.jl b/src/mcse.jl index 063ab8b4..69ba106e 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -70,3 +70,36 @@ function mcse_ipse(x::AbstractVector{<:Real}) return mcse end + +""" + mcse_sbm(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; batch_size) + +Estimate the Monte Carlo standard errors (MCSE) of the `estimator` appplied to `samples` +using the subsampling bootstrap method. + +`samples` has shape `(draws, chains, parameters)`, and `estimator` must accept a vector of +the same eltype as `x`. + +`batch_size` indicates the size of the overlapping batches used to estimate the MCSE, +defaulting to `floor(Int, sqrt(draws * chains))`. +""" +function mcse_sbm( + f, + x::AbstractArray{<:Union{Missing,Real},3}; + batch_size::Int=floor(Int, sqrt(size(x, 1) * size(x, 2))), +) + T = promote_type(eltype(x), typeof(zero(eltype(x)) / 1)) + values = similar(x, T, (axes(x, 3),)) + map!(values, eachslice(x; dims=3)) do xi + return _mcse_sbm(f, vec(xi); batch_size=batch_size) + end + return values +end +function _mcse_sbm(f, x; batch_size) + n = length(x) + i1 = firstindex(x) + v = Statistics.var( + f(view(x, i:(i + size - 1))) for i in i1:(i1 + n - batch_size); corrected=false + ) + return sqrt(v * (batch_size//n)) +end From e4b067b0840c8fe0b5ce53fb65ffcc762bfddac3 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 13:26:06 +0100 Subject: [PATCH 02/45] Update description of `estimator` --- src/mcse.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcse.jl b/src/mcse.jl index 69ba106e..80891dfb 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -78,7 +78,7 @@ Estimate the Monte Carlo standard errors (MCSE) of the `estimator` appplied to ` using the subsampling bootstrap method. `samples` has shape `(draws, chains, parameters)`, and `estimator` must accept a vector of -the same eltype as `x`. +the same eltype as `x` and return a real estimate. `batch_size` indicates the size of the overlapping batches used to estimate the MCSE, defaulting to `floor(Int, sqrt(draws * chains))`. From 34e02219d2f3ddada963f7aebcf510021698d5e3 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 13:27:44 +0100 Subject: [PATCH 03/45] Add specialized estimators for mean, std, and quantile --- src/mcse.jl | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/src/mcse.jl b/src/mcse.jl index 80891dfb..e03c29bb 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -1,3 +1,6 @@ +Base.@irrational normcdf1 0.8413447460685429486 StatsFuns.normcdf(big(1)) +Base.@irrational normcdfn1 0.1586552539314570514 StatsFuns.normcdf(big(-1)) + """ mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...) @@ -71,6 +74,62 @@ function mcse_ipse(x::AbstractVector{<:Real}) return mcse end +function mcse( + ::typeof(Statistics.mean), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... +) + S = ess_rhat(Statistics.mean, samples; kwargs...)[1] + return dropdims(Statistics.std(samples; dims=(1, 2)); dims=(1, 2)) ./ sqrt.(S) +end +function mcse( + ::typeof(Statistics.std), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... +) + x = (samples .- Statistics.mean(samples; dims=(1, 2))) .^ 2 + S = ess_rhat(Statistics.mean, x; kwargs...)[1] + mean_var = dropdims(Statistics.mean(x; dims=(1, 2)); dims=(1, 2)) + mean_moment4 = dropdims(Statistics.mean(abs2, x; dims=(1, 2)); dims=(1, 2)) + return @. sqrt((mean_moment4 / mean_var - mean_var) / S) / 2 +end +function mcse( + f::Base.Fix2{typeof(Statistics.quantile),<:Real}, + samples::AbstractArray{<:Union{Missing,Real},3}; + kwargs..., +) + p = f.x + S = ess_rhat(f, samples; kwargs...)[1] + T = eltype(S) + R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T)))) + values = similar(S, R) + map!(values, eachslice(samples; dims=3), S) do xi, Si + return _mcse_quantile(vec(xi), p, Si) + end + return values +end +function _mcse_quantile(x, p, Seff) + Seff === missing && return missing + S = length(x) + # quantile error distribution is asymptotically normal; estimate σ (mcse) with 2 + # quadrature points: xl and xu, chosen as quantiles so that xu - xl = 2σ + # compute quantiles of error distribution in probability space (i.e. quantiles passed through CDF) + # Beta(α,β) is the approximate error distribution of quantile estimates + α = Seff * p + 1 + β = Seff * (1 - p) + 1 + prob_x_upper = StatsFuns.betainvcdf(α, β, normcdf1) + prob_x_lower = StatsFuns.betainvcdf(α, β, normcdfn1) + # use inverse ECDF to get quantiles in quantile (x) space + l = max(floor(Int, prob_x_lower * S), 1) + u = min(ceil(Int, prob_x_upper * S), S) + iperm = partialsortperm(x, l:u) # sort as little of x as possible + xl = x[first(iperm)] + xu = x[last(iperm)] + # estimate mcse from quantiles + return (xu - xl) / 2 +end +function mcse( + ::typeof(Statistics.median), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... +) + return mcse(Base.Fix2(Statistics.quantile, 1//2), samples; kwargs...) +end + """ mcse_sbm(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; batch_size) From 6839cd774d7479ae8429881532600e74978c62bd Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 13:28:30 +0100 Subject: [PATCH 04/45] Remove vector methods, defaulting to sbm --- src/mcse.jl | 59 +---------------------------------------------------- 1 file changed, 1 insertion(+), 58 deletions(-) diff --git a/src/mcse.jl b/src/mcse.jl index e03c29bb..f574e515 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -15,64 +15,7 @@ The optional argument `method` describes how the errors are estimated. Possible [^Geyer1992]: Geyer, C. J. (1992). Practical Markov Chain Monte Carlo. Statistical Science, 473-483. """ -function mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...) - return if method === :bm - mcse_bm(x; kwargs...) - elseif method === :imse - mcse_imse(x) - elseif method === :ipse - mcse_ipse(x) - else - throw(ArgumentError("unsupported MCSE method $method")) - end -end - -function mcse_bm(x::AbstractVector{<:Real}; size::Int=floor(Int, sqrt(length(x)))) - n = length(x) - m = min(div(n, 2), size) - m == size || @warn "batch size was reduced to $m" - mcse = StatsBase.sem(Statistics.mean(@view(x[(i + 1):(i + m)])) for i in 0:m:(n - m)) - return mcse -end - -function mcse_imse(x::AbstractVector{<:Real}) - n = length(x) - lags = [0, 1] - ghat = StatsBase.autocov(x, lags) - Ghat = sum(ghat) - @inbounds value = Ghat + ghat[2] - @inbounds for i in 2:2:(n - 2) - lags[1] = i - lags[2] = i + 1 - StatsBase.autocov!(ghat, x, lags) - Ghat = min(Ghat, sum(ghat)) - Ghat > 0 || break - value += 2 * Ghat - end - - mcse = sqrt(value / n) - - return mcse -end - -function mcse_ipse(x::AbstractVector{<:Real}) - n = length(x) - lags = [0, 1] - ghat = StatsBase.autocov(x, lags) - @inbounds value = ghat[1] + 2 * ghat[2] - @inbounds for i in 2:2:(n - 2) - lags[1] = i - lags[2] = i + 1 - StatsBase.autocov!(ghat, x, lags) - Ghat = sum(ghat) - Ghat > 0 || break - value += 2 * Ghat - end - - mcse = sqrt(value / n) - - return mcse -end +mcse(f, x::AbstractArray{Union{Missing,<:Real},3}; kwargs...) = mcse_sbm(f, x; kwargs...) function mcse( ::typeof(Statistics.mean), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... From 7dbda2d140228b530e821f7de7220756d80ad203 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 13:30:45 +0100 Subject: [PATCH 05/45] Update docstring --- src/mcse.jl | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/mcse.jl b/src/mcse.jl index f574e515..14d1f07d 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -2,21 +2,14 @@ Base.@irrational normcdf1 0.8413447460685429486 StatsFuns.normcdf(big(1)) Base.@irrational normcdfn1 0.1586552539314570514 StatsFuns.normcdf(big(-1)) """ - mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...) + mcse(estimator, samples::AbstractArray{<:Union{Missing,Real}}; kwargs...) -Compute the Monte Carlo standard error (MCSE) of samples `x`. -The optional argument `method` describes how the errors are estimated. Possible options are: +Estimate the Monte Carlo standard errors (MCSE) of the `estimator` appplied to `samples`. -- `:bm` for batch means [^Glynn1991] -- `:imse` initial monotone sequence estimator [^Geyer1992] -- `:ipse` initial positive sequence estimator [^Geyer1992] - -[^Glynn1991]: Glynn, P. W., & Whitt, W. (1991). Estimating the asymptotic variance with batch means. Operations Research Letters, 10(8), 431-435. - -[^Geyer1992]: Geyer, C. J. (1992). Practical Markov Chain Monte Carlo. Statistical Science, 473-483. +`samples` has shape `(draws, chains, parameters)`, and `estimator` must accept a vector of +the same eltype as `x` and return a real estimate. """ mcse(f, x::AbstractArray{Union{Missing,<:Real},3}; kwargs...) = mcse_sbm(f, x; kwargs...) - function mcse( ::typeof(Statistics.mean), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) From b93ce5b3b292b12be1b3a661097645ba8d9a45f6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 13:40:39 +0100 Subject: [PATCH 06/45] Fix bugs --- src/mcse.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/mcse.jl b/src/mcse.jl index 14d1f07d..73201f9d 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -9,7 +9,7 @@ Estimate the Monte Carlo standard errors (MCSE) of the `estimator` appplied to ` `samples` has shape `(draws, chains, parameters)`, and `estimator` must accept a vector of the same eltype as `x` and return a real estimate. """ -mcse(f, x::AbstractArray{Union{Missing,<:Real},3}; kwargs...) = mcse_sbm(f, x; kwargs...) +mcse(f, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) = mcse_sbm(f, x; kwargs...) function mcse( ::typeof(Statistics.mean), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) @@ -35,8 +35,8 @@ function mcse( T = eltype(S) R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T)))) values = similar(S, R) - map!(values, eachslice(samples; dims=3), S) do xi, Si - return _mcse_quantile(vec(xi), p, Si) + for (i, xi, Si) in zip(eachindex(values), eachslice(samples; dims=3), S) + values[i] = _mcse_quantile(vec(xi), p, Si) end return values end @@ -85,8 +85,8 @@ function mcse_sbm( ) T = promote_type(eltype(x), typeof(zero(eltype(x)) / 1)) values = similar(x, T, (axes(x, 3),)) - map!(values, eachslice(x; dims=3)) do xi - return _mcse_sbm(f, vec(xi); batch_size=batch_size) + for (i, xi) in zip(eachindex(values), eachslice(x; dims=3)) + values[i] = _mcse_sbm(f, vec(xi); batch_size=batch_size) end return values end @@ -94,7 +94,8 @@ function _mcse_sbm(f, x; batch_size) n = length(x) i1 = firstindex(x) v = Statistics.var( - f(view(x, i:(i + size - 1))) for i in i1:(i1 + n - batch_size); corrected=false + f(view(x, i:(i + batch_size - 1))) for i in i1:(i1 + n - batch_size); + corrected=false, ) return sqrt(v * (batch_size//n)) end From 790a99bca874da0eeb35d4d9f0ea87f797f5a5d1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 16:32:19 +0100 Subject: [PATCH 07/45] Update docstrings --- src/mcse.jl | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/src/mcse.jl b/src/mcse.jl index 73201f9d..b81bef4e 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -4,10 +4,23 @@ Base.@irrational normcdfn1 0.1586552539314570514 StatsFuns.normcdf(big(-1)) """ mcse(estimator, samples::AbstractArray{<:Union{Missing,Real}}; kwargs...) -Estimate the Monte Carlo standard errors (MCSE) of the `estimator` appplied to `samples`. +Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` of +shape `(draws, chains, parameters)` -`samples` has shape `(draws, chains, parameters)`, and `estimator` must accept a vector of -the same eltype as `x` and return a real estimate. +## Estimators + +`estimator` must accept a vector of the same eltype as `samples` and return a real estimate. + +For the following estimators, the effective sample size [`ess_rhat`](@ref) and an estimate +of the asymptotic variance are used to compute the MCSE, and `kwargs` are forwarded to +`ess_rhat`: +- `Statistics.mean` +- `Statistics.median` +- `Statistics.std` +- `Base.Fix2(Statistics.quantile, p::Real)` + +For arbitrary estimator, the subsampling bootstrap method [`mcse_sbm`](@ref) is used, and +`kwargs` are forwarded to that function. """ mcse(f, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) = mcse_sbm(f, x; kwargs...) function mcse( @@ -69,14 +82,18 @@ end """ mcse_sbm(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; batch_size) -Estimate the Monte Carlo standard errors (MCSE) of the `estimator` appplied to `samples` -using the subsampling bootstrap method. +Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` +using the subsampling bootstrap method.[^FlegalJones2011] `samples` has shape `(draws, chains, parameters)`, and `estimator` must accept a vector of -the same eltype as `x` and return a real estimate. +the same eltype as `samples` and return a real estimate. `batch_size` indicates the size of the overlapping batches used to estimate the MCSE, defaulting to `floor(Int, sqrt(draws * chains))`. + +[^FlegalJones2011]: Flegal JM, Jones GL. Implementing MCMC: estimating with confidence. + Handbook of Markov Chain Monte Carlo. 2011. 175-97. + [pdf](http://faculty.ucr.edu/~jflegal/EstimatingWithConfidence.pdf) """ function mcse_sbm( f, From 9a1d3d5c17b95b004bdfa9666d6aecb8816ecfd5 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 16:41:08 +0100 Subject: [PATCH 08/45] Update docstring --- src/mcse.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcse.jl b/src/mcse.jl index b81bef4e..762bc564 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -83,7 +83,7 @@ end mcse_sbm(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; batch_size) Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` -using the subsampling bootstrap method.[^FlegalJones2011] +using the subsampling bootstrap method (SBM).[^FlegalJones2011] `samples` has shape `(draws, chains, parameters)`, and `estimator` must accept a vector of the same eltype as `samples` and return a real estimate. From c0d5a944f767662b43bd4aa09de84b8e9d691043 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 17:03:12 +0100 Subject: [PATCH 09/45] Move helper functions to own file --- test/ess.jl | 38 -------------------------------------- test/helpers.jl | 39 +++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 ++ 3 files changed, 41 insertions(+), 38 deletions(-) create mode 100644 test/helpers.jl diff --git a/test/ess.jl b/test/ess.jl index 0e2752cf..e5024b51 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -31,44 +31,6 @@ function LogDensityProblems.capabilities(p::CauchyProblem) return LogDensityProblems.LogDensityOrder{1}() end -# AR(1) process -function ar1(φ::Real, σ::Real, n::Int...) - T = float(Base.promote_eltype(φ, σ)) - x = randn(T, n...) - x .*= σ - accumulate!(x, x; dims=1) do xi, ϵ - return muladd(φ, xi, ϵ) - end - return x -end - -asymptotic_dist(::typeof(mean), dist) = Normal(mean(dist), std(dist)) -function asymptotic_dist(::typeof(var), dist) - μ = var(dist) - σ = μ * sqrt(kurtosis(dist) + 2) - return Normal(μ, σ) -end -function asymptotic_dist(::typeof(std), dist) - μ = std(dist) - σ = μ * sqrt(kurtosis(dist) + 2) / 2 - return Normal(μ, σ) -end -asymptotic_dist(::typeof(median), dist) = asymptotic_dist(Base.Fix2(quantile, 1//2), dist) -function asymptotic_dist(f::Base.Fix2{typeof(quantile),<:Real}, dist) - p = f.x - μ = quantile(dist, p) - σ = sqrt(p * (1 - p)) / pdf(dist, μ) - return Normal(μ, σ) -end -function asymptotic_dist(::typeof(mad), dist::Normal) - # Example 21.10 of Asymptotic Statistics. Van der Vaart - d = Normal(zero(dist.μ), dist.σ) - dtrunc = truncated(d; lower=0) - μ = median(dtrunc) - σ = 1 / (4 * pdf(d, quantile(d, 3//4))) - return Normal(μ, σ) / quantile(Normal(), 3//4) -end - @testset "ess.jl" begin @testset "ESS and R̂ (IID samples)" begin # Repeat tests with different scales diff --git a/test/helpers.jl b/test/helpers.jl new file mode 100644 index 00000000..6dd39cb5 --- /dev/null +++ b/test/helpers.jl @@ -0,0 +1,39 @@ +using Distributions, Statistics, StatsBase + +# AR(1) process +function ar1(φ::Real, σ::Real, n::Int...) + T = float(Base.promote_eltype(φ, σ)) + x = randn(T, n...) + x .*= σ + accumulate!(x, x; dims=1) do xi, ϵ + return muladd(φ, xi, ϵ) + end + return x +end + +asymptotic_dist(::typeof(mean), dist) = Normal(mean(dist), std(dist)) +function asymptotic_dist(::typeof(var), dist) + μ = var(dist) + σ = μ * sqrt(kurtosis(dist) + 2) + return Normal(μ, σ) +end +function asymptotic_dist(::typeof(std), dist) + μ = std(dist) + σ = μ * sqrt(kurtosis(dist) + 2) / 2 + return Normal(μ, σ) +end +asymptotic_dist(::typeof(median), dist) = asymptotic_dist(Base.Fix2(quantile, 1//2), dist) +function asymptotic_dist(f::Base.Fix2{typeof(quantile),<:Real}, dist) + p = f.x + μ = quantile(dist, p) + σ = sqrt(p * (1 - p)) / pdf(dist, μ) + return Normal(μ, σ) +end +function asymptotic_dist(::typeof(mad), dist::Normal) + # Example 21.10 of Asymptotic Statistics. Van der Vaart + d = Normal(zero(dist.μ), dist.σ) + dtrunc = truncated(d; lower=0) + μ = median(dtrunc) + σ = 1 / (4 * pdf(d, quantile(d, 3//4))) + return Normal(μ, σ) / quantile(Normal(), 3//4) +end diff --git a/test/runtests.jl b/test/runtests.jl index af19e7be..cd99385b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,8 @@ using Test Random.seed!(1) @testset "MCMCDiagnosticTools.jl" begin + include("helpers.jl") + @testset "utils" begin include("utils.jl") end From cf908af63c31bd77da11a623d999ebf4ee1a4b99 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 17:03:20 +0100 Subject: [PATCH 10/45] Rearrange tests --- test/runtests.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index cd99385b..f6494d60 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,6 +26,9 @@ Random.seed!(1) @testset "ESS" begin include("ess.jl") end + @testset "Monte Carlo standard error" begin + include("mcse.jl") + end @testset "Gelman, Rubin and Brooks diagnostic" begin include("gelmandiag.jl") end @@ -35,9 +38,6 @@ Random.seed!(1) @testset "Heidelberger and Welch diagnostic" begin include("heideldiag.jl") end - @testset "Monte Carlo standard error" begin - include("mcse.jl") - end @testset "Raftery and Lewis diagnostic" begin include("rafterydiag.jl") end From 93d121e3413ec5a43d458f3ee9d39dc2c924b40d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 17:03:34 +0100 Subject: [PATCH 11/45] Update mcse tests --- test/mcse.jl | 55 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/test/mcse.jl b/test/mcse.jl index 3e54d447..e3b70f9f 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -1,25 +1,38 @@ -@testset "mcse.jl" begin - samples = randn(100) - - @testset "results" begin - result = @inferred(mcse(samples)) - @test result isa Float64 - @test result > 0 - - for method in (:imse, :ipse, :bm) - result = @inferred(mcse(samples; method=method)) - @test result isa Float64 - @test result > 0 - end - end +using Test +using MCMCDiagnosticTools +using Statistics +using StatsBase - @testset "warning" begin - for size in (51, 75, 100, 153) - @test_logs (:warn,) mcse(samples; method=:bm, size=size) +@testset "mcse.jl" begin + @testset "estimand is within interval defined by MCSE estimate" begin + # we check the ESS estimates by simulating uncorrelated, correlated, and + # anticorrelated chains, mapping the draws to a target distribution, computing the + # estimand, and estimating the ESS for the chosen estimator, computing the + # corresponding MCSE, and checking that the mean estimand is close to the asymptotic + # value of the estimand, with a tolerance chosen using the MCSE. + ndraws = 1_000 + nchains = 4 + nparams = 100 + estimators = [mean, median, std, Base.Fix2(quantile, 0.25)] + dists = [Normal(10, 100), Exponential(10), TDist(7) * 10 - 20] + mcse_methods = [mcse, mcse_sbm] + # AR(1) coefficients. 0 is IID, -0.3 is slightly anticorrelated, 0.9 is highly autocorrelated + φs = [-0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9] + # account for all but the 2 skipped checks + nchecks = + nparams * length(φs) * length(estimators) * length(dists) * length(mcse_methods) + α = (0.1 / nchecks) / 2 # multiple correction + @testset for mcse in mcse_methods, f in estimators, dist in dists, φ in φs + σ = sqrt(1 - φ^2) # ensures stationary distribution is N(0, 1) + x = ar1(φ, σ, ndraws, nchains, nparams) + x .= quantile.(dist, cdf.(Normal(), x)) # stationary distribution is dist + μ_mean = dropdims(mapslices(f ∘ vec, x; dims=(1, 2)); dims=(1, 2)) + μ = mean(asymptotic_dist(f, dist)) + se = mcse(f, x) + for i in eachindex(μ_mean, se) + atol = quantile(Normal(0, se[i]), 1 - α) + @test μ_mean[i] ≈ μ atol = atol + end end end - - @testset "exception" begin - @test_throws ArgumentError mcse(samples; method=:somemethod) - end end From fe993564574089f9c2ecf0046d7f9f3c3382b99c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 17:03:41 +0100 Subject: [PATCH 12/45] Export mcse_sbm --- src/MCMCDiagnosticTools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index beb1ca11..f6eb0c19 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -20,7 +20,7 @@ export ess_rhat, ess_rhat_bulk, ess_tail, rhat_tail, ESSMethod, FFTESSMethod, BD export gelmandiag, gelmandiag_multivariate export gewekediag export heideldiag -export mcse +export mcse, mcse_sbm export rafterydiag export rstar From d12648b1cbe66e62ae4e19363f3a91175f556eb9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 17:04:16 +0100 Subject: [PATCH 13/45] Increment minor version number with DEV suffix --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 73087b7e..1a6fd140 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.2.3" +version = "0.3.0-DEV" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From 7dac85e1b4cdab07c1721b0731a866779dabff2d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 18:02:29 +0100 Subject: [PATCH 14/45] Increment docs and tests version numbers --- docs/Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index ad15f1d3..fb2fbdbf 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,7 +8,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] Documenter = "0.27" -MCMCDiagnosticTools = "0.2" +MCMCDiagnosticTools = "0.3" MLJBase = "0.19, 0.20, 0.21" MLJXGBoostInterface = "0.1, 0.2, 0.3" julia = "1.3" diff --git a/test/Project.toml b/test/Project.toml index d571fe59..2e783d1f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,7 +22,7 @@ DynamicHMC = "3" FFTW = "1.1" LogDensityProblems = "0.12, 1, 2" LogExpFunctions = "0.3" -MCMCDiagnosticTools = "0.2" +MCMCDiagnosticTools = "0.3" MLJBase = "0.19, 0.20, 0.21" MLJLIBSVMInterface = "0.1, 0.2" MLJXGBoostInterface = "0.1, 0.2, 0.3" From 94073695ee14dd5998d30a98ea0c9a07e77cb939 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 16 Jan 2023 18:16:57 +0100 Subject: [PATCH 15/45] Add additional citation --- src/mcse.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/mcse.jl b/src/mcse.jl index 762bc564..3d5e73a5 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -83,7 +83,7 @@ end mcse_sbm(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; batch_size) Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` -using the subsampling bootstrap method (SBM).[^FlegalJones2011] +using the subsampling bootstrap method (SBM).[^FlegalJones2011][^Flegal2012] `samples` has shape `(draws, chains, parameters)`, and `estimator` must accept a vector of the same eltype as `samples` and return a real estimate. @@ -91,9 +91,12 @@ the same eltype as `samples` and return a real estimate. `batch_size` indicates the size of the overlapping batches used to estimate the MCSE, defaulting to `floor(Int, sqrt(draws * chains))`. -[^FlegalJones2011]: Flegal JM, Jones GL. Implementing MCMC: estimating with confidence. - Handbook of Markov Chain Monte Carlo. 2011. 175-97. +[^FlegalJones2011]: Flegal JM, Jones GL. (2011) Implementing MCMC: estimating with confidence. + Handbook of Markov Chain Monte Carlo. pp. 175-97. [pdf](http://faculty.ucr.edu/~jflegal/EstimatingWithConfidence.pdf) +[^Flegal2012]: Flegal JM. (2012) Applicability of subsampling bootstrap methods in Markov chain Monte Carlo. + Monte Carlo and Quasi-Monte Carlo Methods 2010. pp. 363-72. + doi: [10.1007/978-3-642-27440-4_18](https://doi.org/10.1007/978-3-642-27440-4_18) """ function mcse_sbm( f, From 88b6c4188b9df421e41d5e0b6045c9a2b3adec66 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 17 Jan 2023 00:31:58 +0100 Subject: [PATCH 16/45] Update diagnostics to use new mcse --- src/gewekediag.jl | 9 ++++++--- src/heideldiag.jl | 7 +++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/gewekediag.jl b/src/gewekediag.jl index f804b62c..32c7a9a9 100644 --- a/src/gewekediag.jl +++ b/src/gewekediag.jl @@ -22,9 +22,12 @@ function gewekediag(x::AbstractVector{<:Real}; first::Real=0.1, last::Real=0.5, n = length(x) x1 = x[1:round(Int, first * n)] x2 = x[round(Int, n - last * n + 1):n] - z = - (Statistics.mean(x1) - Statistics.mean(x2)) / - hypot(mcse(x1; kwargs...), mcse(x2; kwargs...)) + T = float(eltype(x)) + s = hypot( + Base.first(mcse(Statistics.mean, reshape(x1, :, 1, 1); split_chains=1, kwargs...)), + Base.first(mcse(Statistics.mean, reshape(x2, :, 1, 1); split_chains=1, kwargs...)), + )::T + z = (Statistics.mean(x1) - Statistics.mean(x2)) / s p = SpecialFunctions.erfc(abs(z) / sqrt(2)) return (zscore=z, pvalue=p) diff --git a/src/heideldiag.jl b/src/heideldiag.jl index e93291d8..872f0329 100644 --- a/src/heideldiag.jl +++ b/src/heideldiag.jl @@ -17,7 +17,9 @@ function heideldiag( n = length(x) delta = trunc(Int, 0.10 * n) y = x[trunc(Int, n / 2):end] - S0 = length(y) * mcse(y; kwargs...)^2 + T = float(eltype(x)) + s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...))::T + S0 = length(y) * s^2 i, pvalue, converged, ybar = 1, 1.0, false, NaN while i < n / 2 y = x[i:end] @@ -33,7 +35,8 @@ function heideldiag( end i += delta end - halfwidth = sqrt(2) * SpecialFunctions.erfcinv(alpha) * mcse(y; kwargs...) + s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...))::T + halfwidth = sqrt(2) * SpecialFunctions.erfcinv(alpha) * s passed = halfwidth / abs(ybar) <= eps return ( burnin=i + start - 2, From 899711e69abac4c93efef53a4b2ffa7ee2c12a07 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 17 Jan 2023 01:05:00 +0100 Subject: [PATCH 17/45] Increase tolerance of mcse tests --- test/mcse.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcse.jl b/test/mcse.jl index e3b70f9f..1462d566 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -21,7 +21,7 @@ using StatsBase # account for all but the 2 skipped checks nchecks = nparams * length(φs) * length(estimators) * length(dists) * length(mcse_methods) - α = (0.1 / nchecks) / 2 # multiple correction + α = (0.05 / nchecks) / 2 # multiple correction @testset for mcse in mcse_methods, f in estimators, dist in dists, φ in φs σ = sqrt(1 - φ^2) # ensures stationary distribution is N(0, 1) x = ar1(φ, σ, ndraws, nchains, nparams) From 01b8dbcf3e375803a5e6a181c3bff33768c31ed4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 17 Jan 2023 01:31:23 +0100 Subject: [PATCH 18/45] Increase tolerance more --- test/mcse.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcse.jl b/test/mcse.jl index 1462d566..39639719 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -21,7 +21,7 @@ using StatsBase # account for all but the 2 skipped checks nchecks = nparams * length(φs) * length(estimators) * length(dists) * length(mcse_methods) - α = (0.05 / nchecks) / 2 # multiple correction + α = (0.01 / nchecks) / 2 # multiple correction @testset for mcse in mcse_methods, f in estimators, dist in dists, φ in φs σ = sqrt(1 - φ^2) # ensures stationary distribution is N(0, 1) x = ar1(φ, σ, ndraws, nchains, nparams) From 60d64410c865a3d63744780b7f8db8f428ca061c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 17 Jan 2023 01:36:48 +0100 Subject: [PATCH 19/45] Add mcse_sbm to docs --- docs/src/index.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/index.md b/docs/src/index.md index 54a29e74..31f64e38 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -27,6 +27,7 @@ BDAESSMethod ```@docs mcse +mcse_sbm ``` ## R⋆ diagnostic From 2441bcb06aea7df9e75dec9d439051d47baecb4d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 17 Jan 2023 13:30:30 +0100 Subject: [PATCH 20/45] Skip high autocorrelation tests for mcse_sbm --- test/mcse.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/mcse.jl b/test/mcse.jl index 39639719..853644b0 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -1,4 +1,5 @@ using Test +using Distributions using MCMCDiagnosticTools using Statistics using StatsBase @@ -20,9 +21,11 @@ using StatsBase φs = [-0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9] # account for all but the 2 skipped checks nchecks = - nparams * length(φs) * length(estimators) * length(dists) * length(mcse_methods) - α = (0.01 / nchecks) / 2 # multiple correction + nparams * (length(φs) + count(≤(5), φs)) * length(dists) * length(mcse_methods) + α = (0.1 / nchecks) / 2 # multiple correction @testset for mcse in mcse_methods, f in estimators, dist in dists, φ in φs + # mcse_sbm underestimates the MCSE for highly correlated chains + mcse === mcse_sbm && φ > 0.5 && continue σ = sqrt(1 - φ^2) # ensures stationary distribution is N(0, 1) x = ar1(φ, σ, ndraws, nchains, nparams) x .= quantile.(dist, cdf.(Normal(), x)) # stationary distribution is dist From 8e3b06a3fa7096a1d885eb5353b1842e73f54609 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 17 Jan 2023 13:32:50 +0100 Subject: [PATCH 21/45] Note underestimation for SBM --- src/mcse.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/mcse.jl b/src/mcse.jl index 3d5e73a5..9ed8384f 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -91,6 +91,12 @@ the same eltype as `samples` and return a real estimate. `batch_size` indicates the size of the overlapping batches used to estimate the MCSE, defaulting to `floor(Int, sqrt(draws * chains))`. +!!! note + SBM tends to underestimate the MCSE, especially for highly autocorrelated chains. + SBM should only be used as a fallbeck when a specific [`mcse`](@ref) method for + `estimator` is not available and when the bulk- and tail- [`ess_rhat`](@ref) values + indicate low autocorrelation. + [^FlegalJones2011]: Flegal JM, Jones GL. (2011) Implementing MCMC: estimating with confidence. Handbook of Markov Chain Monte Carlo. pp. 175-97. [pdf](http://faculty.ucr.edu/~jflegal/EstimatingWithConfidence.pdf) From 2af67e95a60276a498b5e7365aa99a9e3a89310c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 11:01:17 +0100 Subject: [PATCH 22/45] Update src/mcse.jl --- src/mcse.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcse.jl b/src/mcse.jl index 9ed8384f..976f9508 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -1,5 +1,5 @@ -Base.@irrational normcdf1 0.8413447460685429486 StatsFuns.normcdf(big(1)) -Base.@irrational normcdfn1 0.1586552539314570514 StatsFuns.normcdf(big(-1)) +const normcdf1 = 0.8413447460685429 # StatsFuns.normcdf(1) +const normcdfn1 = 0.15865525393145705 # StatsFuns.normcdf(-1) """ mcse(estimator, samples::AbstractArray{<:Union{Missing,Real}}; kwargs...) From b7cd4953799cecec063a9c03ad0fc794c7db47e8 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 12:42:50 +0100 Subject: [PATCH 23/45] Don't enforce type --- src/gewekediag.jl | 3 +-- src/heideldiag.jl | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/gewekediag.jl b/src/gewekediag.jl index 32c7a9a9..2fad2458 100644 --- a/src/gewekediag.jl +++ b/src/gewekediag.jl @@ -22,11 +22,10 @@ function gewekediag(x::AbstractVector{<:Real}; first::Real=0.1, last::Real=0.5, n = length(x) x1 = x[1:round(Int, first * n)] x2 = x[round(Int, n - last * n + 1):n] - T = float(eltype(x)) s = hypot( Base.first(mcse(Statistics.mean, reshape(x1, :, 1, 1); split_chains=1, kwargs...)), Base.first(mcse(Statistics.mean, reshape(x2, :, 1, 1); split_chains=1, kwargs...)), - )::T + ) z = (Statistics.mean(x1) - Statistics.mean(x2)) / s p = SpecialFunctions.erfc(abs(z) / sqrt(2)) diff --git a/src/heideldiag.jl b/src/heideldiag.jl index 872f0329..58262127 100644 --- a/src/heideldiag.jl +++ b/src/heideldiag.jl @@ -17,8 +17,7 @@ function heideldiag( n = length(x) delta = trunc(Int, 0.10 * n) y = x[trunc(Int, n / 2):end] - T = float(eltype(x)) - s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...))::T + s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...)) S0 = length(y) * s^2 i, pvalue, converged, ybar = 1, 1.0, false, NaN while i < n / 2 @@ -35,7 +34,7 @@ function heideldiag( end i += delta end - s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...))::T + s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...)) halfwidth = sqrt(2) * SpecialFunctions.erfcinv(alpha) * s passed = halfwidth / abs(ybar) <= eps return ( From 4d557162b74db98914cfd8bc2bf863c28f58257d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 12:44:20 +0100 Subject: [PATCH 24/45] Document kwargs passed to mcse --- src/gewekediag.jl | 2 ++ src/heideldiag.jl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/gewekediag.jl b/src/gewekediag.jl index 2fad2458..c858250d 100644 --- a/src/gewekediag.jl +++ b/src/gewekediag.jl @@ -12,6 +12,8 @@ samples are independent. A non-significant test p-value indicates convergence. p-values indicate non-convergence and the possible need to discard initial samples as a burn-in sequence or to simulate additional samples. +`kwargs` are forwarded to [`mcse`](@ref). + [^Geweke1991]: Geweke, J. F. (1991). Evaluating the accuracy of sampling-based approaches to the calculation of posterior moments (No. 148). Federal Reserve Bank of Minneapolis. """ function gewekediag(x::AbstractVector{<:Real}; first::Real=0.1, last::Real=0.5, kwargs...) diff --git a/src/heideldiag.jl b/src/heideldiag.jl index 58262127..8673f292 100644 --- a/src/heideldiag.jl +++ b/src/heideldiag.jl @@ -9,6 +9,8 @@ means are within a target ratio. Stationarity is rejected (0) for significant te Halfwidth tests are rejected (0) if observed ratios are greater than the target, as is the case for `s2` and `beta[1]`. +`kwargs` are forwarded to [`mcse`](@ref). + [^Heidelberger1983]: Heidelberger, P., & Welch, P. D. (1983). Simulation run length control in the presence of an initial transient. Operations Research, 31(6), 1109-1144. """ function heideldiag( From d9f6734f4b1326e61f112ad6825d656dfaf6fead Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 12:45:26 +0100 Subject: [PATCH 25/45] Cross-link mcse and ess_rhat docstrings --- src/ess.jl | 2 +- src/mcse.jl | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ess.jl b/src/ess.jl index df87b029..0646d4bf 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -222,7 +222,7 @@ For a given estimand, it is recommended that the ESS is at least `100 * chains` ``\\widehat{R} < 1.01``.[^VehtariGelman2021] See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref), -[`ess_rhat_bulk`](@ref), [`ess_tail`](@ref), [`rhat_tail`](@ref) +[`ess_rhat_bulk`](@ref), [`ess_tail`](@ref), [`rhat_tail`](@ref), [`mcse`](@ref) ## Estimators diff --git a/src/mcse.jl b/src/mcse.jl index 976f9508..d544d828 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -7,6 +7,8 @@ const normcdfn1 = 0.15865525393145705 # StatsFuns.normcdf(-1) Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` of shape `(draws, chains, parameters)` +See also: [`ess_rhat`](@ref) + ## Estimators `estimator` must accept a vector of the same eltype as `samples` and return a real estimate. From 1c482665aaad09821898f43ed7b5996b6f031b84 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 13:31:44 +0100 Subject: [PATCH 26/45] Document derivation of mcse for std --- src/mcse.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/mcse.jl b/src/mcse.jl index d544d828..8c1bd491 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -34,8 +34,12 @@ end function mcse( ::typeof(Statistics.std), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) - x = (samples .- Statistics.mean(samples; dims=(1, 2))) .^ 2 + x = (samples .- Statistics.mean(samples; dims=(1, 2))) .^ 2 # expectand proxy S = ess_rhat(Statistics.mean, x; kwargs...)[1] + # asymptotic variance of sample variance estimate is Var[var] = E[μ₄] - E[var]², + # where μ₄ is the 4th central moment + # by the delta method, Var[std] = Var[var] / 4E[var] = (E[μ₄]/E[var] - E[var])/4, + # See e.g. Chapter 3 of Van der Vaart, AW. (200) Asymptotic statistics. Vol. 3. mean_var = dropdims(Statistics.mean(x; dims=(1, 2)); dims=(1, 2)) mean_moment4 = dropdims(Statistics.mean(abs2, x; dims=(1, 2)); dims=(1, 2)) return @. sqrt((mean_moment4 / mean_var - mean_var) / S) / 2 From f072b9e1384de4b271410b9f3e22f3cf533b550c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 16:09:10 +0100 Subject: [PATCH 27/45] Test type-inferrability of ess_rhat --- test/ess.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ess.jl b/test/ess.jl index e6c1c55e..ce36bd20 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -181,7 +181,7 @@ end x .= quantile.(dist, cdf.(Normal(), x)) # stationary distribution is dist μ_mean = dropdims(mapslices(f ∘ vec, x; dims=(1, 2)); dims=(1, 2)) dist = asymptotic_dist(f, dist) - n = ess_rhat(f, x)[1] + n = @inferred(ess_rhat(f, x))[1] μ = mean(dist) mcse = sqrt.(var(dist) ./ n) for i in eachindex(μ_mean, mcse) From bb4788751de8d69e0dc08953bfaf18dbb99c06cb Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 16:09:53 +0100 Subject: [PATCH 28/45] Make sure ess_rhat for quantiles not promoted --- src/ess.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ess.jl b/src/ess.jl index 0646d4bf..1cbb95de 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -480,7 +480,7 @@ function _expectand_proxy(::typeof(StatsBase.mad), x) return _expectand_proxy(Statistics.median, x_folded) end function _expectand_proxy(f::Base.Fix2{typeof(Statistics.quantile),<:Real}, x) - y = similar(x, Bool) + y = similar(x) # currently quantile does not support a dims keyword argument for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3)) yi .= xi .≤ f(vec(xi)) From 7f61907bde59e7dff2157023f8adf4dbdf34405d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 16:10:14 +0100 Subject: [PATCH 29/45] Make sure ess_rhat for median type-inferrable --- src/ess.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ess.jl b/src/ess.jl index 1cbb95de..b8b72569 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -470,7 +470,9 @@ rhat_tail(x; kwargs...) = ess_rhat_bulk(_fold_around_median(x); kwargs...)[2] # If no proxy expectand for `f` is known, `nothing` is returned. _expectand_proxy(f, x) = nothing function _expectand_proxy(::typeof(Statistics.median), x) - return x .≤ Statistics.median(x; dims=(1, 2)) + y = similar(x) + y .= x .≤ Statistics.median(x; dims=(1, 2)) + return y end function _expectand_proxy(::typeof(Statistics.std), x) return (x .- Statistics.mean(x; dims=(1, 2))) .^ 2 From a03cc2a4cea65e048a7a9e2ef9e5b978d67d581f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 16:19:00 +0100 Subject: [PATCH 30/45] Implement specific method for median --- src/mcse.jl | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/mcse.jl b/src/mcse.jl index 8c1bd491..6c9856ca 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -59,6 +59,19 @@ function mcse( end return values end +function mcse( + ::typeof(Statistics.median), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... +) + S = ess_rhat(Statistics.median, samples; kwargs...)[1] + T = eltype(S) + R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T)))) + values = similar(S, R) + for (i, xi, Si) in zip(eachindex(values), eachslice(samples; dims=3), S) + values[i] = _mcse_quantile(vec(xi), 1//2, Si) + end + return values +end + function _mcse_quantile(x, p, Seff) Seff === missing && return missing S = length(x) @@ -79,11 +92,6 @@ function _mcse_quantile(x, p, Seff) # estimate mcse from quantiles return (xu - xl) / 2 end -function mcse( - ::typeof(Statistics.median), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... -) - return mcse(Base.Fix2(Statistics.quantile, 1//2), samples; kwargs...) -end """ mcse_sbm(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; batch_size) From bac8a3cbb2c4538bd15bb135a767b64922699d66 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 16:19:28 +0100 Subject: [PATCH 31/45] Return missing if any are missing --- src/mcse.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mcse.jl b/src/mcse.jl index 6c9856ca..e4b4bec5 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -131,6 +131,7 @@ function mcse_sbm( return values end function _mcse_sbm(f, x; batch_size) + any(x -> x === missing, x) && return missing n = length(x) i1 = firstindex(x) v = Statistics.var( From 652b86f1c82775492eefc8ed6c3b6d0d014ddb9d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 16:19:37 +0100 Subject: [PATCH 32/45] Add mcse tests --- test/mcse.jl | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/test/mcse.jl b/test/mcse.jl index 853644b0..5283fc3e 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -1,10 +1,61 @@ using Test using Distributions using MCMCDiagnosticTools +using OffsetArrays using Statistics using StatsBase @testset "mcse.jl" begin + @testset "estimator must be provided" begin + x = randn(100, 4, 10) + @test_throws MethodError mcse(x) + end + + @testset "ESS-based methods forward kwargs to ess_rhat" begin + x = randn(100, 4, 10) + @testset for f in [mean, median, std, Base.Fix2(quantile, 0.1)] + @test @inferred(mcse(f, x; split_chains=1)) ≠ mcse(f, x) + end + end + + @testset "mcse falls back to mcse_sbm" begin + x = randn(100, 4, 10) + @test @inferred(mcse(mad, x)) == + mcse_sbm(mad, x) ≠ + mcse_sbm(mad, x; batch_size=16) == + mcse(mad, x; batch_size=16) + end + + @testset "mcse produces similar vectors to inputs" begin + # simultaneously checks that we index correctly and that output types are correct + @testset for T in (Float32, Float64), + f in [mean, median, std, Base.Fix2(quantile, T(0.1)), mad] + + x = randn(T, 100, 4, 5) + y = OffsetArray(x, -5:94, 2:5, 11:15) + se = mcse(f, y) + @test se isa OffsetVector{T} + @test axes(se, 1) == axes(y, 3) + se2 = mcse(f, x) + @test se2 ≈ collect(se) + # quantile errors if data contains missings + f isa Base.Fix2{typeof(quantile)} && continue + y = OffsetArray(similar(x, Missing), -5:94, 2:5, 11:15) + @test mcse(f, y) isa OffsetVector{Missing} + end + end + + @testset "mcse with Union{Missing,Float64} eltype" begin + x = Array{Union{Missing,Float64}}(undef, 1000, 4, 3) + x .= randn.() + x[1, 1, 1] = missing + @testset for f in [mean, std, mad] + se = mcse(f, x) + @test ismissing(se[1]) + @test !any(ismissing, se[2:end]) + end + end + @testset "estimand is within interval defined by MCSE estimate" begin # we check the ESS estimates by simulating uncorrelated, correlated, and # anticorrelated chains, mapping the draws to a target distribution, computing the From 8dbae8403f3601c2f99e8158e6d32cbc73bc7b72 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 16:20:19 +0100 Subject: [PATCH 33/45] Decrease the number of checks --- test/ess.jl | 2 +- test/mcse.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ess.jl b/test/ess.jl index ce36bd20..0177c360 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -161,7 +161,7 @@ end # estimand, and estimating the ESS for the chosen estimator, computing the # corresponding MCSE, and checking that the mean estimand is close to the asymptotic # value of the estimand, with a tolerance chosen using the MCSE. - ndraws = 1_000 + ndraws = 100 nchains = 4 nparams = 100 x = randn(ndraws, nchains, nparams) diff --git a/test/mcse.jl b/test/mcse.jl index 5283fc3e..db9f6cc0 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -62,7 +62,7 @@ using StatsBase # estimand, and estimating the ESS for the chosen estimator, computing the # corresponding MCSE, and checking that the mean estimand is close to the asymptotic # value of the estimand, with a tolerance chosen using the MCSE. - ndraws = 1_000 + ndraws = 100 nchains = 4 nparams = 100 estimators = [mean, median, std, Base.Fix2(quantile, 0.25)] From d9aff6169c8abfd7709bb6fb4efcd1aba722a71e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 16:32:01 +0100 Subject: [PATCH 34/45] Make ESS/MCSE for median with with Union{Missing,Real} --- src/ess.jl | 7 ++++++- test/mcse.jl | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/ess.jl b/src/ess.jl index b8b72569..074e1339 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -471,7 +471,12 @@ rhat_tail(x; kwargs...) = ess_rhat_bulk(_fold_around_median(x); kwargs...)[2] _expectand_proxy(f, x) = nothing function _expectand_proxy(::typeof(Statistics.median), x) y = similar(x) - y .= x .≤ Statistics.median(x; dims=(1, 2)) + # avoid using the `dims` keyword for median because it + # - can error for Union{Missing,Real} (https://github.com/JuliaStats/Statistics.jl/issues/8) + # - is type-unstable (https://github.com/JuliaStats/Statistics.jl/issues/39) + for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3)) + yi .= xi .≤ Statistics.median(vec(xi)) + end return y end function _expectand_proxy(::typeof(Statistics.std), x) diff --git a/test/mcse.jl b/test/mcse.jl index db9f6cc0..63032474 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -49,7 +49,7 @@ using StatsBase x = Array{Union{Missing,Float64}}(undef, 1000, 4, 3) x .= randn.() x[1, 1, 1] = missing - @testset for f in [mean, std, mad] + @testset for f in [mean, median, std, mad] se = mcse(f, x) @test ismissing(se[1]) @test !any(ismissing, se[2:end]) From cced4be7a66569737b5a2291a74f517a773ff5a7 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 17:14:37 +0100 Subject: [PATCH 35/45] Make _fold_around_median type-inferrable --- src/utils.jl | 8 +++++++- test/utils.jl | 9 ++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 6ee43aeb..a5aba84a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -145,7 +145,13 @@ end Compute the absolute deviation of `x` from `Statistics.median(x)`. """ -_fold_around_median(data) = abs.(data .- Statistics.median(data; dims=(1, 2))) +function _fold_around_median(x) + y = similar(x) + for (xi, yi) in zip(eachslice(y; dims=3), eachslice(x; dims=3)) + yi .= abs.(xi .- Statistics.median(vec(xi))) + end + return y +end """ _rank_normalize(x::AbstractArray{<:Any,3}) diff --git a/test/utils.jl b/test/utils.jl index 39006eec..a412df06 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -104,6 +104,13 @@ end @testset "_fold_around_median" begin x = rand(100, 4, 8) - @test_broken @inferred MCMCDiagnosticTools._fold_around_median(x) # fails because median with dims is not type-inferrable + @inferred MCMCDiagnosticTools._fold_around_median(x) @test MCMCDiagnosticTools._fold_around_median(x) ≈ abs.(x .- median(x; dims=(1, 2))) + x = Array{Union{Missing,Float64}}(undef, 100, 4, 8) + x .= randn.() + x[1, 1, 1] = missing + @test isequal( + @inferred(MCMCDiagnosticTools._fold_around_median(x)), + abs.(x .- mapslices(median ∘ vec, x; dims=(1, 2))), + ) end From ce9d427a26ca3f8f352709fab0c286fb14411f61 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 17:15:23 +0100 Subject: [PATCH 36/45] Increase tolerance for exhaustive tests --- test/ess.jl | 2 +- test/mcse.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ess.jl b/test/ess.jl index 0177c360..da2921c8 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -173,7 +173,7 @@ end φs = [-0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9] # account for all but the 2 skipped checks nchecks = nparams * length(φs) * ((length(estimators) - 1) * length(dists) + 1) - α = (0.1 / nchecks) / 2 # multiple correction + α = (0.01 / nchecks) / 2 # multiple correction @testset for f in estimators, dist in dists, φ in φs f === mad && !(dist isa Normal) && continue σ = sqrt(1 - φ^2) # ensures stationary distribution is N(0, 1) diff --git a/test/mcse.jl b/test/mcse.jl index 63032474..5369b321 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -73,7 +73,7 @@ using StatsBase # account for all but the 2 skipped checks nchecks = nparams * (length(φs) + count(≤(5), φs)) * length(dists) * length(mcse_methods) - α = (0.1 / nchecks) / 2 # multiple correction + α = (0.01 / nchecks) / 2 # multiple correction @testset for mcse in mcse_methods, f in estimators, dist in dists, φ in φs # mcse_sbm underestimates the MCSE for highly correlated chains mcse === mcse_sbm && φ > 0.5 && continue From 787a05f45ab672f8847fabb195ee0aa93198666d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 18:10:43 +0100 Subject: [PATCH 37/45] Fix _fold_around_median --- src/utils.jl | 2 +- test/utils.jl | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index a5aba84a..87064951 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -147,7 +147,7 @@ Compute the absolute deviation of `x` from `Statistics.median(x)`. """ function _fold_around_median(x) y = similar(x) - for (xi, yi) in zip(eachslice(y; dims=3), eachslice(x; dims=3)) + for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3)) yi .= abs.(xi .- Statistics.median(vec(xi))) end return y diff --git a/test/utils.jl b/test/utils.jl index a412df06..72965951 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -109,8 +109,7 @@ end x = Array{Union{Missing,Float64}}(undef, 100, 4, 8) x .= randn.() x[1, 1, 1] = missing - @test isequal( - @inferred(MCMCDiagnosticTools._fold_around_median(x)), - abs.(x .- mapslices(median ∘ vec, x; dims=(1, 2))), - ) + foldx = @inferred(MCMCDiagnosticTools._fold_around_median(x)) + @test all(ismissing, foldx[:, :, 1]) + @test foldx[:, :, 2:end] ≈ abs.(x[:, :, 2:end] .- median(x[:, :, 2:end]; dims=(1, 2))) end From 34f377167ef452a10aa1734d7c25324a34f607aa Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 19:42:47 +0100 Subject: [PATCH 38/45] Fix count of checks --- test/mcse.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/mcse.jl b/test/mcse.jl index 5369b321..902d8032 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -71,8 +71,7 @@ using StatsBase # AR(1) coefficients. 0 is IID, -0.3 is slightly anticorrelated, 0.9 is highly autocorrelated φs = [-0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9] # account for all but the 2 skipped checks - nchecks = - nparams * (length(φs) + count(≤(5), φs)) * length(dists) * length(mcse_methods) + nchecks = nparams * (length(φs) + count(≤(5), φs)) * length(dists) α = (0.01 / nchecks) / 2 # multiple correction @testset for mcse in mcse_methods, f in estimators, dist in dists, φ in φs # mcse_sbm underestimates the MCSE for highly correlated chains From d10740a9dd0de7e19454feee8b0998a09677fbac Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 19:43:53 +0100 Subject: [PATCH 39/45] Increase the number of draws improves the quality of the estimates and reduces random failures --- test/ess.jl | 2 +- test/mcse.jl | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/ess.jl b/test/ess.jl index da2921c8..30a96e9a 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -161,7 +161,7 @@ end # estimand, and estimating the ESS for the chosen estimator, computing the # corresponding MCSE, and checking that the mean estimand is close to the asymptotic # value of the estimand, with a tolerance chosen using the MCSE. - ndraws = 100 + ndraws = 1000 nchains = 4 nparams = 100 x = randn(ndraws, nchains, nparams) diff --git a/test/mcse.jl b/test/mcse.jl index 902d8032..c172388a 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -57,12 +57,12 @@ using StatsBase end @testset "estimand is within interval defined by MCSE estimate" begin - # we check the ESS estimates by simulating uncorrelated, correlated, and + # we check the MCSE estimates by simulating uncorrelated, correlated, and # anticorrelated chains, mapping the draws to a target distribution, computing the - # estimand, and estimating the ESS for the chosen estimator, computing the - # corresponding MCSE, and checking that the mean estimand is close to the asymptotic - # value of the estimand, with a tolerance chosen using the MCSE. - ndraws = 100 + # estimand, estimating the MCSE for the chosen estimator, and checking that the mean + # estimand is close to the asymptotic value of the estimand, with a tolerance chosen + # using the MCSE. + ndraws = 1000 nchains = 4 nparams = 100 estimators = [mean, median, std, Base.Fix2(quantile, 0.25)] From cf09e4e8f406bf13faea08301d40568de8a9ee4a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 22:21:35 +0100 Subject: [PATCH 40/45] Apply suggestions from code review Co-authored-by: David Widmann --- src/mcse.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/mcse.jl b/src/mcse.jl index e4b4bec5..f7de08de 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -5,13 +5,13 @@ const normcdfn1 = 0.15865525393145705 # StatsFuns.normcdf(-1) mcse(estimator, samples::AbstractArray{<:Union{Missing,Real}}; kwargs...) Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` of -shape `(draws, chains, parameters)` +shape `(draws, chains, parameters)`. See also: [`ess_rhat`](@ref) ## Estimators -`estimator` must accept a vector of the same eltype as `samples` and return a real estimate. +`estimator` must accept a vector of the same `eltype` as `samples` and return a real estimate. For the following estimators, the effective sample size [`ess_rhat`](@ref) and an estimate of the asymptotic variance are used to compute the MCSE, and `kwargs` are forwarded to @@ -21,7 +21,7 @@ of the asymptotic variance are used to compute the MCSE, and `kwargs` are forwar - `Statistics.std` - `Base.Fix2(Statistics.quantile, p::Real)` -For arbitrary estimator, the subsampling bootstrap method [`mcse_sbm`](@ref) is used, and +For arbitrary estimators, the subsampling bootstrap method [`mcse_sbm`](@ref) is used, and `kwargs` are forwarded to that function. """ mcse(f, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) = mcse_sbm(f, x; kwargs...) @@ -100,15 +100,15 @@ Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `s using the subsampling bootstrap method (SBM).[^FlegalJones2011][^Flegal2012] `samples` has shape `(draws, chains, parameters)`, and `estimator` must accept a vector of -the same eltype as `samples` and return a real estimate. +the same `eltype` as `samples` and return a real estimate. `batch_size` indicates the size of the overlapping batches used to estimate the MCSE, defaulting to `floor(Int, sqrt(draws * chains))`. !!! note SBM tends to underestimate the MCSE, especially for highly autocorrelated chains. - SBM should only be used as a fallbeck when a specific [`mcse`](@ref) method for - `estimator` is not available and when the bulk- and tail- [`ess_rhat`](@ref) values + SBM should only be used as a fallback when a specific [`mcse`](@ref) method for + `estimator` is not available and when the bulk- and tail-[`ess_rhat`](@ref) values indicate low autocorrelation. [^FlegalJones2011]: Flegal JM, Jones GL. (2011) Implementing MCMC: estimating with confidence. From 39d74a998e22e9f311db012048dbf16e2baa304e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 22:38:59 +0100 Subject: [PATCH 41/45] Make sure heideldiag and gewekediag preserve input type --- src/MCMCDiagnosticTools.jl | 2 +- src/gewekediag.jl | 2 +- src/heideldiag.jl | 9 +++++---- test/gewekediag.jl | 9 +++++---- test/heideldiag.jl | 7 ++++--- 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index dd8af867..e16ac114 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -7,7 +7,7 @@ using Distributions: Distributions using MLJModelInterface: MLJModelInterface as MMI using SpecialFunctions: SpecialFunctions using StatsBase: StatsBase -using StatsFuns: StatsFuns +using StatsFuns: StatsFuns, sqrt2 using Tables: Tables using LinearAlgebra: LinearAlgebra diff --git a/src/gewekediag.jl b/src/gewekediag.jl index c858250d..38bc788b 100644 --- a/src/gewekediag.jl +++ b/src/gewekediag.jl @@ -29,7 +29,7 @@ function gewekediag(x::AbstractVector{<:Real}; first::Real=0.1, last::Real=0.5, Base.first(mcse(Statistics.mean, reshape(x2, :, 1, 1); split_chains=1, kwargs...)), ) z = (Statistics.mean(x1) - Statistics.mean(x2)) / s - p = SpecialFunctions.erfc(abs(z) / sqrt(2)) + p = SpecialFunctions.erfc(abs(z) / sqrt2) return (zscore=z, pvalue=p) end diff --git a/src/heideldiag.jl b/src/heideldiag.jl index 8673f292..26ec5296 100644 --- a/src/heideldiag.jl +++ b/src/heideldiag.jl @@ -14,14 +14,15 @@ case for `s2` and `beta[1]`. [^Heidelberger1983]: Heidelberger, P., & Welch, P. D. (1983). Simulation run length control in the presence of an initial transient. Operations Research, 31(6), 1109-1144. """ function heideldiag( - x::AbstractVector{<:Real}; alpha::Real=0.05, eps::Real=0.1, start::Int=1, kwargs... + x::AbstractVector{<:Real}; alpha::Real=1//20, eps::Real=0.1, start::Int=1, kwargs... ) n = length(x) delta = trunc(Int, 0.10 * n) y = x[trunc(Int, n / 2):end] + T = typeof(zero(eltype(x)) / 1) s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...)) S0 = length(y) * s^2 - i, pvalue, converged, ybar = 1, 1.0, false, NaN + i, pvalue, converged, ybar = 1, one(T), false, T(NaN) while i < n / 2 y = x[i:end] m = length(y) @@ -29,7 +30,7 @@ function heideldiag( B = cumsum(y) - ybar * collect(1:m) Bsq = (B .* B) ./ (m * S0) I = sum(Bsq) / m - pvalue = 1.0 - pcramer(I) + pvalue = 1 - T(pcramer(I)) converged = pvalue > alpha if converged break @@ -37,7 +38,7 @@ function heideldiag( i += delta end s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...)) - halfwidth = sqrt(2) * SpecialFunctions.erfcinv(alpha) * s + halfwidth = sqrt2 * SpecialFunctions.erfcinv(T(alpha)) * s passed = halfwidth / abs(ybar) <= eps return ( burnin=i + start - 2, diff --git a/test/gewekediag.jl b/test/gewekediag.jl index 5cd8f211..3e963e1a 100644 --- a/test/gewekediag.jl +++ b/test/gewekediag.jl @@ -1,12 +1,13 @@ @testset "gewekediag.jl" begin - samples = randn(100) - @testset "results" begin - @test @inferred(gewekediag(samples)) isa - NamedTuple{(:zscore, :pvalue),Tuple{Float64,Float64}} + @testset for T in (Float32, Float64) + samples = randn(T, 100) + @inferred NamedTuple{(:zscore, :pvalue),Tuple{T,T}} gewekediag(samples) + end end @testset "exceptions" begin + samples = randn(100) for x in (-0.3, 0, 1, 1.2) @test_throws ArgumentError gewekediag(samples; first=x) @test_throws ArgumentError gewekediag(samples; last=x) diff --git a/test/heideldiag.jl b/test/heideldiag.jl index 96d8adf4..79e80e54 100644 --- a/test/heideldiag.jl +++ b/test/heideldiag.jl @@ -1,7 +1,8 @@ @testset "heideldiag.jl" begin - samples = randn(100) - @testset "results" begin - @test @inferred(heideldiag(samples)) isa NamedTuple + @testset for T in (Float32, Float64) + samples = randn(T, 100) + @test @inferred(heideldiag(samples)) isa NamedTuple + end end end From a575fc8b289b57ffbca5b27b4ab2ca07189230f1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 22:41:30 +0100 Subject: [PATCH 42/45] Consistently use first and last for ess_rhat --- src/ess.jl | 6 +++--- src/mcse.jl | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/ess.jl b/src/ess.jl index 074e1339..241fc673 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -435,8 +435,8 @@ function ess_tail( # workaround for https://github.com/JuliaStats/Statistics.jl/issues/136 T = Base.promote_eltype(x, tail_prob) return min.( - ess_rhat(Base.Fix2(Statistics.quantile, T(tail_prob / 2)), x; kwargs...)[1], - ess_rhat(Base.Fix2(Statistics.quantile, T(1 - tail_prob / 2)), x; kwargs...)[1], + first(ess_rhat(Base.Fix2(Statistics.quantile, T(tail_prob / 2)), x; kwargs...)), + first(ess_rhat(Base.Fix2(Statistics.quantile, T(1 - tail_prob / 2)), x; kwargs...)), ) end @@ -464,7 +464,7 @@ See also: [`ess_tail`](@ref), [`ess_rhat_bulk`](@ref) doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221) arXiv: [1903.08008](https://arxiv.org/abs/1903.08008) """ -rhat_tail(x; kwargs...) = ess_rhat_bulk(_fold_around_median(x); kwargs...)[2] +rhat_tail(x; kwargs...) = last(ess_rhat_bulk(_fold_around_median(x); kwargs...)) # Compute an expectand `z` such that ``\\textrm{mean-ESS}(z) ≈ \\textrm{f-ESS}(x)``. # If no proxy expectand for `f` is known, `nothing` is returned. diff --git a/src/mcse.jl b/src/mcse.jl index f7de08de..a8cbfead 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -28,14 +28,14 @@ mcse(f, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) = mcse_sbm(f, x; k function mcse( ::typeof(Statistics.mean), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) - S = ess_rhat(Statistics.mean, samples; kwargs...)[1] + S = first(ess_rhat(Statistics.mean, samples; kwargs...)) return dropdims(Statistics.std(samples; dims=(1, 2)); dims=(1, 2)) ./ sqrt.(S) end function mcse( ::typeof(Statistics.std), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) x = (samples .- Statistics.mean(samples; dims=(1, 2))) .^ 2 # expectand proxy - S = ess_rhat(Statistics.mean, x; kwargs...)[1] + S = first(ess_rhat(Statistics.mean, x; kwargs...)) # asymptotic variance of sample variance estimate is Var[var] = E[μ₄] - E[var]², # where μ₄ is the 4th central moment # by the delta method, Var[std] = Var[var] / 4E[var] = (E[μ₄]/E[var] - E[var])/4, @@ -50,7 +50,7 @@ function mcse( kwargs..., ) p = f.x - S = ess_rhat(f, samples; kwargs...)[1] + S = first(ess_rhat(f, samples; kwargs...)) T = eltype(S) R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T)))) values = similar(S, R) @@ -62,7 +62,7 @@ end function mcse( ::typeof(Statistics.median), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) - S = ess_rhat(Statistics.median, samples; kwargs...)[1] + S = first(ess_rhat(Statistics.median, samples; kwargs...)) T = eltype(S) R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T)))) values = similar(S, R) From b4eea8d33eba260ed7df2f4478e4b32f801c06ff Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 22:42:39 +0100 Subject: [PATCH 43/45] Copy comment to _fold_around_median --- src/utils.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 87064951..6ebfcb64 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -147,6 +147,9 @@ Compute the absolute deviation of `x` from `Statistics.median(x)`. """ function _fold_around_median(x) y = similar(x) + # avoid using the `dims` keyword for median because it + # - can error for Union{Missing,Real} (https://github.com/JuliaStats/Statistics.jl/issues/8) + # - is type-unstable (https://github.com/JuliaStats/Statistics.jl/issues/39) for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3)) yi .= abs.(xi .- Statistics.median(vec(xi))) end From 373834745de3ecc1c941ae33ecda880a0e28ec4e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 22:55:33 +0100 Subject: [PATCH 44/45] Make mcse_sbm an internal function --- docs/src/index.md | 1 - src/MCMCDiagnosticTools.jl | 2 +- src/mcse.jl | 49 ++++++++++++++------------------------ 3 files changed, 19 insertions(+), 33 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 31f64e38..54a29e74 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -27,7 +27,6 @@ BDAESSMethod ```@docs mcse -mcse_sbm ``` ## R⋆ diagnostic diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index e16ac114..e237abef 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -20,7 +20,7 @@ export ess_rhat, ess_rhat_bulk, ess_tail, rhat_tail, ESSMethod, FFTESSMethod, BD export gelmandiag, gelmandiag_multivariate export gewekediag export heideldiag -export mcse, mcse_sbm +export mcse export rafterydiag export rstar diff --git a/src/mcse.jl b/src/mcse.jl index a8cbfead..779e4e14 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -21,10 +21,22 @@ of the asymptotic variance are used to compute the MCSE, and `kwargs` are forwar - `Statistics.std` - `Base.Fix2(Statistics.quantile, p::Real)` -For arbitrary estimators, the subsampling bootstrap method [`mcse_sbm`](@ref) is used, and -`kwargs` are forwarded to that function. +For other estimators, the subsampling bootstrap method (SBM)[^FlegalJones2011][^Flegal2012] +is used as a fallback, and the only accepted `kwargs` are `batch_size`, which indicates the +size of the overlapping batches used to estimate the MCSE, defaulting to +`floor(Int, sqrt(draws * chains))`. Note that SBM tends to underestimate the MCSE, +especially for highly autocorrelated chains. One should verify that autocorrelation is low +by checking the bulk- and tail-[`ess_rhat`](@ref) values. + +[^FlegalJones2011]: Flegal JM, Jones GL. (2011) Implementing MCMC: estimating with confidence. + Handbook of Markov Chain Monte Carlo. pp. 175-97. + [pdf](http://faculty.ucr.edu/~jflegal/EstimatingWithConfidence.pdf) +[^Flegal2012]: Flegal JM. (2012) Applicability of subsampling bootstrap methods in Markov chain Monte Carlo. + Monte Carlo and Quasi-Monte Carlo Methods 2010. pp. 363-72. + doi: [10.1007/978-3-642-27440-4_18](https://doi.org/10.1007/978-3-642-27440-4_18) + """ -mcse(f, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) = mcse_sbm(f, x; kwargs...) +mcse(f, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) = _mcse_sbm(f, x; kwargs...) function mcse( ::typeof(Statistics.mean), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs... ) @@ -93,32 +105,7 @@ function _mcse_quantile(x, p, Seff) return (xu - xl) / 2 end -""" - mcse_sbm(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; batch_size) - -Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` -using the subsampling bootstrap method (SBM).[^FlegalJones2011][^Flegal2012] - -`samples` has shape `(draws, chains, parameters)`, and `estimator` must accept a vector of -the same `eltype` as `samples` and return a real estimate. - -`batch_size` indicates the size of the overlapping batches used to estimate the MCSE, -defaulting to `floor(Int, sqrt(draws * chains))`. - -!!! note - SBM tends to underestimate the MCSE, especially for highly autocorrelated chains. - SBM should only be used as a fallback when a specific [`mcse`](@ref) method for - `estimator` is not available and when the bulk- and tail-[`ess_rhat`](@ref) values - indicate low autocorrelation. - -[^FlegalJones2011]: Flegal JM, Jones GL. (2011) Implementing MCMC: estimating with confidence. - Handbook of Markov Chain Monte Carlo. pp. 175-97. - [pdf](http://faculty.ucr.edu/~jflegal/EstimatingWithConfidence.pdf) -[^Flegal2012]: Flegal JM. (2012) Applicability of subsampling bootstrap methods in Markov chain Monte Carlo. - Monte Carlo and Quasi-Monte Carlo Methods 2010. pp. 363-72. - doi: [10.1007/978-3-642-27440-4_18](https://doi.org/10.1007/978-3-642-27440-4_18) -""" -function mcse_sbm( +function _mcse_sbm( f, x::AbstractArray{<:Union{Missing,Real},3}; batch_size::Int=floor(Int, sqrt(size(x, 1) * size(x, 2))), @@ -126,11 +113,11 @@ function mcse_sbm( T = promote_type(eltype(x), typeof(zero(eltype(x)) / 1)) values = similar(x, T, (axes(x, 3),)) for (i, xi) in zip(eachindex(values), eachslice(x; dims=3)) - values[i] = _mcse_sbm(f, vec(xi); batch_size=batch_size) + values[i] = _mcse_sbm(f, vec(xi), batch_size) end return values end -function _mcse_sbm(f, x; batch_size) +function _mcse_sbm(f, x, batch_size) any(x -> x === missing, x) && return missing n = length(x) i1 = firstindex(x) From 69fe0a1336974b172a8e73123d6ef2bc28a144a1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 18 Jan 2023 22:56:51 +0100 Subject: [PATCH 45/45] Update tests --- test/mcse.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/mcse.jl b/test/mcse.jl index c172388a..781143f5 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -18,11 +18,11 @@ using StatsBase end end - @testset "mcse falls back to mcse_sbm" begin + @testset "mcse falls back to _mcse_sbm" begin x = randn(100, 4, 10) @test @inferred(mcse(mad, x)) == - mcse_sbm(mad, x) ≠ - mcse_sbm(mad, x; batch_size=16) == + MCMCDiagnosticTools._mcse_sbm(mad, x) ≠ + MCMCDiagnosticTools._mcse_sbm(mad, x; batch_size=16) == mcse(mad, x; batch_size=16) end @@ -67,15 +67,15 @@ using StatsBase nparams = 100 estimators = [mean, median, std, Base.Fix2(quantile, 0.25)] dists = [Normal(10, 100), Exponential(10), TDist(7) * 10 - 20] - mcse_methods = [mcse, mcse_sbm] + mcse_methods = [mcse, MCMCDiagnosticTools._mcse_sbm] # AR(1) coefficients. 0 is IID, -0.3 is slightly anticorrelated, 0.9 is highly autocorrelated φs = [-0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9] # account for all but the 2 skipped checks nchecks = nparams * (length(φs) + count(≤(5), φs)) * length(dists) α = (0.01 / nchecks) / 2 # multiple correction @testset for mcse in mcse_methods, f in estimators, dist in dists, φ in φs - # mcse_sbm underestimates the MCSE for highly correlated chains - mcse === mcse_sbm && φ > 0.5 && continue + # _mcse_sbm underestimates the MCSE for highly correlated chains + mcse === MCMCDiagnosticTools._mcse_sbm && φ > 0.5 && continue σ = sqrt(1 - φ^2) # ensures stationary distribution is N(0, 1) x = ar1(φ, σ, ndraws, nchains, nparams) x .= quantile.(dist, cdf.(Normal(), x)) # stationary distribution is dist