Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fixes for rstar #52

Merged
merged 24 commits into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.2.0"
version = "0.2.1"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
1 change: 1 addition & 0 deletions src/MCMCDiagnosticTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export mcse
export rafterydiag
export rstar

include("utils.jl")
include("bfmi.jl")
include("discretediag.jl")
include("ess.jl")
Expand Down
26 changes: 14 additions & 12 deletions src/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
classifier::MLJModelInterface.Supervised,
samples,
chain_indices::AbstractVector{Int};
subset::Real=0.8,
subset::Real=0.7,
nsplit::Int=2,
verbosity::Int=0,
)

Expand All @@ -23,26 +24,25 @@ function rstar(
classifier::MLJModelInterface.Supervised,
x,
y::AbstractVector{Int};
subset::Real=0.8,
subset::Real=0.7,
nsplit::Int=2,
verbosity::Int=0,
)
# checks
MLJModelInterface.nrows(x) != length(y) && throw(DimensionMismatch())
0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)"))

ysplit = split_chain_indices(y, nsplit)

# randomly sub-select training and testing set
N = length(y)
Ntrain = round(Int, N * subset)
0 < Ntrain < N ||
train_ids, test_ids = shuffle_split_stratified(rng, ysplit, subset)
0 < length(train_ids) < length(y) ||
throw(ArgumentError("training and test data subsets must not be empty"))
ids = Random.randperm(rng, N)
train_ids = view(ids, 1:Ntrain)
test_ids = view(ids, (Ntrain + 1):N)

xtable = _astable(x)

# train classifier on training data
ycategorical = MLJModelInterface.categorical(y)
ycategorical = MLJModelInterface.categorical(ysplit)
xtrain = MLJModelInterface.selectrows(xtable, train_ids)
fitresult, _ = MLJModelInterface.fit(
classifier, verbosity, xtrain, ycategorical[train_ids]
Expand Down Expand Up @@ -79,7 +79,8 @@ end
rng::Random.AbstractRNG=Random.default_rng(),
classifier::MLJModelInterface.Supervised,
samples::AbstractArray{<:Real,3};
subset::Real=0.8,
subset::Real=0.7,
nsplit::Int=2,
verbosity::Int=0,
)

Expand All @@ -91,8 +92,9 @@ This implementation is an adaption of algorithms 1 and 2 described by Lambert an

The `classifier` has to be a supervised classifier of the MLJ framework (see the
[MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/list_of_supported_models/#model_list)
for a list of supported models). It is trained with a `subset` of the samples. The training
of the classifier can be inspected by adjusting the `verbosity` level.
for a list of supported models). It is trained with a `subset` of the samples. Each chain
is split into `nsplit` separate chains to additionally check for within-chain convergence.
The training of the classifier can be inspected by adjusting the `verbosity` level.

If the classifier is deterministic, i.e., if it predicts a class, the value of the ``R^*``
statistic is returned (algorithm 1). If the classifier is probabilistic, i.e., if it outputs
Expand Down
80 changes: 80 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
indices_of_unique(x) -> Dict

