diff --git a/src/Optimizers/KaHyPar.jl b/src/Optimizers/KaHyPar.jl index 2edc06d..757e064 100644 --- a/src/Optimizers/KaHyPar.jl +++ b/src/Optimizers/KaHyPar.jl @@ -7,6 +7,8 @@ using KaHyPar imbalance::Float32 = 0.03 stop::Function = <=(2) ∘ length ∘ Base.Fix1(getfield, :args) configuration::Union{Nothing,Symbol,String} = nothing + edge_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2 + vertex_scaler::Function = Base.Fix1(*, 1000) ∘ Int ∘ round ∘ log2 end function EinExprs.einexpr(config::HyPar, path) @@ -21,8 +23,8 @@ function EinExprs.einexpr(config::HyPar, path) incidence_matrix = sparse(I, J, V) # NOTE indices in `inds` should be in the same order as unique indices appear by iterating on `path.args` because `∪` retains order - edge_weights = map(Base.Fix1(size, path), inds) - vertex_weights = ones(Int, length(path.args)) + edge_weights = map(config.edge_scaler ∘ Base.Fix1(size, path), inds) + vertex_weights = map(config.vertex_scaler ∘ length, path.args) hypergraph = KaHyPar.HyperGraph(incidence_matrix, vertex_weights, edge_weights) diff --git a/test/KaHyPar_test.jl b/test/KaHyPar_test.jl index 559e0ba..f4cdc84 100644 --- a/test/KaHyPar_test.jl +++ b/test/KaHyPar_test.jl @@ -10,7 +10,7 @@ EinExpr([:d, :g, :c], Dict(i => 2 for i in [:d, :g, :c])), ] - path = einexpr(HyPar, EinExpr(Symbol[], tensors)) + path = einexpr(HyPar(imbalance=0.42), EinExpr(Symbol[], tensors)) @test path isa EinExpr @@ -41,10 +41,10 @@ EinExpr([:a, :C, :d], Dict(:a => 3, :d => 6, :C => 4)), ] - path = einexpr(HyPar, EinExpr(Symbol[], tensors)) + path = einexpr(HyPar(imbalance=0.45), EinExpr(Symbol[], tensors)) @test path isa EinExpr - @test mapreduce(flops, +, Branches(path)) == 31653164 + @test mapreduce(flops, +, Branches(path)) == 19099592 end end \ No newline at end of file