From 98cdef05467cfe4c4335d943d027e444ba6a8f39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?= <61060572+jofrevalles@users.noreply.github.com> Date: Fri, 3 Nov 2023 13:17:12 +0100 Subject: [PATCH] Enhance edge and vertex weights for `HyPar` optimizer (#43) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- src/Optimizers/KaHyPar.jl | 6 ++++-- test/KaHyPar_test.jl | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl index f2eed6d..fef7ba9 100644 --- a/src/Optimizers/KaHyPar.jl +++ b/src/Optimizers/KaHyPar.jl @@ -7,6 +7,8 @@ using KaHyPar imbalance::Float32 = 0.03 stop::Function = <=(2) ∘ length ∘ Base.Fix2(getfield, :args) configuration::Union{Nothing,Symbol,String} = nothing + edge_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2 + vertex_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2 end function EinExprs.einexpr(config::HyPar, path) @@ -21,8 +23,8 @@ function EinExprs.einexpr(config::HyPar, path) incidence_matrix = sparse(I, J, V) # NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order - edge_weights = map(Base.Fix1(size, path), inds) - vertex_weights = ones(Int, length(path.args)) + edge_weights = map(config.edge_scaler ∘ Base.Fix1(size, path), inds) + vertex_weights = map(config.vertex_scaler ∘ length, path.args) hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights) diff --git a/test/KaHyPar_test.jl b/test/KaHyPar_test.jl index 559e0ba..f4cdc84 100644 --- a/test/KaHyPar_test.jl +++ b/test/KaHyPar_test.jl @@ -10,7 +10,7 @@ EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), ] - path = einexpr(HyPar, EinExpr(Symbol[], tensors)) + path = einexpr(HyPar(imbalance=0.42), EinExpr(Symbol[], tensors)) @test path isa EinExpr @@ -41,10 +41,10 @@ EinExpr([:a, :C, :d], Dict(:a => 3, :d => 6, :C => 4)), ] - path = einexpr(HyPar, EinExpr(Symbol[], tensors)) + path = einexpr(HyPar(imbalance=0.45), EinExpr(Symbol[], tensors)) @test path isa EinExpr - @test mapreduce(flops, +, Branches(path)) == 31653164 + @test mapreduce(flops, +, Branches(path)) == 19099592 end end \ No newline at end of file