Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme: remove closures #59

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ OptimizationZygoteExt = "Zygote"
ADTypes = "1.3"
ArrayInterface = "7.6"
DocStringExtensions = "0.9"
Enzyme = "0.11.11, =0.12.6"
Enzyme = "0.12.12"
FiniteDiff = "2.12"
ForwardDiff = "0.10.26"
LinearAlgebra = "1.9, 1.10"
Expand Down
145 changes: 50 additions & 95 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,47 @@ isdefined(Base, :get_extension) ? (using Enzyme) : (using ..Enzyme)
end
end

function inner_grad(θ, bθ, f, p, args::Vararg{Any, N}) where N
Enzyme.autodiff_deferred(Enzyme.Reverse,
Const(firstapply),
Active,
Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p),
Const.(args)...),
return nothing
end

function hv_f2_alloc(x, f, p, args...)
dx = Enzyme.make_zero(x)
Enzyme.autodiff_deferred(Enzyme.Reverse,
firstapply,
Active,
f,
Enzyme.Duplicated(x, dx),
Const(p),
Const.(args)...)
return dx
end

function inner_cons(x, p, num_cons, i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function inner_cons(x, p, num_cons, i)
function inner_cons(f, x, p, num_cons, i)

res = zeros(eltype(x), num_cons)
f.cons(res, x, p)
return res[i]
end

function cons_f2(x, dx, fcons, p, num_cons, i)
Enzyme.autodiff_deferred(Enzyme.Reverse, inner_cons, Active, Enzyme.Duplicated(x, dx), Const(p), Const(num_cons), Const(i))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Enzyme.autodiff_deferred(Enzyme.Reverse, inner_cons, Active, Enzyme.Duplicated(x, dx), Const(p), Const(num_cons), Const(i))
Enzyme.autodiff_deferred(Enzyme.Reverse, inner_cons, Active, fcons, Enzyme.Duplicated(x, dx), Const(p), Const(num_cons), Const(i))

Won't we need to zero and duplicate the function if it's a closure?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow, can you add more detail what you mean by zeroing the function? It doesn't need to be duplicated it should be Const iiuc (done that in #60)

return nothing
end

function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
adtype::AutoEnzyme, p,
num_cons = 0)
if f.grad === nothing
grad = let
function (res, θ, args...)
res .= zero(eltype(res))
Enzyme.make_zero!(res)
Enzyme.autodiff(Enzyme.Reverse,
Const(firstapply),
Active,
Expand All @@ -36,24 +70,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
end

if f.hess === nothing
function g(θ, bθ, f, p, args...)
Enzyme.autodiff_deferred(Enzyme.Reverse,
Const(firstapply),
Active,
Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p),
Const.(args)...),
return nothing
end
function hess(res, θ, args...)
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0)))

bθ = zeros(length(θ))
vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ))

Enzyme.autodiff(Enzyme.Forward,
g,
inner_grad,
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(f.f),
Expand All @@ -69,19 +93,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
end

if f.hv === nothing
function f2(x, f, p, args...)
dx = zeros(length(x))
Enzyme.autodiff_deferred(Enzyme.Reverse,
firstapply,
Active,
f,
Enzyme.Duplicated(x, dx),
Const(p),
Const.(args)...)
return dx
end
hv = function (H, θ, v, args...)
H .= Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v),
H .= Enzyme.autodiff(Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
Const(_f), Const(f.f), Const(p),
Const.(args)...)[1]
end
Expand Down Expand Up @@ -109,19 +122,6 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
end

if cons !== nothing && f.cons_h === nothing
fncs = map(1:num_cons) do i
function (x)
res = zeros(eltype(x), num_cons)
f.cons(res, x, p)
return res[i]
end
end

function f2(x, dx, fnc)
Enzyme.autodiff_deferred(Enzyme.Reverse, fnc, Enzyme.Duplicated(x, dx))
return nothing
end

cons_h = function (res, θ)
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0)))
bθ = zeros(length(θ))
Expand All @@ -132,10 +132,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
el .= zeros(length(θ))
end
Enzyme.autodiff(Enzyme.Forward,
f2,
cons_f2,
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(fncs[i]))
Const(f.cons),
Const(p),
Const(num_cons),
Const(i)
)

for j in eachindex(θ)
res[i][j, :] .= vdbθ[j]
Expand All @@ -161,7 +165,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},

if f.grad === nothing
function grad(res, θ, args...)
res .= zero(eltype(res))
Enzyme.make_zero!(res)
Enzyme.autodiff(Enzyme.Reverse,
Const(firstapply),
Active,
Expand All @@ -175,21 +179,13 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
end

if f.hess === nothing
function g(θ, bθ, f, p, args...)
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(firstapply), Active, Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p),
Const.(args)...)
return nothing
end
function hess(res, θ, args...)
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0)))

bθ = zeros(length(θ))
vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ))

Enzyme.autodiff(Enzyme.Forward,
g,
inner_grad,
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(f.f),
Expand All @@ -205,17 +201,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
end

if f.hv === nothing
function f2(x, f, p, args...)
dx = zeros(length(x))
Enzyme.autodiff_deferred(Enzyme.Reverse, firstapply, Active,
f,
Enzyme.Duplicated(x, dx),
Const(p),
Const.(args)...)
return dx
end
hv = function (H, θ, v, args...)
H .= Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v),
H .= Enzyme.autodiff(Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
Const(f.f), Const(p),
Const.(args)...)[1]
end
Expand Down Expand Up @@ -294,24 +281,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
end

if f.hess === nothing
function g(θ, bθ, f, p, args...)
Enzyme.autodiff_deferred(Enzyme.Reverse,
Const(firstapply),
Active,
Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p),
Const.(args)...),
return nothing
end
function hess(θ, args...)
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0)))

bθ = zeros(length(θ))
vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ))

Enzyme.autodiff(Enzyme.Forward,
g,
inner_grad,
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(f.f),
Expand Down Expand Up @@ -418,7 +395,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
res = zeros(eltype(x), size(x))
grad = let res = res
function (θ, args...)
res .= zero(eltype(res))
Enzyme.make_zero!(res)
Enzyme.autodiff(Enzyme.Reverse,
Const(firstapply),
Active,
Expand All @@ -434,24 +411,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
end

if f.hess === nothing
function g(θ, bθ, f, p, args...)
Enzyme.autodiff_deferred(Enzyme.Reverse,
Const(firstapply),
Active,
Const(f),
Enzyme.Duplicated(θ, bθ),
Const(p),
Const.(args)...),
return nothing
end
function hess(θ, args...)
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0)))

bθ = zeros(length(θ))
vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ))

Enzyme.autodiff(Enzyme.Forward,
g,
inner_grad,
Enzyme.BatchDuplicated(θ, vdθ),
Enzyme.BatchDuplicated(bθ, vdbθ),
Const(f.f),
Expand All @@ -465,20 +432,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
end

if f.hv === nothing
dx = zeros(length(x))
function f2(x, f, p, args...)
dx .= zero(eltype(dx))
Enzyme.autodiff_deferred(Enzyme.Reverse,
firstapply,
Active,
f,
Enzyme.Duplicated(x, dx),
Const(p),
Const.(args)...)
return dx
end
hv = function (θ, v, args...)
Enzyme.autodiff(Enzyme.Forward, f2, DuplicatedNoNeed, Duplicated(θ, v),
Enzyme.autodiff(Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
Const(_f), Const(f.f), Const(p),
Const.(args)...)[1]
end
Expand Down
Loading