From c39a7b1451fe3770411b37fd11ad3f70ebd1dda8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= <jofrevalles99@gmail.com> Date: Thu, 2 Nov 2023 16:05:43 +0100 Subject: [PATCH 1/8] Enhance edge and vertex weights for HyPar optimizer --- src/Optimizers/KaHyPar.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl index 2edc06d..5d45168 100644 --- a/src/Optimizers/KaHyPar.jl +++ b/src/Optimizers/KaHyPar.jl @@ -7,6 +7,8 @@ using KaHyPar imbalance::Float32 = 0.03 stop::Function = <=(2) ∘ length ∘ Base.Fix1(getfield, :args) configuration::Union{Nothing,Symbol,String} = nothing + edge_weight_scaling::Function = (ind_size) -> 1000 * Int(round(log2(ind_size))) + vertex_weight_scaling::Function = (prod_size) -> 1000 * Int(round(log2(prod_size))) end function EinExprs.einexpr(config::HyPar, path) @@ -21,8 +23,8 @@ function EinExprs.einexpr(config::HyPar, path) incidence_matrix = sparse(I, J, V) # NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order - edge_weights = map(Base.Fix1(size, path), inds) - vertex_weights = ones(Int, length(path.args)) + edge_weights = map(ind -> config.edge_weight_scaling(size(path, ind)), inds) + vertex_weights = map(tensor -> config.vertex_weight_scaling(prod(size(tensor))), path.args) hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights) From 5856a5f3879a13cad40f1eed969bf0362f89d9ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= <jofrevalles99@gmail.com> Date: Thu, 2 Nov 2023 16:57:08 +0100 Subject: [PATCH 2/8] Enhance syntax --- src/Optimizers/KaHyPar.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl index 5d45168..376d289 100644 --- a/src/Optimizers/KaHyPar.jl +++ b/src/Optimizers/KaHyPar.jl @@ -7,8 +7,8 @@ using KaHyPar imbalance::Float32 = 0.03 stop::Function = <=(2) ∘ length ∘ Base.Fix1(getfield, :args) configuration::Union{Nothing,Symbol,String} = nothing - edge_weight_scaling::Function = (ind_size) -> 1000 * Int(round(log2(ind_size))) - vertex_weight_scaling::Function = (prod_size) -> 1000 * Int(round(log2(prod_size))) + edge_scaler::Function = (ind_size) -> 1000 * (Int ∘ round ∘ log2)(ind_size) + vertex_scaler::Function = (prod_size) -> 1000 * (Int ∘ round ∘ log2)(prod_size) end function EinExprs.einexpr(config::HyPar, path) @@ -23,8 +23,8 @@ function EinExprs.einexpr(config::HyPar, path) incidence_matrix = sparse(I, J, V) # NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order - edge_weights = map(ind -> config.edge_weight_scaling(size(path, ind)), inds) - vertex_weights = map(tensor -> config.vertex_weight_scaling(prod(size(tensor))), path.args) + edge_weights = map(ind -> (config.edge_scaler ∘ size)(path, ind), inds) + vertex_weights = map(tensor -> (config.vertex_scaler ∘ length)(tensor), path.args) hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights) From 69622424d603bd9019367b5bc816b23e96bba3d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= <jofrevalles99@gmail.com> Date: Fri, 3 Nov 2023 10:05:52 +0100 Subject: [PATCH 3/8] Use Base.Fix1 --- src/Optimizers/KaHyPar.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl index 376d289..80440a9 100644 --- a/src/Optimizers/KaHyPar.jl +++ b/src/Optimizers/KaHyPar.jl @@ -23,8 +23,8 @@ function EinExprs.einexpr(config::HyPar, path) incidence_matrix = sparse(I, J, V) # NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order - edge_weights = map(ind -> (config.edge_scaler ∘ size)(path, ind), inds) - vertex_weights = map(tensor -> (config.vertex_scaler ∘ length)(tensor), path.args) + edge_weights = map(ind -> (config.edge_scaler ∘ Base.Fix1(size, path))(ind), inds) + vertex_weights = map(config.vertex_scaler ∘ length, path.args) hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights) From a679316856968e977d22f65fb84d29b6e1f609c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?= <61060572+jofrevalles@users.noreply.github.com> Date: Fri, 3 Nov 2023 10:19:27 +0100 Subject: [PATCH 4/8] Apply @mofeing suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- src/Optimizers/KaHyPar.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl index 80440a9..4a9f8f4 100644 --- a/src/Optimizers/KaHyPar.jl +++ b/src/Optimizers/KaHyPar.jl @@ -7,8 +7,8 @@ using KaHyPar imbalance::Float32 = 0.03 stop::Function = <=(2) ∘ length ∘ Base.Fix1(getfield, :args) configuration::Union{Nothing,Symbol,String} = nothing - edge_scaler::Function = (ind_size) -> 1000 * (Int ∘ round ∘ log2)(ind_size) - vertex_scaler::Function = (prod_size) -> 1000 * (Int ∘ round ∘ log2)(prod_size) + edge_scaler::Function = *(1000) ∘ Int ∘ round ∘ log2 + vertex_scaler::Function = *(1000) ∘ Int ∘ round ∘ log2 end function EinExprs.einexpr(config::HyPar, path) @@ -23,7 +23,7 @@ function EinExprs.einexpr(config::HyPar, path) incidence_matrix = sparse(I, J, V) # NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order - edge_weights = map(ind -> (config.edge_scaler ∘ Base.Fix1(size, path))(ind), inds) + edge_weights = map(config.edge_scaler ∘ Base.Fix1(size, path), inds) vertex_weights = map(config.vertex_scaler ∘ length, path.args) hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights) From c92f8bfa798155e6e31a6c787ac2be38bd9f53b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= <jofrevalles99@gmail.com> Date: Fri, 3 Nov 2023 10:24:45 +0100 Subject: [PATCH 5/8] Fix code from code review --- src/Optimizers/KaHyPar.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl index 4a9f8f4..757e064 100644 --- a/src/Optimizers/KaHyPar.jl +++ b/src/Optimizers/KaHyPar.jl @@ -7,8 +7,8 @@ using KaHyPar imbalance::Float32 = 0.03 stop::Function = <=(2) ∘ length ∘ Base.Fix1(getfield, :args) configuration::Union{Nothing,Symbol,String} = nothing - edge_scaler::Function = *(1000) ∘ Int ∘ round ∘ log2 - vertex_scaler::Function = *(1000) ∘ Int ∘ round ∘ log2 + edge_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2 + vertex_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2 end function EinExprs.einexpr(config::HyPar, path) From 628671a5bfe794f1f0f900975890155a4030c4a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= <jofrevalles99@gmail.com> Date: Fri, 3 Nov 2023 10:47:05 +0100 Subject: [PATCH 6/8] Fix KaHyPar test --- test/KaHyPar_test.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/KaHyPar_test.jl b/test/KaHyPar_test.jl index 559e0ba..66ff887 100644 --- a/test/KaHyPar_test.jl +++ b/test/KaHyPar_test.jl @@ -41,10 +41,10 @@ EinExpr([:a, :C, :d], Dict(:a => 3, :d => 6, :C => 4)), ] - path = einexpr(HyPar, EinExpr(Symbol[], tensors)) + path = einexpr(HyPar(imbalance=0.45), EinExpr(Symbol[], tensors)) @test path isa EinExpr - @test mapreduce(flops, +, Branches(path)) == 31653164 + @test mapreduce(flops, +, Branches(path)) == 19099592 end end \ No newline at end of file From 72c3562e954bb55931ea0aeff3ba6dcd5c5a6a29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= <jofrevalles99@gmail.com> Date: Fri, 3 Nov 2023 11:25:36 +0100 Subject: [PATCH 7/8] Fix kahypar test --- test/KaHyPar_test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/KaHyPar_test.jl b/test/KaHyPar_test.jl index 66ff887..4e04050 100644 --- a/test/KaHyPar_test.jl +++ b/test/KaHyPar_test.jl @@ -10,7 +10,7 @@ EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), ] - path = einexpr(HyPar, EinExpr(Symbol[], tensors)) + path = einexpr(HyPar(imbakance=0.42), EinExpr(Symbol[], tensors)) @test path isa EinExpr From 694b3b2de1f9e6cbf1dbcada8cae6a02c63b3c28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= <jofrevalles99@gmail.com> Date: Fri, 3 Nov 2023 11:36:05 +0100 Subject: [PATCH 8/8] Fix typo --- test/KaHyPar_test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/KaHyPar_test.jl b/test/KaHyPar_test.jl index 4e04050..f4cdc84 100644 --- a/test/KaHyPar_test.jl +++ b/test/KaHyPar_test.jl @@ -10,7 +10,7 @@ EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), ] - path = einexpr(HyPar(imbakance=0.42), EinExpr(Symbol[], tensors)) + path = einexpr(HyPar(imbalance=0.42), EinExpr(Symbol[], tensors)) @test path isa EinExpr