diff --git a/src/DataStructures.jl b/src/DataStructures.jl index 5e5d767e..cc99eccc 100644 --- a/src/DataStructures.jl +++ b/src/DataStructures.jl @@ -68,6 +68,7 @@ module DataStructures include("queue.jl") include("accumulator.jl") include("disjoint_set.jl") + export PCRecursive, PCIterative, PCHalving, PCSplitting include("heaps.jl") include("default_dict.jl") diff --git a/src/disjoint_set.jl b/src/disjoint_set.jl index 36464547..6b00608a 100644 --- a/src/disjoint_set.jl +++ b/src/disjoint_set.jl @@ -60,13 +60,64 @@ function _find_root_impl!(parents::Vector{T}, x::Integer) where {T<:Integer} return p end +# iterative path compression: makes every node on the path point directly to the root +@inline function find_root_iterative!(parents::Vector{T}, x::Integer) where {T<:Integer} + current = x + # find the root of the tree + @inbounds while parents[current] != current + current = parents[current] + end + root = current + # compress the path: make every node point directly to the root + current = x + @inbounds while parents[current] != root + p = parents[current] # temporarily store the parent + parents[current] = root # point directly to the root + current = p # move to the next node in the original path + end + return root +end + +# path-halving and path-splitting are a one-pass forms of path compression with inverse-ackerman complexity +# e.g., see p.19 of https://www.cs.princeton.edu/courses/archive/spr11/cos423/Lectures/PathCompressionAnalysisII.pdf + +# path-halving: every node on the path points to its grandparent +@inline function find_root_halving!(parents::Vector{T}, x::Integer) where {T<:Integer} + current = x # use a separate variable 'current' to track traversal + @inbounds while parents[current] != current + @inbounds parents[current] = parents[parents[current]] # point to grandparent + @inbounds current = parents[current] # move to grandparent + end + return current +end + +# path-splitting: every node on the path points to its grandparent +@inline function find_root_splitting!(parents::Vector{T}, x::Integer) where {T<:Integer} + @inbounds while parents[x] != x + p = parents[x] # store the current parent + parents[x] = parents[p] # point to grandparent + x = p # move to parent + end + return x +end + + +struct PCRecursive end # path compression types +struct PCIterative end # path compression types +struct PCHalving end # path compression types +struct PCSplitting end # path compression types + """ find_root!(s::IntDisjointSet{T}, x::T) Find the root element of the subset that contains an member `x`. Path compression happens here. """ -find_root!(s::IntDisjointSet{T}, x::T) where {T<:Integer} = find_root_impl!(s.parents, x) +@inline find_root!(s::IntDisjointSet{T}, x::T) where {T<:Integer} = find_root_impl!(s.parents, x) # default +@inline find_root!(s::IntDisjointSet{T}, x::T, ::PCRecursive) where {T<:Integer} = find_root_impl!(s.parents, x) +@inline find_root!(s::IntDisjointSet{T}, x::T, ::PCIterative) where {T<:Integer} = find_root_iterative!(s.parents, x) +@inline find_root!(s::IntDisjointSet{T}, x::T, ::PCHalving) where {T<:Integer} = find_root_halving!(s.parents, x) +@inline find_root!(s::IntDisjointSet{T}, x::T, ::PCSplitting) where {T<:Integer} = find_root_splitting!(s.parents, x) """ in_same_set(s::IntDisjointSet{T}, x::T, y::T) @@ -191,6 +242,10 @@ end Find the root element of the subset in `s` which has the element `x` as a member. """ find_root!(s::DisjointSet{T}, x::T) where {T} = s.revmap[find_root!(s.internal, s.intmap[x])] +find_root!(s::DisjointSet{T}, x::T, ::PCIterative) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCIterative())] +find_root!(s::DisjointSet{T}, x::T, ::PCRecursive) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCRecursive())] +find_root!(s::DisjointSet{T}, x::T, ::PCHalving) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCHalving())] +find_root!(s::DisjointSet{T}, x::T, ::PCSplitting) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCSplitting())] """ in_same_set(s::DisjointSet{T}, x::T, y::T) diff --git a/test/bench_disjoint_set.jl b/test/bench_disjoint_set.jl index d4712b5e..f7c0796f 100644 --- a/test/bench_disjoint_set.jl +++ b/test/bench_disjoint_set.jl @@ -1,6 +1,6 @@ # Benchmark on disjoint set forests -using DataStructures +using DataStructures, BenchmarkTools # do 10^6 random unions over 10^6 element set @@ -29,3 +29,43 @@ x = rand(1:n, T) y = rand(1:n, T) @time batch_union!(s, x, y) + +#= +benchmark `find` operation +=# + +function create_disjoint_set_struct(n::Int) + parents = [1; collect(1:n-1)] # each element's parent is its predecessor + ranks = zeros(Int, n) # ranks are all zero + IntDisjointSet(parents, ranks, n) +end + +# benchmarking function +function benchmark_find_root(n::Int) + println("Benchmarking recursive path compression implementation (find_root_impl!):") + if n >= 10^5 + println("Recursive may path compression may encounter stack-overflow; skipping") + else + s = create_disjoint_set_struct(n) + @btime find_root!($s, $n, PCRecursive()) + end + + println("Benchmarking iterative path compression implementation (find_root_iterative!):") + s = create_disjoint_set_struct(n) # reset parents + @btime find_root!($s, $n, PCIterative()) + + println("Benchmarking path-halving implementation (find_root_halving!):") + s = create_disjoint_set_struct(n) # reset parents + @btime find_root!($s, $n, PCHalving()) + + println("Benchmarking path-splitting implementation (find_root_path_splitting!):") + s = create_disjoint_set_struct(n) # reset parents + @btime find_root!($s, $n, PCSplitting()) +end + +# run benchmark tests +benchmark_find_root(1_000) +benchmark_find_root(10_000) +benchmark_find_root(100_000) +benchmark_find_root(1_000_000) +benchmark_find_root(10_000_000) \ No newline at end of file diff --git a/test/test_disjoint_set.jl b/test/test_disjoint_set.jl index a0346547..d146f309 100644 --- a/test/test_disjoint_set.jl +++ b/test/test_disjoint_set.jl @@ -29,10 +29,16 @@ @test num_groups(s) == T(9) @test in_same_set(s, T(2), T(3)) @test find_root!(s, T(3)) == T(2) + @test find_root!(s, T(3), PCIterative()) == T(2) + @test find_root!(s, T(3), PCHalving()) == T(2) + @test find_root!(s, T(3), PCSplitting()) == T(2) union!(s, T(3), T(2)) @test num_groups(s) == T(9) @test in_same_set(s, T(2), T(3)) @test find_root!(s, T(3)) == T(2) + @test find_root!(s, T(3), PCIterative()) == T(2) + @test find_root!(s, T(3), PCHalving()) == T(2) + @test find_root!(s, T(3), PCSplitting()) == T(2) end @testset "more tests" begin @@ -48,10 +54,19 @@ @test union!(s, T(8), T(5)) == T(8) @test num_groups(s) == T(7) @test find_root!(s, T(6)) == T(8) + @test find_root!(s, T(6), PCIterative()) == T(8) + @test find_root!(s, T(6), PCHalving()) == T(8) + @test find_root!(s, T(6), PCSplitting()) == T(8) union!(s, T(2), T(6)) @test find_root!(s, T(2)) == T(8) root1 = find_root!(s, T(6)) + root1 = find_root!(s, T(6), PCIterative()) + root1 = find_root!(s, T(6), PCHalving()) + root1 = find_root!(s, T(6), PCSplitting()) root2 = find_root!(s, T(2)) + root2 = find_root!(s, T(2), PCIterative()) + root2 = find_root!(s, T(2), PCHalving()) + root2 = find_root!(s, T(2), PCSplitting()) @test root_union!(s, T(root1), T(root2)) == T(8) @test union!(s, T(5), T(6)) == T(8) end @@ -98,6 +113,12 @@ r = [find_root!(s, i) for i in 1 : 10] @test isequal(r, collect(1:10)) + r = [find_root!(s, i, PCIterative()) for i in 1 : 10] + @test isequal(r, collect(1:10)) + r = [find_root!(s, i, PCHalving()) for i in 1 : 10] + @test isequal(r, collect(1:10)) + r = [find_root!(s, i, PCSplitting()) for i in 1 : 10] + @test isequal(r, collect(1:10)) end @testset "union!" begin @@ -117,6 +138,57 @@ @test num_groups(s) == 2 end + @testset "union! PCIterative" begin + for i = 1 : 5 + x = 2 * i - 1 + y = 2 * i + union!(s, x, y) + @test find_root!(s, x, PCIterative()) == find_root!(s, y, PCIterative()) + end + + + @test union!(s, 1, 4) == find_root!(s, 1, PCIterative()) + @test union!(s, 3, 5) == find_root!(s, 1, PCIterative()) + @test union!(s, 7, 9) == find_root!(s, 7, PCIterative()) + + @test length(s) == 10 + @test num_groups(s) == 2 + end + + @testset "union! PCHalving" begin + for i = 1 : 5 + x = 2 * i - 1 + y = 2 * i + union!(s, x, y) + @test find_root!(s, x, PCHalving()) == find_root!(s, y, PCHalving()) + end + + + @test union!(s, 1, 4) == find_root!(s, 1, PCHalving()) + @test union!(s, 3, 5) == find_root!(s, 1, PCHalving()) + @test union!(s, 7, 9) == find_root!(s, 7, PCHalving()) + + @test length(s) == 10 + @test num_groups(s) == 2 + end + + @testset "union! PCSplitting" begin + for i = 1 : 5 + x = 2 * i - 1 + y = 2 * i + union!(s, x, y) + @test find_root!(s, x, PCSplitting()) == find_root!(s, y, PCSplitting()) + end + + + @test union!(s, 1, 4) == find_root!(s, 1, PCSplitting()) + @test union!(s, 3, 5) == find_root!(s, 1, PCSplitting()) + @test union!(s, 7, 9) == find_root!(s, 7, PCSplitting()) + + @test length(s) == 10 + @test num_groups(s) == 2 + end + @testset "r0" begin r0 = [ find_root!(s,i) for i in 1:10 ] # Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11 @@ -130,6 +202,45 @@ @test isequal(r, r0) end + @testset "r0 Iterative" begin + r0 = [ find_root!(s,i) for i in 1:10 ] + # Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11 + push!(s, 17) + + @test length(s) == 11 + @test num_groups(s) == 3 + + r0 = [ r0 ; 17] + r = [find_root!(s, i, PCIterative()) for i in [1 : 10; 17] ] + @test isequal(r, r0) + end + + @testset "r0 Splitting" begin + r0 = [ find_root!(s,i) for i in 1:10 ] + # Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11 + push!(s, 17) + + @test length(s) == 11 + @test num_groups(s) == 3 + + r0 = [ r0 ; 17] + r = [find_root!(s, i, PCSplitting()) for i in [1 : 10; 17] ] + @test isequal(r, r0) + end + + @testset "r0 Halving" begin + r0 = [ find_root!(s,i) for i in 1:10 ] + # Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11 + push!(s, 17) + + @test length(s) == 11 + @test num_groups(s) == 3 + + r0 = [ r0 ; 17] + r = [find_root!(s, i, PCHalving()) for i in [1 : 10; 17] ] + @test isequal(r, r0) + end + @testset "root_union!" begin root1 = find_root!(s, 7) root2 = find_root!(s, 3)