Skip to content

Commit

Permalink
Implement fixes for rstar (#52)
Browse files Browse the repository at this point in the history
* Use 70-30 split

* Add indices_of_unique

* Add split_chain_indices

* Add shuffle_split_stratified

* Use utility functions in rstar

* Update to reflect splits

* Perform test without splits

* Add test for splitting

* Increment patch number

* Test methods are type-inferrable

* Run formatter

* Update subset docs

* Rename nsplit to split_chains

* From indices_of_unique to unique_indices

* Fix bugs in tests

* Add comment why we can't use Iterators.partitions

* Avoid append!

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Add DataStructures as dependency

* Fix variable name

* Remove unused variable groups

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Make sure no chain indices are skipped

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
sethaxen and devmotion authored Dec 14, 2022
1 parent 8d74357 commit 63e82cf
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 16 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.2.0"
version = "0.2.1"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Expand All @@ -18,6 +19,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
[compat]
AbstractFFTs = "0.5, 1"
DataAPI = "1.6"
DataStructures = "0.18.3"
Distributions = "0.25"
MLJModelInterface = "1.6"
SpecialFunctions = "0.8, 0.9, 0.10, 1, 2"
Expand Down
2 changes: 2 additions & 0 deletions src/MCMCDiagnosticTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module MCMCDiagnosticTools

using AbstractFFTs: AbstractFFTs
using DataAPI: DataAPI
using DataStructures: DataStructures
using Distributions: Distributions
using MLJModelInterface: MLJModelInterface
using SpecialFunctions: SpecialFunctions
Expand All @@ -22,6 +23,7 @@ export mcse
export rafterydiag
export rstar

include("utils.jl")
include("bfmi.jl")
include("discretediag.jl")
include("ess.jl")
Expand Down
27 changes: 15 additions & 12 deletions src/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
classifier::MLJModelInterface.Supervised,
samples,
chain_indices::AbstractVector{Int};
subset::Real=0.8,
subset::Real=0.7,
split_chains::Int=2,
verbosity::Int=0,
)
Expand All @@ -23,26 +24,25 @@ function rstar(
classifier::MLJModelInterface.Supervised,
x,
y::AbstractVector{Int};
subset::Real=0.8,
subset::Real=0.7,
split_chains::Int=2,
verbosity::Int=0,
)
# checks
MLJModelInterface.nrows(x) != length(y) && throw(DimensionMismatch())
0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)"))

ysplit = split_chain_indices(y, split_chains)

# randomly sub-select training and testing set
N = length(y)
Ntrain = round(Int, N * subset)
0 < Ntrain < N ||
train_ids, test_ids = shuffle_split_stratified(rng, ysplit, subset)
0 < length(train_ids) < length(y) ||
throw(ArgumentError("training and test data subsets must not be empty"))
ids = Random.randperm(rng, N)
train_ids = view(ids, 1:Ntrain)
test_ids = view(ids, (Ntrain + 1):N)

xtable = _astable(x)

