From 98cdef05467cfe4c4335d943d027e444ba6a8f39 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?=
 <61060572+jofrevalles@users.noreply.github.com>
Date: Fri, 3 Nov 2023 13:17:12 +0100
Subject: [PATCH] Enhance edge and vertex weights for `HyPar` optimizer (#43)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com>
---
 src/Optimizers/KaHyPar.jl | 6 ++++--
 test/KaHyPar_test.jl      | 6 +++---
 2 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl
index f2eed6d..fef7ba9 100644
--- a/src/Optimizers/KaHyPar.jl
+++ b/src/Optimizers/KaHyPar.jl
@@ -7,6 +7,8 @@ using KaHyPar
     imbalance::Float32 = 0.03
     stop::Function = <=(2) ∘ length ∘ Base.Fix2(getfield, :args)
     configuration::Union{Nothing,Symbol,String} = nothing
+    edge_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2
+    vertex_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2
 end
 
 function EinExprs.einexpr(config::HyPar, path)
@@ -21,8 +23,8 @@ function EinExprs.einexpr(config::HyPar, path)
     incidence_matrix = sparse(I, J, V)
 
     # NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order
-    edge_weights = map(Base.Fix1(size, path), inds)
-    vertex_weights = ones(Int, length(path.args))
+        edge_weights = map(config.edge_scaler ∘ Base.Fix1(size, path), inds)
+        vertex_weights = map(config.vertex_scaler ∘ length, path.args)
 
     hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights)
 
diff --git a/test/KaHyPar_test.jl b/test/KaHyPar_test.jl
index 559e0ba..f4cdc84 100644
--- a/test/KaHyPar_test.jl
+++ b/test/KaHyPar_test.jl
@@ -10,7 +10,7 @@
             EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])),
         ]
 
-        path = einexpr(HyPar, EinExpr(Symbol[], tensors))
+        path = einexpr(HyPar(imbalance=0.42), EinExpr(Symbol[], tensors))
 
         @test path isa EinExpr
 
@@ -41,10 +41,10 @@
             EinExpr([:a, :C, :d], Dict(:a => 3, :d => 6, :C => 4)),
         ]
 
-        path = einexpr(HyPar, EinExpr(Symbol[], tensors))
+        path = einexpr(HyPar(imbalance=0.45), EinExpr(Symbol[], tensors))
 
         @test path isa EinExpr
 
-        @test mapreduce(flops, +, Branches(path)) == 31653164
+        @test mapreduce(flops, +, Branches(path)) == 19099592
     end
 end
\ No newline at end of file