Return a `Dict` whose keys are the unique elements of `x` and whose values are the
corresponding indices in `x`.
"""
function indices_of_unique(x)
d = Dict{eltype(x), Vector{Int}}()
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
for (i, xi) in enumerate(x)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
if haskey(d, xi)
push!(d[xi], i)
else
d[xi] = [i]
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be made more efficient by not looking up the key twice. One could e.g. use

Suggested change
if haskey(d, xi)
push!(d[xi], i)
else
d[xi] = [i]
end
d_xi = get!(d, xi) do
return Int[]
end
push!(d_xi, i)

Apart from that, it seems like a function that could exist e.g. in StatsBase (similar to proportionmap etc.). Did you check that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this would fit in StatsBase, but there currently is no such method (indexmap is closest). MLUtils has group_indices, which is equivalent, but the dependency is too heavy.

I found a few threads of people looking for this, e.g. https://discourse.julialang.org/t/is-there-a-function-similar-to-numpy-unique-with-inverse/80949 but with no clear answer.

An alternative would be to stick closer to NumPy's very useful return_inverse=True approach and return 2 vectors, basically the sorted keys and corresponding values. Either way, this could later be upstreamed to StatsBase.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I just want to make sure we use existing functionality. If it doesn't exist yet that's unfortunate but, of course, then we should use our own implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it seems indicatormap returns the information we are interested in: https://juliastats.org/StatsBase.jl/stable/misc/#StatsBase.indicatormat But maybe it's not the desired output format for our purposes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's similar, yes, but a little clunky. e.g. here's how we could get the vector of indices:

using SparseArrays
map(first  findnz  sparse, eachslice(indicatormat(x; sparse=true); dims=1))

But I still think it makes more sense to try to upstream the functionality we want, since often something like what we want will be more convenient for the user.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. That seems a bit inconvenient.

end
return d
end

"""
split_chain_indices(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there some existing splitting functionality for ess? Is the plan to merge these eventually?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite merge, because there are two different types of splitting we can consider. This approach supports ragged chains and is as a result more complex and doesn't discard any draws (instead dividing the remainder across the earlier splits).

For ess/rhat, we don't support ragged chains so would discard draws if necessary to keep them the same length after splitting. This implementation is much simpler and can be done in a non-allocating way with just reshape and view on a 3d array. This will be part of #22.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing splitting functionality copy_split! will go away.

chain_inds::AbstractVector{Int},
nsplit::Int=2,
) -> AbstractVector{Int}

Split each chain in `chain_inds` into `nsplit` chains.

For each chain in `chain_inds`, all entries are assumed to correspond to draws that have
been ordered by iteration number. The result is a vector of the same length as `chain_inds`
where each entry is the new index of the chain that the corresponding draw belongs to.
"""
function split_chain_indices(c::AbstractVector{<:Int}, nsplit::Int=2)
cnew = similar(c)
if nsplit == 1
copyto!(cnew, c)
return cnew
end
chain_indices = indices_of_unique(c)
chain_ind = 0
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
for chain in sort(collect(keys(chain_indices)))
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
inds = chain_indices[chain]
ndraws_per_split, rem = divrem(length(inds), nsplit)
ilast = 0
for j in 1:nsplit
chain_ind += 1
ndraws_this_split = ndraws_per_split + (j ≤ rem)
i = ilast + 1
ilast = i + ndraws_this_split - 1
@views cnew[inds[i:ilast]] .= chain_ind
end
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end
return cnew
end

"""
shuffle_split_stratified(
rng::Random.AbstractRNG,
group_ids::AbstractVector,
frac::Real,
) -> (inds1, inds2)

Randomly split the indices of `group_ids` into two groups, where `frac` indices from each
group are in `inds1` and the remainder are in `inds2`.