# train classifier on training data
ycategorical = MLJModelInterface.categorical(y)
ycategorical = MLJModelInterface.categorical(ysplit)
xtrain = MLJModelInterface.selectrows(xtable, train_ids)
fitresult, _ = MLJModelInterface.fit(
classifier, verbosity, xtrain, ycategorical[train_ids]
Expand Down Expand Up @@ -79,7 +79,8 @@ end
rng::Random.AbstractRNG=Random.default_rng(),
classifier::MLJModelInterface.Supervised,
samples::AbstractArray{<:Real,3};
subset::Real=0.8,
subset::Real=0.7,
split_chains::Int=2,
verbosity::Int=0,
)
Expand All @@ -91,8 +92,10 @@ This implementation is an adaption of algorithms 1 and 2 described by Lambert an
The `classifier` has to be a supervised classifier of the MLJ framework (see the
[MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/list_of_supported_models/#model_list)
for a list of supported models). It is trained with a `subset` of the samples. The training
of the classifier can be inspected by adjusting the `verbosity` level.
for a list of supported models). It is trained with a `subset` of the samples from each
chain. Each chain is split into `split_chains` separate chains to additionally check for
within-chain convergence. The training of the classifier can be inspected by adjusting the
`verbosity` level.
If the classifier is deterministic, i.e., if it predicts a class, the value of the ``R^*``
statistic is returned (algorithm 1). If the classifier is probabilistic, i.e., if it outputs
Expand Down
99 changes: 99 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
unique_indices(x) -> (unique, indices)
Return the results of `unique(collect(x))` along with the a vector of the same length whose
elements are the indices in `x` at which the corresponding unique element in `unique` is
found.
"""
function unique_indices(x)
inds = eachindex(x)
T = eltype(inds)
ind_map = DataStructures.SortedDict{eltype(x),Vector{T}}()
for i in inds
xi = x[i]
inds_xi = get!(ind_map, xi) do
return T[]
end
push!(inds_xi, i)
end
unique = collect(keys(ind_map))
indices = collect(values(ind_map))
return unique, indices
end

"""
split_chain_indices(
chain_inds::AbstractVector{Int},
split::Int=2,
) -> AbstractVector{Int}
Split each chain in `chain_inds` into `split` chains.
For each chain in `chain_inds`, all entries are assumed to correspond to draws that have
been ordered by iteration number. The result is a vector of the same length as `chain_inds`
where each entry is the new index of the chain that the corresponding draw belongs to.
"""
function split_chain_indices(c::AbstractVector{Int}, split::Int=2)
cnew = similar(c)
if split == 1
copyto!(cnew, c)
return cnew
end
_, indices = unique_indices(c)
chain_ind = 1
for inds in indices
ndraws_per_split, rem = divrem(length(inds), split)
# here we can't use Iterators.partition because it's greedy. e.g. we can't partition
# 4 items across 3 partitions because Iterators.partition(1:4, 1) == [[1], [2], [3]]
# and Iterators.partition(1:4, 2) == [[1, 2], [3, 4]]. But we would want
# [[1, 2], [3], [4]].
i = j = 0
ndraws_this_split = ndraws_per_split + (j < rem)
for ind in inds
cnew[ind] = chain_ind
if (i += 1) == ndraws_this_split
i = 0
j += 1
ndraws_this_split = ndraws_per_split + (j < rem)
chain_ind += 1
end
end
end
return cnew
end

"""
shuffle_split_stratified(
rng::Random.AbstractRNG,
group_ids::AbstractVector,
frac::Real,
) -> (inds1, inds2)
Randomly split the indices of `group_ids` into two groups, where `frac` indices from each
group are in `inds1` and the remainder are in `inds2`.
This is used, for example, to split data into training and test data while preserving the
class balances.
"""
function shuffle_split_stratified(
rng::Random.AbstractRNG, group_ids::AbstractVector, frac::Real
)
_, indices = unique_indices(group_ids)
T = eltype(eltype(indices))
N1_tot = sum(x -> round(Int, length(x) * frac), indices)
N2_tot = length(group_ids) - N1_tot
inds1 = Vector{T}(undef, N1_tot)
inds2 = Vector{T}(undef, N2_tot)
items_in_1 = items_in_2 = 0
for inds in indices
N = length(inds)
N1 = round(Int, N * frac)
N2 = N - N1
Random.shuffle!(rng, inds)
copyto!(inds1, items_in_1 + 1, inds, 1, N1)
copyto!(inds2, items_in_2 + 1, inds, N1 + 1, N2)
items_in_1 += N1
items_in_2 += N2
end
return inds1, inds2
end
17 changes: 14 additions & 3 deletions test/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
@test dist isa LocationScale
@test dist.ρ isa PoissonBinomial
@test minimum(dist) == 0
@test maximum(dist) == 3
@test maximum(dist) == 6
end
@test mean(dist) 1 rtol = 0.2
wrapper === Vector && break
Expand All @@ -48,7 +48,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
@test dist isa LocationScale
@test dist.ρ isa PoissonBinomial
@test minimum(dist) == 0
@test maximum(dist) == 4
@test maximum(dist) == 8
end
@test mean(dist) 1 rtol = 0.15

Expand All @@ -58,7 +58,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
100 .* cos.(1:N) 100 .* sin.(1:N)
])
chain_indices = repeat(1:2; inner=N)
dist = rstar(classifier, samples, chain_indices)
dist = rstar(classifier, samples, chain_indices; split_chains=1)

# Mean of the statistic should be close to 2, i.e., the classifier should be able to
# learn an almost perfect decision boundary between chains.
Expand All @@ -71,6 +71,17 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
@test maximum(dist) == 2
end
@test mean(dist) 2 rtol = 0.15

# Compute the R⋆ statistic for identical chains that individually have not mixed.
samples = ones(sz)
samples[div(N, 2):end, :] .= 2
chain_indices = repeat(1:4; outer=div(N, 4))
dist = rstar(classifier, samples, chain_indices; split_chains=1)
# without split chains cannot distinguish between chains
@test mean(dist) 1 rtol = 0.15
dist = rstar(classifier, samples, chain_indices)
# with split chains can learn almost perfect decision boundary
@test mean(dist) 2 rtol = 0.15
end
wrapper === Vector && continue

Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ using Test
Random.seed!(1)

@testset "MCMCDiagnosticTools.jl" begin
@testset "utils" begin
include("utils.jl")
end

@testset "Bayesian fraction of missing information" begin
include("bfmi.jl")
end
Expand Down
62 changes: 62 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using MCMCDiagnosticTools
using Test
using Random

@testset "unique_indices" begin
@testset "indices=$(eachindex(inds))" for inds in [
rand(11:14, 100), transpose(rand(11:14, 10, 10))
]
unique, indices = @inferred MCMCDiagnosticTools.unique_indices(inds)
@test unique isa Vector{Int}
if eachindex(inds) isa CartesianIndices{2}
@test indices isa Vector{Vector{CartesianIndex{2}}}
else
@test indices isa Vector{Vector{Int}}
end
@test issorted(unique)
@test issetequal(union(indices...), eachindex(inds))
for i in eachindex(unique, indices)
@test all(inds[indices[i]] .== unique[i])
end
end
end

@testset "split_chain_indices" begin
c = [2, 2, 1, 3, 4, 3, 4, 1, 2, 1, 4, 3, 3, 2, 4, 3, 4, 1, 4, 1]
@test @inferred(MCMCDiagnosticTools.split_chain_indices(c, 1)) == c

cnew = @inferred MCMCDiagnosticTools.split_chain_indices(c, 2)
@test issetequal(Base.unique(cnew), 1:maximum(cnew)) # check no indices skipped
unique, indices = MCMCDiagnosticTools.unique_indices(c)
uniquenew, indicesnew = MCMCDiagnosticTools.unique_indices(cnew)
for (i, inew) in enumerate(1:2:7)
@test length(indicesnew[inew]) length(indicesnew[inew + 1])
@test indices[i] == vcat(indicesnew[inew], indicesnew[inew + 1])
end

cnew = MCMCDiagnosticTools.split_chain_indices(c, 3)
@test issetequal(Base.unique(cnew), 1:maximum(cnew)) # check no indices skipped
unique, indices = MCMCDiagnosticTools.unique_indices(c)
uniquenew, indicesnew = MCMCDiagnosticTools.unique_indices(cnew)
for (i, inew) in enumerate(1:3:11)
@test length(indicesnew[inew])
length(indicesnew[inew + 1])
length(indicesnew[inew + 2])
@test indices[i] ==
vcat(indicesnew[inew], indicesnew[inew + 1], indicesnew[inew + 2])
end
end

@testset "shuffle_split_stratified" begin
rng = Random.default_rng()
c = rand(1:4, 100)
unique, indices = MCMCDiagnosticTools.unique_indices(c)
@testset "frac=$frac" for frac in [0.3, 0.5, 0.7]
inds1, inds2 = @inferred(MCMCDiagnosticTools.shuffle_split_stratified(rng, c, frac))
@test issetequal(vcat(inds1, inds2), eachindex(c))
for inds in indices
common_inds = intersect(inds1, inds)
@test length(common_inds) == round(frac * length(inds))
end
end
end

2 comments on commit 63e82cf

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/74081

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.1 -m "<description of version>" 63e82cfa0a464259d87c1e1670e289f3be9ab9c6
git push origin v0.2.1

Please sign in to comment.