diff --git a/src/ess_rhat.jl b/src/ess_rhat.jl index 039ae3ed..a07f8633 100644 --- a/src/ess_rhat.jl +++ b/src/ess_rhat.jl @@ -280,7 +280,7 @@ function _ess(estimator, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs return _ess(Val(:basic), x; kwargs...) end function _ess(kind::Val, samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...) - return first(_ess_rhat(kind, samples; kwargs...)) + return _ess_rhat(kind, samples; kwargs...).ess end function _ess( ::Val{:tail}, @@ -411,7 +411,11 @@ function _rhat(::Val{:rank}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs.. end """ - ess_rhat(samples::AbstractArray{<:Union{Missing,Real},3}; kind::Symbol=:rank, kwargs...) + ess_rhat( + samples::AbstractArray{<:Union{Missing,Real},3}; + kind::Symbol=:rank, + kwargs..., + ) -> NamedTuple{(:ess, :rhat)} Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shape `(draws, chains, parameters)`. @@ -468,7 +472,7 @@ function _ess_rhat( ess = similar(chains, T, axes_out) rhat = similar(chains, T, axes_out) - T === Missing && return ess, rhat + T === Missing && return (; ess, rhat) # define caches for mean and variance chain_mean = Array{T}(undef, 1, nchains) @@ -565,7 +569,7 @@ function _ess_rhat( ess[i] = min(ntotal / τ, ess_max) end - return ess, rhat + return (; ess, rhat) end function _ess_rhat(::Val{:bulk}, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) return _ess_rhat(Val(:basic), _rank_normalize(x); kwargs...) @@ -578,7 +582,7 @@ function _ess_rhat( ) S = _ess(kind, x; split_chains=split_chains, kwargs...) R = _rhat(kind, x; split_chains=split_chains) - return S, R + return (ess=S, rhat=R) end function _ess_rhat( ::Val{:rank}, x::AbstractArray{<:Union{Missing,Real},3}; split_chains::Int=2, kwargs... @@ -586,7 +590,7 @@ function _ess_rhat( Sbulk, Rbulk = _ess_rhat(Val(:bulk), x; split_chains=split_chains, kwargs...) Rtail = _rhat(Val(:tail), x; split_chains=split_chains) Rrank = map(max, Rtail, Rbulk) - return Sbulk, Rrank + return (ess=Sbulk, rhat=Rrank) end # Compute an expectand `z` such that ``\\textrm{mean-ESS}(z) ≈ \\textrm{f-ESS}(x)``. diff --git a/test/ess_rhat.jl b/test/ess_rhat.jl index 14c8fcfc..dc251db5 100644 --- a/test/ess_rhat.jl +++ b/test/ess_rhat.jl @@ -45,14 +45,16 @@ mymean(x) = mean(x) TV = Vector{T} kind === :rank || @test @inferred(ess(x; kind=kind)) isa TV @test @inferred(rhat(x; kind=kind)) isa TV - @test @inferred(ess_rhat(x; kind=kind)) isa Tuple{TV,TV} + @test @inferred(ess_rhat(x; kind=kind)) isa + NamedTuple{(:ess, :rhat),Tuple{TV,TV}} end @testset "Int" begin x = rand(1:10, 100, 4, 2) TV = Vector{Float64} kind === :rank || @test @inferred(ess(x; kind=kind)) isa TV @test @inferred(rhat(x; kind=kind)) isa TV - @test @inferred(ess_rhat(x; kind=kind)) isa Tuple{TV,TV} + @test @inferred(ess_rhat(x; kind=kind)) isa + NamedTuple{(:ess, :rhat),Tuple{TV,TV}} end end @testset for kind in (mean, median, mad, std, Base.Fix2(quantile, 0.25))