This is used, for example, to split data into training and test data while preserving the
class balances.
"""
function shuffle_split_stratified(rng::Random.AbstractRNG, groups::AbstractVector, frac::Real)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
inds1 = Int[]
inds2 = Int[]
group_indices = indices_of_unique(groups)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
for group in keys(group_indices)
inds = group_indices[group]
N = length(inds)
N1 = round(Int, N * frac)
ids = Random.randperm(rng, N)
@views append!(inds1, inds[ids[1:N1]])
@views append!(inds2, inds[ids[(N1 + 1):N]])
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end
return inds1, inds2
end
17 changes: 14 additions & 3 deletions test/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
@test dist isa LocationScale
@test dist.ρ isa PoissonBinomial
@test minimum(dist) == 0
@test maximum(dist) == 3
@test maximum(dist) == 6
end
@test mean(dist) ≈ 1 rtol = 0.2
wrapper === Vector && break
Expand All @@ -48,7 +48,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
@test dist isa LocationScale
@test dist.ρ isa PoissonBinomial
@test minimum(dist) == 0
@test maximum(dist) == 4
@test maximum(dist) == 8
end
@test mean(dist) ≈ 1 rtol = 0.15

Expand All @@ -58,7 +58,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
100 .* cos.(1:N) 100 .* sin.(1:N)
])
chain_indices = repeat(1:2; inner=N)
dist = rstar(classifier, samples, chain_indices)
dist = rstar(classifier, samples, chain_indices; nsplit=1)

# Mean of the statistic should be close to 2, i.e., the classifier should be able to
# learn an almost perfect decision boundary between chains.
Expand All @@ -71,6 +71,17 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo
@test maximum(dist) == 2
end
@test mean(dist) ≈ 2 rtol = 0.15

# Compute the R⋆ statistic for identical chains that individually have not mixed.
samples = ones(sz)
samples[div(N, 2):end, :] .= 2
chain_indices = repeat(1:4; outer=div(N, 4))
dist = rstar(classifier, samples, chain_indices; nsplit=1)
# without split chains cannot distinguish between chains
@test mean(dist) ≈ 1 rtol = 0.15
dist = rstar(classifier, samples, chain_indices)
# with split chains can learn almost perfect decision boundary
@test mean(dist) ≈ 2 rtol = 0.15
end
wrapper === Vector && continue

Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ using Test
Random.seed!(1)

@testset "MCMCDiagnosticTools.jl" begin
@testset "utils" begin
include("utils.jl")
end

@testset "Bayesian fraction of missing information" begin
include("bfmi.jl")
end
Expand Down
48 changes: 48 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using MCMCDiagnosticTools
using Test
using Random

@testset "indices_of_unique" begin
inds = [1, 4, 3, 1, 4, 1, 3, 3, 4, 2, 1, 4, 1, 1, 3, 2, 3, 4, 4, 2]
d = MCMCDiagnosticTools.indices_of_unique(inds)
@test d isa Dict{Int, Vector{Int}}
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
@test issetequal(union(values(d)...), eachindex(inds))
for k in keys(d)
@test all(inds[d[k]] .== k)
end
end

@testset "split_chain_indices" begin
c = [2, 2, 1, 3, 4, 3, 4, 1, 2, 1, 4, 3, 3, 2, 4, 3, 4, 1, 4, 1]
@test @inferred(MCMCDiagnosticTools.split_chain_indices(c, 1)) == c

cnew = @inferred MCMCDiagnosticTools.split_chain_indices(c, 2)
d = MCMCDiagnosticTools.indices_of_unique(c)
dnew = MCMCDiagnosticTools.indices_of_unique(cnew)
for (i, inew) in enumerate(1:2:7)
@test length(dnew[inew]) ≥ length(dnew[inew + 1])
@test d[i] == vcat(dnew[inew], dnew[inew + 1])
end

cnew = MCMCDiagnosticTools.split_chain_indices(c, 3)
d = MCMCDiagnosticTools.indices_of_unique(c)
dnew = MCMCDiagnosticTools.indices_of_unique(cnew)
for (i, inew) in enumerate(1:3:11)
@test length(dnew[inew]) ≥ length(dnew[inew + 1]) ≥ length(dnew[inew + 2])
@test d[i] == vcat(dnew[inew], dnew[inew + 1], dnew[inew + 2])
end
end

@testset "shuffle_split_stratified" begin
rng = Random.default_rng()
c = rand(1:4, 100)
d = MCMCDiagnosticTools.indices_of_unique(c)
@testset "frac=$frac" for frac in [0.3, 0.5, 0.7]
inds1, inds2 = @inferred(MCMCDiagnosticTools.shuffle_split_stratified(rng, c, frac))
@test issetequal(vcat(inds1, inds2), eachindex(c))
for i in 1:4
common_inds = intersect(inds1, d[i])
@test length(common_inds) == round(frac * length(d[i]))
end
end
end