diff --git a/Project.toml b/Project.toml index 05009511..7d5cd9c9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.2.0" +version = "0.2.1" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" @@ -18,6 +19,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] AbstractFFTs = "0.5, 1" DataAPI = "1.6" +DataStructures = "0.18.3" Distributions = "0.25" MLJModelInterface = "1.6" SpecialFunctions = "0.8, 0.9, 0.10, 1, 2" diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index 6aca4543..9c995f65 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -2,6 +2,7 @@ module MCMCDiagnosticTools using AbstractFFTs: AbstractFFTs using DataAPI: DataAPI +using DataStructures: DataStructures using Distributions: Distributions using MLJModelInterface: MLJModelInterface using SpecialFunctions: SpecialFunctions @@ -22,6 +23,7 @@ export mcse export rafterydiag export rstar +include("utils.jl") include("bfmi.jl") include("discretediag.jl") include("ess.jl") diff --git a/src/rstar.jl b/src/rstar.jl index 7eb61c9b..820c3a02 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -4,7 +4,8 @@ classifier::MLJModelInterface.Supervised, samples, chain_indices::AbstractVector{Int}; - subset::Real=0.8, + subset::Real=0.7, + split_chains::Int=2, verbosity::Int=0, ) @@ -23,26 +24,25 @@ function rstar( classifier::MLJModelInterface.Supervised, x, y::AbstractVector{Int}; - subset::Real=0.8, + subset::Real=0.7, + split_chains::Int=2, verbosity::Int=0, ) # checks MLJModelInterface.nrows(x) != length(y) && throw(DimensionMismatch()) 0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)")) + ysplit = split_chain_indices(y, split_chains) + # randomly sub-select training and testing set - N = length(y) - Ntrain = round(Int, N * subset) - 0 < Ntrain < N || + train_ids, test_ids = shuffle_split_stratified(rng, ysplit, subset) + 0 < length(train_ids) < length(y) || throw(ArgumentError("training and test data subsets must not be empty")) - ids = Random.randperm(rng, N) - train_ids = view(ids, 1:Ntrain) - test_ids = view(ids, (Ntrain + 1):N) xtable = _astable(x) # train classifier on training data - ycategorical = MLJModelInterface.categorical(y) + ycategorical = MLJModelInterface.categorical(ysplit) xtrain = MLJModelInterface.selectrows(xtable, train_ids) fitresult, _ = MLJModelInterface.fit( classifier, verbosity, xtrain, ycategorical[train_ids] @@ -79,7 +79,8 @@ end rng::Random.AbstractRNG=Random.default_rng(), classifier::MLJModelInterface.Supervised, samples::AbstractArray{<:Real,3}; - subset::Real=0.8, + subset::Real=0.7, + split_chains::Int=2, verbosity::Int=0, ) @@ -91,8 +92,10 @@ This implementation is an adaption of algorithms 1 and 2 described by Lambert an The `classifier` has to be a supervised classifier of the MLJ framework (see the [MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/list_of_supported_models/#model_list) -for a list of supported models). It is trained with a `subset` of the samples. The training -of the classifier can be inspected by adjusting the `verbosity` level. +for a list of supported models). It is trained with a `subset` of the samples from each +chain. Each chain is split into `split_chains` separate chains to additionally check for +within-chain convergence. The training of the classifier can be inspected by adjusting the +`verbosity` level. If the classifier is deterministic, i.e., if it predicts a class, the value of the ``R^*`` statistic is returned (algorithm 1). If the classifier is probabilistic, i.e., if it outputs diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 00000000..ca186080 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,99 @@ +""" + unique_indices(x) -> (unique, indices) + +Return the results of `unique(collect(x))` along with the a vector of the same length whose +elements are the indices in `x` at which the corresponding unique element in `unique` is +found. +""" +function unique_indices(x) + inds = eachindex(x) + T = eltype(inds) + ind_map = DataStructures.SortedDict{eltype(x),Vector{T}}() + for i in inds + xi = x[i] + inds_xi = get!(ind_map, xi) do + return T[] + end + push!(inds_xi, i) + end + unique = collect(keys(ind_map)) + indices = collect(values(ind_map)) + return unique, indices +end + +""" + split_chain_indices( + chain_inds::AbstractVector{Int}, + split::Int=2, + ) -> AbstractVector{Int} + +Split each chain in `chain_inds` into `split` chains. + +For each chain in `chain_inds`, all entries are assumed to correspond to draws that have +been ordered by iteration number. The result is a vector of the same length as `chain_inds` +where each entry is the new index of the chain that the corresponding draw belongs to. +""" +function split_chain_indices(c::AbstractVector{Int}, split::Int=2) + cnew = similar(c) + if split == 1 + copyto!(cnew, c) + return cnew + end + _, indices = unique_indices(c) + chain_ind = 1 + for inds in indices + ndraws_per_split, rem = divrem(length(inds), split) + # here we can't use Iterators.partition because it's greedy. e.g. we can't partition + # 4 items across 3 partitions because Iterators.partition(1:4, 1) == [[1], [2], [3]] + # and Iterators.partition(1:4, 2) == [[1, 2], [3, 4]]. But we would want + # [[1, 2], [3], [4]]. + i = j = 0 + ndraws_this_split = ndraws_per_split + (j < rem) + for ind in inds + cnew[ind] = chain_ind + if (i += 1) == ndraws_this_split + i = 0 + j += 1 + ndraws_this_split = ndraws_per_split + (j < rem) + chain_ind += 1 + end + end + end + return cnew +end + +""" + shuffle_split_stratified( + rng::Random.AbstractRNG, + group_ids::AbstractVector, + frac::Real, + ) -> (inds1, inds2) + +Randomly split the indices of `group_ids` into two groups, where `frac` indices from each +group are in `inds1` and the remainder are in `inds2`. + +This is used, for example, to split data into training and test data while preserving the +class balances. +""" +function shuffle_split_stratified( + rng::Random.AbstractRNG, group_ids::AbstractVector, frac::Real +) + _, indices = unique_indices(group_ids) + T = eltype(eltype(indices)) + N1_tot = sum(x -> round(Int, length(x) * frac), indices) + N2_tot = length(group_ids) - N1_tot + inds1 = Vector{T}(undef, N1_tot) + inds2 = Vector{T}(undef, N2_tot) + items_in_1 = items_in_2 = 0 + for inds in indices + N = length(inds) + N1 = round(Int, N * frac) + N2 = N - N1 + Random.shuffle!(rng, inds) + copyto!(inds1, items_in_1 + 1, inds, 1, N1) + copyto!(inds2, items_in_2 + 1, inds, N1 + 1, N2) + items_in_1 += N1 + items_in_2 += N2 + end + return inds1, inds2 +end diff --git a/test/rstar.jl b/test/rstar.jl index 928e50a4..92706416 100644 --- a/test/rstar.jl +++ b/test/rstar.jl @@ -30,7 +30,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo @test dist isa LocationScale @test dist.ρ isa PoissonBinomial @test minimum(dist) == 0 - @test maximum(dist) == 3 + @test maximum(dist) == 6 end @test mean(dist) ≈ 1 rtol = 0.2 wrapper === Vector && break @@ -48,7 +48,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo @test dist isa LocationScale @test dist.ρ isa PoissonBinomial @test minimum(dist) == 0 - @test maximum(dist) == 4 + @test maximum(dist) == 8 end @test mean(dist) ≈ 1 rtol = 0.15 @@ -58,7 +58,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo 100 .* cos.(1:N) 100 .* sin.(1:N) ]) chain_indices = repeat(1:2; inner=N) - dist = rstar(classifier, samples, chain_indices) + dist = rstar(classifier, samples, chain_indices; split_chains=1) # Mean of the statistic should be close to 2, i.e., the classifier should be able to # learn an almost perfect decision boundary between chains. @@ -71,6 +71,17 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo @test maximum(dist) == 2 end @test mean(dist) ≈ 2 rtol = 0.15 + + # Compute the R⋆ statistic for identical chains that individually have not mixed. + samples = ones(sz) + samples[div(N, 2):end, :] .= 2 + chain_indices = repeat(1:4; outer=div(N, 4)) + dist = rstar(classifier, samples, chain_indices; split_chains=1) + # without split chains cannot distinguish between chains + @test mean(dist) ≈ 1 rtol = 0.15 + dist = rstar(classifier, samples, chain_indices) + # with split chains can learn almost perfect decision boundary + @test mean(dist) ≈ 2 rtol = 0.15 end wrapper === Vector && continue diff --git a/test/runtests.jl b/test/runtests.jl index 63fdade0..af19e7be 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,10 @@ using Test Random.seed!(1) @testset "MCMCDiagnosticTools.jl" begin + @testset "utils" begin + include("utils.jl") + end + @testset "Bayesian fraction of missing information" begin include("bfmi.jl") end diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 00000000..de374e8b --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,62 @@ +using MCMCDiagnosticTools +using Test +using Random + +@testset "unique_indices" begin + @testset "indices=$(eachindex(inds))" for inds in [ + rand(11:14, 100), transpose(rand(11:14, 10, 10)) + ] + unique, indices = @inferred MCMCDiagnosticTools.unique_indices(inds) + @test unique isa Vector{Int} + if eachindex(inds) isa CartesianIndices{2} + @test indices isa Vector{Vector{CartesianIndex{2}}} + else + @test indices isa Vector{Vector{Int}} + end + @test issorted(unique) + @test issetequal(union(indices...), eachindex(inds)) + for i in eachindex(unique, indices) + @test all(inds[indices[i]] .== unique[i]) + end + end +end + +@testset "split_chain_indices" begin + c = [2, 2, 1, 3, 4, 3, 4, 1, 2, 1, 4, 3, 3, 2, 4, 3, 4, 1, 4, 1] + @test @inferred(MCMCDiagnosticTools.split_chain_indices(c, 1)) == c + + cnew = @inferred MCMCDiagnosticTools.split_chain_indices(c, 2) + @test issetequal(Base.unique(cnew), 1:maximum(cnew)) # check no indices skipped + unique, indices = MCMCDiagnosticTools.unique_indices(c) + uniquenew, indicesnew = MCMCDiagnosticTools.unique_indices(cnew) + for (i, inew) in enumerate(1:2:7) + @test length(indicesnew[inew]) ≥ length(indicesnew[inew + 1]) + @test indices[i] == vcat(indicesnew[inew], indicesnew[inew + 1]) + end + + cnew = MCMCDiagnosticTools.split_chain_indices(c, 3) + @test issetequal(Base.unique(cnew), 1:maximum(cnew)) # check no indices skipped + unique, indices = MCMCDiagnosticTools.unique_indices(c) + uniquenew, indicesnew = MCMCDiagnosticTools.unique_indices(cnew) + for (i, inew) in enumerate(1:3:11) + @test length(indicesnew[inew]) ≥ + length(indicesnew[inew + 1]) ≥ + length(indicesnew[inew + 2]) + @test indices[i] == + vcat(indicesnew[inew], indicesnew[inew + 1], indicesnew[inew + 2]) + end +end + +@testset "shuffle_split_stratified" begin + rng = Random.default_rng() + c = rand(1:4, 100) + unique, indices = MCMCDiagnosticTools.unique_indices(c) + @testset "frac=$frac" for frac in [0.3, 0.5, 0.7] + inds1, inds2 = @inferred(MCMCDiagnosticTools.shuffle_split_stratified(rng, c, frac)) + @test issetequal(vcat(inds1, inds2), eachindex(c)) + for inds in indices + common_inds = intersect(inds1, inds) + @test length(common_inds) == round(frac * length(inds)) + end + end +end