-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Use 70-30 split * Add indices_of_unique * Add split_chain_indices * Add shuffle_split_stratified * Use utility functions in rstar * Update to reflect splits * Perform test without splits * Add test for splitting * Increment patch number * Test methods are type-inferrable * Run formatter * Update subset docs * Rename nsplit to split_chains * From indices_of_unique to unique_indices * Fix bugs in tests * Add comment why we can't use Iterators.partitions * Avoid append! * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Add DataStructures as dependency * Fix variable name * Remove unused variable groups * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Make sure no chain indices are skipped Co-authored-by: David Widmann <[email protected]>
- Loading branch information
Showing
7 changed files
with
199 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
""" | ||
unique_indices(x) -> (unique, indices) | ||
Return the results of `unique(collect(x))` along with the a vector of the same length whose | ||
elements are the indices in `x` at which the corresponding unique element in `unique` is | ||
found. | ||
""" | ||
function unique_indices(x) | ||
inds = eachindex(x) | ||
T = eltype(inds) | ||
ind_map = DataStructures.SortedDict{eltype(x),Vector{T}}() | ||
for i in inds | ||
xi = x[i] | ||
inds_xi = get!(ind_map, xi) do | ||
return T[] | ||
end | ||
push!(inds_xi, i) | ||
end | ||
unique = collect(keys(ind_map)) | ||
indices = collect(values(ind_map)) | ||
return unique, indices | ||
end | ||
|
||
""" | ||
split_chain_indices( | ||
chain_inds::AbstractVector{Int}, | ||
split::Int=2, | ||
) -> AbstractVector{Int} | ||
Split each chain in `chain_inds` into `split` chains. | ||
For each chain in `chain_inds`, all entries are assumed to correspond to draws that have | ||
been ordered by iteration number. The result is a vector of the same length as `chain_inds` | ||
where each entry is the new index of the chain that the corresponding draw belongs to. | ||
""" | ||
function split_chain_indices(c::AbstractVector{Int}, split::Int=2) | ||
cnew = similar(c) | ||
if split == 1 | ||
copyto!(cnew, c) | ||
return cnew | ||
end | ||
_, indices = unique_indices(c) | ||
chain_ind = 1 | ||
for inds in indices | ||
ndraws_per_split, rem = divrem(length(inds), split) | ||
# here we can't use Iterators.partition because it's greedy. e.g. we can't partition | ||
# 4 items across 3 partitions because Iterators.partition(1:4, 1) == [[1], [2], [3]] | ||
# and Iterators.partition(1:4, 2) == [[1, 2], [3, 4]]. But we would want | ||
# [[1, 2], [3], [4]]. | ||
i = j = 0 | ||
ndraws_this_split = ndraws_per_split + (j < rem) | ||
for ind in inds | ||
cnew[ind] = chain_ind | ||
if (i += 1) == ndraws_this_split | ||
i = 0 | ||
j += 1 | ||
ndraws_this_split = ndraws_per_split + (j < rem) | ||
chain_ind += 1 | ||
end | ||
end | ||
end | ||
return cnew | ||
end | ||
|
||
""" | ||
shuffle_split_stratified( | ||
rng::Random.AbstractRNG, | ||
group_ids::AbstractVector, | ||
frac::Real, | ||
) -> (inds1, inds2) | ||
Randomly split the indices of `group_ids` into two groups, where `frac` indices from each | ||
group are in `inds1` and the remainder are in `inds2`. | ||
This is used, for example, to split data into training and test data while preserving the | ||
class balances. | ||
""" | ||
function shuffle_split_stratified( | ||
rng::Random.AbstractRNG, group_ids::AbstractVector, frac::Real | ||
) | ||
_, indices = unique_indices(group_ids) | ||
T = eltype(eltype(indices)) | ||
N1_tot = sum(x -> round(Int, length(x) * frac), indices) | ||
N2_tot = length(group_ids) - N1_tot | ||
inds1 = Vector{T}(undef, N1_tot) | ||
inds2 = Vector{T}(undef, N2_tot) | ||
items_in_1 = items_in_2 = 0 | ||
for inds in indices | ||
N = length(inds) | ||
N1 = round(Int, N * frac) | ||
N2 = N - N1 | ||
Random.shuffle!(rng, inds) | ||
copyto!(inds1, items_in_1 + 1, inds, 1, N1) | ||
copyto!(inds2, items_in_2 + 1, inds, N1 + 1, N2) | ||
items_in_1 += N1 | ||
items_in_2 += N2 | ||
end | ||
return inds1, inds2 | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
using MCMCDiagnosticTools | ||
using Test | ||
using Random | ||
|
||
@testset "unique_indices" begin | ||
@testset "indices=$(eachindex(inds))" for inds in [ | ||
rand(11:14, 100), transpose(rand(11:14, 10, 10)) | ||
] | ||
unique, indices = @inferred MCMCDiagnosticTools.unique_indices(inds) | ||
@test unique isa Vector{Int} | ||
if eachindex(inds) isa CartesianIndices{2} | ||
@test indices isa Vector{Vector{CartesianIndex{2}}} | ||
else | ||
@test indices isa Vector{Vector{Int}} | ||
end | ||
@test issorted(unique) | ||
@test issetequal(union(indices...), eachindex(inds)) | ||
for i in eachindex(unique, indices) | ||
@test all(inds[indices[i]] .== unique[i]) | ||
end | ||
end | ||
end | ||
|
||
@testset "split_chain_indices" begin | ||
c = [2, 2, 1, 3, 4, 3, 4, 1, 2, 1, 4, 3, 3, 2, 4, 3, 4, 1, 4, 1] | ||
@test @inferred(MCMCDiagnosticTools.split_chain_indices(c, 1)) == c | ||
|
||
cnew = @inferred MCMCDiagnosticTools.split_chain_indices(c, 2) | ||
@test issetequal(Base.unique(cnew), 1:maximum(cnew)) # check no indices skipped | ||
unique, indices = MCMCDiagnosticTools.unique_indices(c) | ||
uniquenew, indicesnew = MCMCDiagnosticTools.unique_indices(cnew) | ||
for (i, inew) in enumerate(1:2:7) | ||
@test length(indicesnew[inew]) ≥ length(indicesnew[inew + 1]) | ||
@test indices[i] == vcat(indicesnew[inew], indicesnew[inew + 1]) | ||
end | ||
|
||
cnew = MCMCDiagnosticTools.split_chain_indices(c, 3) | ||
@test issetequal(Base.unique(cnew), 1:maximum(cnew)) # check no indices skipped | ||
unique, indices = MCMCDiagnosticTools.unique_indices(c) | ||
uniquenew, indicesnew = MCMCDiagnosticTools.unique_indices(cnew) | ||
for (i, inew) in enumerate(1:3:11) | ||
@test length(indicesnew[inew]) ≥ | ||
length(indicesnew[inew + 1]) ≥ | ||
length(indicesnew[inew + 2]) | ||
@test indices[i] == | ||
vcat(indicesnew[inew], indicesnew[inew + 1], indicesnew[inew + 2]) | ||
end | ||
end | ||
|
||
@testset "shuffle_split_stratified" begin | ||
rng = Random.default_rng() | ||
c = rand(1:4, 100) | ||
unique, indices = MCMCDiagnosticTools.unique_indices(c) | ||
@testset "frac=$frac" for frac in [0.3, 0.5, 0.7] | ||
inds1, inds2 = @inferred(MCMCDiagnosticTools.shuffle_split_stratified(rng, c, frac)) | ||
@test issetequal(vcat(inds1, inds2), eachindex(c)) | ||
for inds in indices | ||
common_inds = intersect(inds1, inds) | ||
@test length(common_inds) == round(frac * length(inds)) | ||
end | ||
end | ||
end |
63e82cf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
63e82cf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/74081
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: