diff --git a/Project.toml b/Project.toml index 31dd2a08..f78d6d38 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann", "Seth Axen"] -version = "0.3.9" +version = "0.3.10" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index e42d16c4..05446464 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -342,19 +342,26 @@ function rhat(samples::AbstractArray{<:Union{Missing,Real}}; kind::Symbol=:rank, return throw(ArgumentError("the `kind` `$kind` is not supported by `rhat`")) end end -function _rhat( - ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real}}; split_chains::Int=2 -) - # compute size of matrices (each chain may be split!) - niter = size(chains, 1) ÷ split_chains - nchains = split_chains * size(chains, 2) +function _rhat(::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real}}; kwargs...) + # define output array axes_out = _param_axes(chains) T = promote_type(eltype(chains), typeof(zero(eltype(chains)) / 1)) - - # define output arrays rhat = similar(chains, T, axes_out) - T === Missing && return rhat + if T !== Missing + _rhat_basic!(rhat, chains; kwargs...) + end + + return _maybescalar(rhat) +end +function _rhat_basic!( + rhat::AbstractArray{T}, + chains::AbstractArray{<:Union{Missing,Real}}; + split_chains::Int=2, +) where {T<:Union{Missing,Real}} + # compute size of matrices (each chain may be split!) + niter = size(chains, 1) ÷ split_chains + nchains = split_chains * size(chains, 2) # define caches for mean and variance chain_mean = Array{T}(undef, 1, nchains) @@ -393,8 +400,7 @@ function _rhat( # estimate rhat rhat[i] = sqrt(var₊ / W) end - - return _maybescalar(rhat) + return rhat end function _rhat(::Val{:bulk}, x::AbstractArray{<:Union{Missing,Real}}; kwargs...) return _rhat(Val(:basic), _rank_normalize(x); kwargs...) @@ -445,33 +451,48 @@ end function _ess_rhat( ::Val{:basic}, chains::AbstractArray{<:Union{Missing,Real}}; - relative::Bool=false, - autocov_method::AbstractAutocovMethod=AutocovMethod(), split_chains::Int=2, maxlag::Int=250, + kwargs..., ) - # compute size of matrices (each chain may be split!) - niter = size(chains, 1) ÷ split_chains - nchains = split_chains * size(chains, 2) - ntotal = niter * nchains + # define output arrays axes_out = _param_axes(chains) T = promote_type(eltype(chains), typeof(zero(eltype(chains)) / 1)) + ess = similar(chains, T, axes_out) + rhat = similar(chains, T, axes_out) + + # compute number of iterations (each chain may be split!) + niter = size(chains, 1) ÷ split_chains - # discard the last pair of autocorrelations, which are poorly estimated and only matter - # when chains have mixed poorly anyways. - # leave the last even autocorrelation as a bias term that reduces variance for - # case of antithetical chains, see below if !(niter > 4) - throw(ArgumentError("number of draws after splitting must >4 but is $niter.")) + # discard the last pair of autocorrelations, which are poorly estimated and only matter + # when chains have mixed poorly anyways. + # leave the last even autocorrelation as a bias term that reduces variance for + # case of antithetical chains, see below + @warn "number of draws after splitting must be >4 but is $niter. ESS cannot be computed." + fill!(ess, NaN) + _rhat_basic!(rhat, chains; split_chains) + elseif T !== Missing + maxlag > 0 || throw(DomainError(maxlag, "maxlag must be >0.")) + maxlag = min(maxlag, niter - 4) + _ess_rhat_basic!(ess, rhat, chains; split_chains, maxlag, kwargs...) end - maxlag > 0 || throw(DomainError(maxlag, "maxlag must be >0.")) - maxlag = min(maxlag, niter - 4) - # define output arrays - ess = similar(chains, T, axes_out) - rhat = similar(chains, T, axes_out) - - T === Missing && return (; ess, rhat) + return (; ess=_maybescalar(ess), rhat=_maybescalar(rhat)) +end +function _ess_rhat_basic!( + ess::TA, + rhat::TA, + chains::AbstractArray{<:Union{Missing,Real}}; + relative::Bool=false, + autocov_method::AbstractAutocovMethod=AutocovMethod(), + split_chains::Int=2, + maxlag::Int=250, +) where {T<:Union{Missing,Real},TA<:AbstractArray{T}} + # compute size of matrices (each chain may be split!) + niter = size(chains, 1) ÷ split_chains + nchains = split_chains * size(chains, 2) + ntotal = niter * nchains # define caches for mean and variance chain_mean = Array{T}(undef, 1, nchains) @@ -573,7 +594,7 @@ function _ess_rhat( ess .*= ntotal end - return (; ess=_maybescalar(ess), rhat=_maybescalar(rhat)) + return (; ess, rhat) end function _ess_rhat(::Val{:bulk}, x::AbstractArray{<:Union{Missing,Real}}; kwargs...) return _ess_rhat(Val(:basic), _rank_normalize(x); kwargs...) diff --git a/test/ess_rhat.jl b/test/ess_rhat.jl index 9a0b3801..a620c1c1 100644 --- a/test/ess_rhat.jl +++ b/test/ess_rhat.jl @@ -67,12 +67,27 @@ mymean(x) = mean(x) x = rand(4, 3, 5) x2 = rand(5, 3, 5) x3 = rand(100, 3, 5) + x4 = rand(1, 3, 5) @testset for f in (ess, ess_rhat) @testset for kind in (:rank, :bulk, :tail, :basic) f === ess && kind === :rank && continue - @test_throws ArgumentError f(x; split_chains=1, kind=kind) + if f === ess + @test all(isnan, f(x; split_chains=1, kind=kind)) + @test all(isnan, f(x4; split_chains=2, kind=kind)) + else + @test all(isnan, f(x; split_chains=1, kind=kind).ess) + @test f(x; split_chains=1, kind=kind).rhat == + rhat(x; split_chains=1, kind=kind) + @test all(isnan, f(x4; split_chains=2, kind=kind).ess) + end f(x2; split_chains=1, kind=kind) - @test_throws ArgumentError f(x2; split_chains=2, kind=kind) + if f === ess + @test all(isnan, f(x2; split_chains=2, kind=kind)) + else + @test all(isnan, f(x2; split_chains=2, kind=kind).ess) + @test f(x2; split_chains=2, kind=kind).rhat == + rhat(x2; split_chains=2, kind=kind) + end f(x3; maxlag=1, kind=kind) @test_throws DomainError f(x3; maxlag=0, kind=kind) end