From 6d4f6d79b0a0f93d6a8b91c2f87f2cbeb1279368 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 2 Aug 2024 10:49:10 -0400 Subject: [PATCH] hessian of lagrangian --- ext/OptimizationEnzymeExt.jl | 49 +++++++++++++++++++++++-- src/OptimizationDISparseExt.jl | 65 +++++++++++----------------------- 2 files changed, 68 insertions(+), 46 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 862804d..c6e3a34 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -3,7 +3,7 @@ module OptimizationEnzymeExt import OptimizationBase, OptimizationBase.ArrayInterface import OptimizationBase.SciMLBase: OptimizationFunction import OptimizationBase.SciMLBase -import OptimizationBase.LinearAlgebra: I +import OptimizationBase.LinearAlgebra: I, dot import OptimizationBase.ADTypes: AutoEnzyme using Enzyme using Core: Vararg @@ -76,6 +76,18 @@ function cons_f2_oop(x, dx, fcons, p, i) return nothing end +function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))::Float64 + res = zeros(eltype(x), length(λ)) + cons(res, x, p) + return σ * _f(x, p) + dot(λ, res) +end + +function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ) + Enzyme.autodiff_deferred(Enzyme.Reverse, lagrangian, Active, Enzyme.Duplicated(x, dx), + Const(_f), Const(cons), Const(p), Const(λ), Const(σ)) + return nothing +end + function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoEnzyme, p, num_cons = 0) @@ -219,7 +231,40 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, end if f.lag_h === nothing - lag_h = nothing # Consider implementing this + lag_vdθ = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x))))) + lag_bθ = zeros(eltype(x), length(x)) + + if f.hess_prototype === nothing + lag_vdbθ = Tuple(zeros(eltype(x), length(x)) for i in eachindex(x)) + else + #useless right now, looks like there is no way to tell Enzyme the sparsity pattern? + lag_vdbθ = Tuple((copy(r) for r in eachrow(f.hess_prototype))) + end + + function lag_h(h, θ, σ, μ) + Enzyme.make_zero!.(lag_vdθ) + Enzyme.make_zero!(lag_bθ) + Enzyme.make_zero!.(lag_vdbθ) + + Enzyme.autodiff(Enzyme.Forward, + lag_grad, + Enzyme.BatchDuplicated(θ, lag_vdθ), + Enzyme.BatchDuplicatedNoNeed(lag_bθ, lag_vdbθ), + Const(lagrangian), + Const(f.f), + Const(f.cons), + Const(p), + Const(σ), + Const(μ) + ) + k = 0 + + for i in eachindex(θ) + vec_lagv = lag_vdbθ[i] + h[k+1:k+i] .= @view(vec_lagv[1:i]) + k += i + end + end else lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) end diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index b1cb8da..bf1ebb2 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -174,12 +174,12 @@ function instantiate_function( conshess_colors = f.cons_hess_colorvec if cons !== nothing && f.cons_h === nothing fncs = [@closure (x) -> cons_oop(x)[i] for i in 1:num_cons] - extras_cons_hess = Vector{DifferentiationInterface.SparseHessianExtras}(undef, length(fncs)) - for ind in 1:num_cons - extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x) - end - conshess_sparsity = [sum(sparse, cons)] - conshess_colors = getfield.(extras_cons_hess, Ref(:colors)) + # extras_cons_hess = Vector(undef, length(fncs)) + # for ind in 1:num_cons + # extras_cons_hess[ind] = prepare_hessian(fncs[ind], soadtype, x) + # end + # conshess_sparsity = getfield.(extras_cons_hess, :sparsity) + # conshess_colors = getfield.(extras_cons_hess, :colors) function cons_h(H, θ) for i in 1:num_cons hessian!(fncs[i], H[i], soadtype, θ) @@ -189,56 +189,33 @@ function instantiate_function( cons_h = (res, θ) -> f.cons_h(res, θ, p) end - function lagrangian(x, σ = one(eltype(x))) - θ = x[1:end-num_cons] - λ = x[end-num_cons+1:end] - return σ * _f(θ) + dot(λ, cons_oop(θ)) + function lagrangian(x, σ = one(eltype(x)), λ = ones(eltype(x), num_cons)) + return σ * _f(x) + dot(λ, cons_oop(x)) end + lag_hess_prototype = f.lag_hess_prototype if f.lag_h === nothing - lag_extras = prepare_hessian(lagrangian, soadtype, vcat(x, ones(eltype(x), num_cons))) + lag_extras = prepare_hessian(lagrangian, soadtype, x) lag_hess_prototype = lag_extras.sparsity - - function lag_h(H::Matrix, θ, σ, λ) - @show size(H) - @show size(θ) - @show size(λ) + + function lag_h(H::AbstractMatrix, θ, σ, λ) if σ == zero(eltype(θ)) cons_h(H, θ) H *= λ else - hessian!(lagrangian, H, soadtype, vcat(θ, λ), lag_extras) + hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) end end function lag_h(h, θ, σ, λ) - # @show h - sparseHproto = findnz(lag_extras.sparsity) - H = sparse(sparseHproto[1], sparseHproto[2], zeros(eltype(θ), length(sparseHproto[1]))) - if σ == zero(eltype(θ)) - cons_h(H, θ) - H *= λ - else - hessian!(lagrangian, H, soadtype, vcat(θ, λ), lag_extras) - k = 0 - rows, cols, _ = findnz(H) - for (i, j) in zip(rows, cols) - if i <= j - k += 1 - h[k] = σ * H[i, j] - end - end - k = 0 - for λi in λ - if Hi isa SparseMatrixCSC - rows, cols, _ = findnz(Hi) - for (i, j) in zip(rows, cols) - if i <= j - k += 1 - h[k] += λi * Hi[i, j] - end - end - end + H = eltype(θ).(lag_hess_prototype) + hessian!(x -> lagrangian(x, σ, λ), H, soadtype, θ, lag_extras) + k = 0 + rows, cols, _ = findnz(H) + for (i, j) in zip(rows, cols) + if i <= j + k += 1 + h[k] = H[i, j] end end end