From c39a7b1451fe3770411b37fd11ad3f70ebd1dda8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= <jofrevalles99@gmail.com>
Date: Thu, 2 Nov 2023 16:05:43 +0100
Subject: [PATCH 1/8] Enhance edge and vertex weights for HyPar optimizer

---
 src/Optimizers/KaHyPar.jl | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl
index 2edc06d..5d45168 100644
--- a/src/Optimizers/KaHyPar.jl
+++ b/src/Optimizers/KaHyPar.jl
@@ -7,6 +7,8 @@ using KaHyPar
     imbalance::Float32 = 0.03
     stop::Function = <=(2) ∘ length ∘ Base.Fix1(getfield, :args)
     configuration::Union{Nothing,Symbol,String} = nothing
+    edge_weight_scaling::Function = (ind_size) -> 1000 * Int(round(log2(ind_size)))
+    vertex_weight_scaling::Function = (prod_size) -> 1000 * Int(round(log2(prod_size)))
 end
 
 function EinExprs.einexpr(config::HyPar, path)
@@ -21,8 +23,8 @@ function EinExprs.einexpr(config::HyPar, path)
     incidence_matrix = sparse(I, J, V)
 
     # NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order
-    edge_weights = map(Base.Fix1(size, path), inds)
-    vertex_weights = ones(Int, length(path.args))
+    edge_weights = map(ind -> config.edge_weight_scaling(size(path, ind)), inds)
+    vertex_weights = map(tensor -> config.vertex_weight_scaling(prod(size(tensor))), path.args)
 
     hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights)
 

From 5856a5f3879a13cad40f1eed969bf0362f89d9ad Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= <jofrevalles99@gmail.com>
Date: Thu, 2 Nov 2023 16:57:08 +0100
Subject: [PATCH 2/8] Enhance syntax

---
 src/Optimizers/KaHyPar.jl | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl
index 5d45168..376d289 100644
--- a/src/Optimizers/KaHyPar.jl
+++ b/src/Optimizers/KaHyPar.jl
@@ -7,8 +7,8 @@ using KaHyPar
     imbalance::Float32 = 0.03
     stop::Function = <=(2) ∘ length ∘ Base.Fix1(getfield, :args)
     configuration::Union{Nothing,Symbol,String} = nothing
-    edge_weight_scaling::Function = (ind_size) -> 1000 * Int(round(log2(ind_size)))
-    vertex_weight_scaling::Function = (prod_size) -> 1000 * Int(round(log2(prod_size)))
+    edge_scaler::Function = (ind_size) -> 1000 * (Int ∘ round ∘ log2)(ind_size)
+    vertex_scaler::Function = (prod_size) -> 1000 * (Int ∘ round ∘ log2)(prod_size)
 end
 
 function EinExprs.einexpr(config::HyPar, path)
@@ -23,8 +23,8 @@ function EinExprs.einexpr(config::HyPar, path)
     incidence_matrix = sparse(I, J, V)
 
     # NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order
-    edge_weights = map(ind -> config.edge_weight_scaling(size(path, ind)), inds)
-    vertex_weights = map(tensor -> config.vertex_weight_scaling(prod(size(tensor))), path.args)
+    edge_weights = map(ind -> (config.edge_scaler ∘ size)(path, ind), inds)
+    vertex_weights = map(tensor -> (config.vertex_scaler ∘ length)(tensor), path.args)
 
     hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights)
 

From 69622424d603bd9019367b5bc816b23e96bba3d2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= <jofrevalles99@gmail.com>
Date: Fri, 3 Nov 2023 10:05:52 +0100
Subject: [PATCH 3/8] Use Base.Fix1

---
 src/Optimizers/KaHyPar.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl
index 376d289..80440a9 100644
--- a/src/Optimizers/KaHyPar.jl
+++ b/src/Optimizers/KaHyPar.jl
@@ -23,8 +23,8 @@ function EinExprs.einexpr(config::HyPar, path)
     incidence_matrix = sparse(I, J, V)
 
     # NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order
-    edge_weights = map(ind -> (config.edge_scaler ∘ size)(path, ind), inds)
-    vertex_weights = map(tensor -> (config.vertex_scaler ∘ length)(tensor), path.args)
+        edge_weights = map(ind -> (config.edge_scaler ∘ Base.Fix1(size, path))(ind), inds)
+        vertex_weights = map(config.vertex_scaler ∘ length, path.args)
 
     hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights)
 

From a679316856968e977d22f65fb84d29b6e1f609c5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?=
 <61060572+jofrevalles@users.noreply.github.com>
Date: Fri, 3 Nov 2023 10:19:27 +0100
Subject: [PATCH 4/8] Apply @mofeing suggestions from code review
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com>
---
 src/Optimizers/KaHyPar.jl | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl
