-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement fixes for rstar #52
Merged
Merged
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
84b1434
Use 70-30 split
sethaxen 0874f73
Add indices_of_unique
sethaxen f76ced4
Add split_chain_indices
sethaxen edf7344
Add shuffle_split_stratified
sethaxen 45117c5
Use utility functions in rstar
sethaxen 2b90e31
Update to reflect splits
sethaxen 72a5fc3
Perform test without splits
sethaxen bfdd461
Add test for splitting
sethaxen 47f5cb4
Increment patch number
sethaxen d07a9cc
Test methods are type-inferrable
sethaxen a98c482
Run formatter
sethaxen bab4deb
Update subset docs
sethaxen e1c939c
Rename nsplit to split_chains
sethaxen d6d72d1
From indices_of_unique to unique_indices
sethaxen 01a3de1
Fix bugs in tests
sethaxen abcbac1
Add comment why we can't use Iterators.partitions
sethaxen 1120768
Avoid append!
sethaxen 9042e19
Apply suggestions from code review
sethaxen f2fcd87
Add DataStructures as dependency
sethaxen f3db4d6
Fix variable name
sethaxen 702ac36
Remove unused variable groups
sethaxen ae38386
Apply suggestions from code review
sethaxen efe0502
Apply suggestions from code review
sethaxen 3a288fa
Make sure no chain indices are skipped
sethaxen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
""" | ||
unique_indices(x) -> (unique, indices) | ||
|
||
Return the results of `unique(collect(x))` along with the a vector of the same length whose | ||
elements are the indices in `x` at which the corresponding unique element in `unique` is | ||
found. | ||
""" | ||
function unique_indices(x) | ||
inds = eachindex(x) | ||
T = eltype(inds) | ||
ind_map = DataStructures.SortedDict{eltype(x),Vector{T}}() | ||
for i in inds | ||
xi = x[i] | ||
inds_xi = get!(ind_map, xi) do | ||
return T[] | ||
end | ||
push!(inds_xi, i) | ||
end | ||
unique = collect(keys(ind_map)) | ||
indices = collect(values(ind_map)) | ||
return unique, indices | ||
end | ||
|
||
""" | ||
split_chain_indices( | ||
chain_inds::AbstractVector{Int}, | ||
split::Int=2, | ||
) -> AbstractVector{Int} | ||
|
||
Split each chain in `chain_inds` into `split` chains. | ||
|
||
For each chain in `chain_inds`, all entries are assumed to correspond to draws that have | ||
been ordered by iteration number. The result is a vector of the same length as `chain_inds` | ||
where each entry is the new index of the chain that the corresponding draw belongs to. | ||
""" | ||
function split_chain_indices(c::AbstractVector{Int}, split::Int=2) | ||
cnew = similar(c) | ||
if split == 1 | ||
copyto!(cnew, c) | ||
return cnew | ||
end | ||
_, indices = unique_indices(c) | ||
chain_ind = 0 | ||
sethaxen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for inds in indices | ||
ndraws_per_split, rem = divrem(length(inds), split) | ||
# here we can't use Iterators.partition because it's greedy. e.g. we can't partition | ||
# 4 items across 3 partitions because Iterators.partition(1:4, 1) == [[1], [2], [3]] | ||
# and Iterators.partition(1:4, 2) == [[1, 2], [3, 4]]. But we would want | ||
# [[1, 2], [3], [4]]. | ||
i = j = 0 | ||
ndraws_this_split = ndraws_per_split + (j < rem) | ||
chain_ind += 1 | ||
sethaxen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for ind in inds | ||
cnew[ind] = chain_ind | ||
if (i += 1) == ndraws_this_split | ||
i = 0 | ||
j += 1 | ||
ndraws_this_split = ndraws_per_split + (j < rem) | ||
chain_ind += 1 | ||
end | ||
end | ||
end | ||
return cnew | ||
end | ||
|
||
""" | ||
shuffle_split_stratified( | ||
rng::Random.AbstractRNG, | ||
group_ids::AbstractVector, | ||
frac::Real, | ||
) -> (inds1, inds2) | ||
|
||
Randomly split the indices of `group_ids` into two groups, where `frac` indices from each | ||
group are in `inds1` and the remainder are in `inds2`. | ||
|
||
This is used, for example, to split data into training and test data while preserving the | ||
class balances. | ||
""" | ||
function shuffle_split_stratified( | ||
rng::Random.AbstractRNG, group_ids::AbstractVector, frac::Real | ||
) | ||
_, indices = unique_indices(group_ids) | ||
T = eltype(eltype(indices)) | ||
N1_tot = sum(x -> round(Int, length(x) * frac), indices) | ||
N2_tot = length(group_ids) - N1_tot | ||
inds1 = Vector{T}(undef, N1_tot) | ||
inds2 = Vector{T}(undef, N2_tot) | ||
items_in_1 = items_in_2 = 0 | ||
for inds in indices | ||
N = length(inds) | ||
N1 = round(Int, N * frac) | ||
N2 = N - N1 | ||
Random.shuffle!(rng, inds) | ||
copyto!(inds1, items_in_1 + 1, inds, 1, N1) | ||
copyto!(inds2, items_in_2 + 1, inds, N1 + 1, N2) | ||
items_in_1 += N1 | ||
items_in_2 += N2 | ||
end | ||
return inds1, inds2 | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
using MCMCDiagnosticTools | ||
using Test | ||
using Random | ||
|
||
@testset "unique_indices" begin | ||
@testset "indices=$(eachindex(inds))" for inds in [ | ||
rand(11:14, 100), transpose(rand(11:14, 10, 10)) | ||
] | ||
unique, indices = @inferred MCMCDiagnosticTools.unique_indices(inds) | ||
@test unique isa Vector{Int} | ||
if eachindex(inds) isa CartesianIndices{2} | ||
@test indices isa Vector{Vector{CartesianIndex{2}}} | ||
else | ||
@test indices isa Vector{Vector{Int}} | ||
end | ||
@test issorted(unique) | ||
@test issetequal(union(indices...), eachindex(inds)) | ||
for i in eachindex(unique, indices) | ||
@test all(inds[indices[i]] .== unique[i]) | ||
end | ||
end | ||
end | ||
|
||
@testset "split_chain_indices" begin | ||
c = [2, 2, 1, 3, 4, 3, 4, 1, 2, 1, 4, 3, 3, 2, 4, 3, 4, 1, 4, 1] | ||
@test @inferred(MCMCDiagnosticTools.split_chain_indices(c, 1)) == c | ||
|
||
cnew = @inferred MCMCDiagnosticTools.split_chain_indices(c, 2) | ||
unique, indices = MCMCDiagnosticTools.unique_indices(c) | ||
uniquenew, indicesnew = MCMCDiagnosticTools.unique_indices(cnew) | ||
for (i, inew) in enumerate(1:2:7) | ||
@test length(indicesnew[inew]) ≥ length(indicesnew[inew + 1]) | ||
@test indices[i] == vcat(indicesnew[inew], indicesnew[inew + 1]) | ||
end | ||
|
||
cnew = MCMCDiagnosticTools.split_chain_indices(c, 3) | ||
unique, indices = MCMCDiagnosticTools.unique_indices(c) | ||
uniquenew, indicesnew = MCMCDiagnosticTools.unique_indices(cnew) | ||
for (i, inew) in enumerate(1:3:11) | ||
@test length(indicesnew[inew]) ≥ | ||
length(indicesnew[inew + 1]) ≥ | ||
length(indicesnew[inew + 2]) | ||
@test indices[i] == | ||
vcat(indicesnew[inew], indicesnew[inew + 1], indicesnew[inew + 2]) | ||
end | ||
end | ||
|
||
@testset "shuffle_split_stratified" begin | ||
rng = Random.default_rng() | ||
c = rand(1:4, 100) | ||
unique, indices = MCMCDiagnosticTools.unique_indices(c) | ||
@testset "frac=$frac" for frac in [0.3, 0.5, 0.7] | ||
inds1, inds2 = @inferred(MCMCDiagnosticTools.shuffle_split_stratified(rng, c, frac)) | ||
@test issetequal(vcat(inds1, inds2), eachindex(c)) | ||
for inds in indices | ||
common_inds = intersect(inds1, inds) | ||
@test length(common_inds) == round(frac * length(inds)) | ||
end | ||
end | ||
end |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't there some existing splitting functionality for
ess
? Is the plan to merge these eventually?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not quite merge, because there are two different types of splitting we can consider. This approach supports ragged chains and is as a result more complex and doesn't discard any draws (instead dividing the remainder across the earlier splits).
For
ess
/rhat
, we don't support ragged chains so would discard draws if necessary to keep them the same length after splitting. This implementation is much simpler and can be done in a non-allocating way with justreshape
andview
on a 3d array. This will be part of #22.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The existing splitting functionality
copy_split!
will go away.