index 80440a9..4a9f8f4 100644
--- a/src/Optimizers/KaHyPar.jl
+++ b/src/Optimizers/KaHyPar.jl
@@ -7,8 +7,8 @@ using KaHyPar
     imbalance::Float32 = 0.03
     stop::Function = <=(2) ∘ length ∘ Base.Fix1(getfield, :args)
     configuration::Union{Nothing,Symbol,String} = nothing
-    edge_scaler::Function = (ind_size) -> 1000 * (Int ∘ round ∘ log2)(ind_size)
-    vertex_scaler::Function = (prod_size) -> 1000 * (Int ∘ round ∘ log2)(prod_size)
+    edge_scaler::Function = *(1000) ∘ Int ∘ round ∘ log2
+    vertex_scaler::Function = *(1000) ∘ Int ∘ round ∘ log2
 end
 
 function EinExprs.einexpr(config::HyPar, path)
@@ -23,7 +23,7 @@ function EinExprs.einexpr(config::HyPar, path)
     incidence_matrix = sparse(I, J, V)
 
     # NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order
-        edge_weights = map(ind -> (config.edge_scaler ∘ Base.Fix1(size, path))(ind), inds)
+        edge_weights = map(config.edge_scaler ∘ Base.Fix1(size, path), inds)
         vertex_weights = map(config.vertex_scaler ∘ length, path.args)
 
     hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights)

From c92f8bfa798155e6e31a6c787ac2be38bd9f53b1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= <jofrevalles99@gmail.com>
Date: Fri, 3 Nov 2023 10:24:45 +0100
Subject: [PATCH 5/8] Fix code from code review

---
 src/Optimizers/KaHyPar.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl
index 4a9f8f4..757e064 100644
--- a/src/Optimizers/KaHyPar.jl
+++ b/src/Optimizers/KaHyPar.jl
@@ -7,8 +7,8 @@ using KaHyPar
     imbalance::Float32 = 0.03
     stop::Function = <=(2) ∘ length ∘ Base.Fix1(getfield, :args)
     configuration::Union{Nothing,Symbol,String} = nothing
-    edge_scaler::Function = *(1000) ∘ Int ∘ round ∘ log2
-    vertex_scaler::Function = *(1000) ∘ Int ∘ round ∘ log2
+    edge_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2
+    vertex_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2
 end
 
 function EinExprs.einexpr(config::HyPar, path)

From 628671a5bfe794f1f0f900975890155a4030c4a7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= <jofrevalles99@gmail.com>
Date: Fri, 3 Nov 2023 10:47:05 +0100
Subject: [PATCH 6/8] Fix KaHyPar test

---
 test/KaHyPar_test.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/test/KaHyPar_test.jl b/test/KaHyPar_test.jl
index 559e0ba..66ff887 100644
--- a/test/KaHyPar_test.jl
+++ b/test/KaHyPar_test.jl
@@ -41,10 +41,10 @@
             EinExpr([:a, :C, :d], Dict(:a => 3, :d => 6, :C => 4)),
         ]
 
-        path = einexpr(HyPar, EinExpr(Symbol[], tensors))
+        path = einexpr(HyPar(imbalance=0.45), EinExpr(Symbol[], tensors))
 
         @test path isa EinExpr
 
-        @test mapreduce(flops, +, Branches(path)) == 31653164
+        @test mapreduce(flops, +, Branches(path)) == 19099592
     end
 end
\ No newline at end of file

From 72c3562e954bb55931ea0aeff3ba6dcd5c5a6a29 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= <jofrevalles99@gmail.com>
Date: Fri, 3 Nov 2023 11:25:36 +0100
Subject: [PATCH 7/8] Fix kahypar test

---
 test/KaHyPar_test.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/KaHyPar_test.jl b/test/KaHyPar_test.jl
index 66ff887..4e04050 100644
--- a/test/KaHyPar_test.jl
+++ b/test/KaHyPar_test.jl
@@ -10,7 +10,7 @@
             EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])),
         ]
 
-        path = einexpr(HyPar, EinExpr(Symbol[], tensors))
+        path = einexpr(HyPar(imbakance=0.42), EinExpr(Symbol[], tensors))
 
         @test path isa EinExpr
 

From 694b3b2de1f9e6cbf1dbcada8cae6a02c63b3c28 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= <jofrevalles99@gmail.com>
Date: Fri, 3 Nov 2023 11:36:05 +0100
Subject: [PATCH 8/8] Fix typo

---
 test/KaHyPar_test.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/KaHyPar_test.jl b/test/KaHyPar_test.jl
index 4e04050..f4cdc84 100644
--- a/test/KaHyPar_test.jl
+++ b/test/KaHyPar_test.jl
@@ -10,7 +10,7 @@
             EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])),
         ]
 
-        path = einexpr(HyPar(imbakance=0.42), EinExpr(Symbol[], tensors))
+        path = einexpr(HyPar(imbalance=0.42), EinExpr(Symbol[], tensors))
 
         @test path isa EinExpr