This repository has been archived by the owner on Nov 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from LuxDL/ap/batched_nonlinearproblem
Add Batched Nonlinear Solvers and Forward AD rules
- Loading branch information
Showing
15 changed files
with
384 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 changes: 24 additions & 0 deletions
24
ext/BatchedRoutinesForwardDiffExt/BatchedRoutinesForwardDiffExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
module BatchedRoutinesForwardDiffExt | ||
|
||
using ADTypes: AutoForwardDiff | ||
using ArrayInterface: parameterless_type | ||
using BatchedRoutines: BatchedRoutines, AbstractBatchedNonlinearAlgorithm, | ||
UniformBlockDiagonalOperator, batched_jacobian, batched_mul, | ||
batched_pickchunksize, _assert_type | ||
using ChainRulesCore: ChainRulesCore | ||
using FastClosures: @closure | ||
using ForwardDiff: ForwardDiff, Dual | ||
using LinearAlgebra: LinearAlgebra | ||
using LuxDeviceUtils: LuxDeviceUtils, get_device | ||
using SciMLBase: SciMLBase, NonlinearProblem | ||
|
||
const CRC = ChainRulesCore | ||
|
||
@inline BatchedRoutines._is_extension_loaded(::Val{:ForwardDiff}) = true | ||
|
||
@inline BatchedRoutines.__can_forwarddiff_dual(::Type{T}) where {T} = ForwardDiff.can_dual(T) | ||
|
||
include("jacobian.jl") | ||
include("nonlinearsolve_ad.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
function SciMLBase.solve( | ||
prob::NonlinearProblem{<:AbstractArray, iip, | ||
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, | ||
alg::AbstractBatchedNonlinearAlgorithm, | ||
args...; | ||
kwargs...) where {T, V, P, iip} | ||
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) | ||
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) | ||
return SciMLBase.build_solution( | ||
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) | ||
end | ||
|
||
function __nlsolve_ad(prob::NonlinearProblem, alg, args...; kwargs...) | ||
p = ForwardDiff.value.(prob.p) | ||
u0 = ForwardDiff.value.(prob.u0) | ||
newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...) | ||
|
||
sol = SciMLBase.solve(newprob, alg, args...; kwargs...) | ||
|
||
uu = sol.u | ||
Jₚ = ForwardDiff.jacobian(Base.Fix1(prob.f, uu), p) | ||
Jᵤ = if prob.f.jac === nothing | ||
BatchedRoutines.batched_jacobian(AutoForwardDiff(), prob.f, uu, p) | ||
else | ||
BatchedRoutines._wrap_batched_operator(prob.f.jac(uu, p)) | ||
end | ||
|
||
Jᵤ_fact = LinearAlgebra.lu!(Jᵤ) | ||
|
||
map_fn = @closure zp -> begin | ||
Jₚᵢ, p = zp | ||
LinearAlgebra.ldiv!(Jᵤ_fact, Jₚᵢ) | ||
Jₚᵢ .*= -1 | ||
return map(Base.Fix2(*, ForwardDiff.partials(p)), Jₚᵢ) | ||
end | ||
|
||
return sol, sum(map_fn, zip(eachcol(Jₚ), prob.p)) | ||
end | ||
|
||
@inline function __nlsolve_dual_soln(u::AbstractArray, partials, | ||
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P} | ||
_partials = reshape(partials, size(u)) | ||
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
13 changes: 13 additions & 0 deletions
13
ext/BatchedRoutinesSciMLSensitivityExt/BatchedRoutinesSciMLSensitivityExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
module BatchedRoutinesSciMLSensitivityExt | ||
|
||
using ADTypes: AutoForwardDiff, AutoFiniteDiff | ||
using BatchedRoutines: BatchedRoutines, BatchedNonlinearSolution | ||
using FastClosures: @closure | ||
using LinearSolve: LinearSolve | ||
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearSolution | ||
using SciMLSensitivity: SciMLSensitivity, SteadyStateAdjoint, ZygoteVJP | ||
using Zygote: Zygote | ||
|
||
include("steadystateadjoint.jl") | ||
|
||
end |
92 changes: 92 additions & 0 deletions
92
ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import SciMLSensitivity: SteadyStateAdjointProblem, SteadyStateAdjointSensitivityFunction | ||
|
||
function SteadyStateAdjointProblem( | ||
sol::BatchedNonlinearSolution, sensealg::SteadyStateAdjoint, alg, | ||
dgdu::DG1=nothing, dgdp::DG2=nothing, g::G=nothing; kwargs...) where {DG1, DG2, G} | ||
@assert sol.prob isa NonlinearProblem | ||
(; f, p, u0) = sol.prob | ||
f = SciMLBase.ODEFunction(f) | ||
|
||
@assert !SciMLBase.isinplace(sol.prob) "Adjoint for Batched Problems does not support \ | ||
inplace problems." | ||
@assert ndims(u0)==2 "u0 must be a matrix." | ||
@assert dgdu!==nothing "`dgdu` must be specified. Automatic differentiation is not \ | ||
currently implemented for this part." | ||
@assert sensealg.autojacvec isa ZygoteVJP | ||
|
||
dgdu === nothing && | ||
dgdp === nothing && | ||
g === nothing && | ||
error("Either `dgdu`, `dgdp`, or `g` must be specified.") | ||
|
||
needs_jac = ifelse(SciMLBase.has_adjoint(f), | ||
false, | ||
ifelse(sensealg.linsolve === nothing, size(u0, 1) ≤ 50, | ||
SciMLSensitivity.__needs_concrete_A(sensealg.linsolve))) | ||
|
||
p === SciMLBase.NullParameters() && | ||
error("Your model does not have parameters, and thus it is impossible to calculate \ | ||
the derivative of the solution with respect to the parameters. Your model \ | ||
must have parameters to use parameter sensitivity calculations!") | ||
|
||
y = sol.u | ||
|
||
if needs_jac | ||
if SciMLBase.has_jac(f) | ||
J = BatchedRoutines._wrap_batched_operator(f.jac(y, p, nothing)) | ||
else | ||
uf = SciMLBase.UJacobianWrapper(f, nothing, p) | ||
if SciMLSensitivity.alg_autodiff(sensealg) | ||
J = BatchedRoutines.batched_jacobian(AutoFiniteDiff(), uf, y) | ||
else | ||
J = BatchedRoutines.batched_jacobian(AutoForwardDiff(), uf, y) | ||
end | ||
end | ||
end | ||
|
||
if dgdp === nothing && g === nothing | ||
dgdu_val = similar(u0, length(u0)) | ||
dgdp_val = nothing | ||
else | ||
dgdu_val, dgdp_val = similar(u0, length(u0)), similar(u0, length(p)) | ||
end | ||
|
||
if dgdu !== nothing | ||
dgdu(dgdu_val, y, p, nothing, nothing) | ||
else | ||
error("Not implemented yet") | ||
end | ||
|
||
if !needs_jac # Construct an operator and use Jacobian-Free Linear Solve | ||
linsolve = if sensealg.linsolve === nothing | ||
LinearSolve.SimpleGMRES(; blocksize=size(u0, 1)) | ||
else | ||
sensealg.linsolve | ||
end | ||
usize = size(y) | ||
__f = @closure y -> vec(f(reshape(y, usize), p, nothing)) | ||
operator = SciMLSensitivity.VecJac(__f, vec(y); | ||
autodiff=SciMLSensitivity.get_autodiff_from_vjp(sensealg.autojacvec)) | ||
linear_problem = SciMLBase.LinearProblem(operator, dgdu_val) | ||
linsol = SciMLBase.solve( | ||
linear_problem, linsolve; alias_A=true, sensealg.linsolve_kwargs...) | ||
else | ||
linear_problem = SciMLBase.LinearProblem(J', dgdu_val) | ||
linsol = SciMLBase.solve( | ||
linear_problem, sensealg.linsolve; alias_A=true, sensealg.linsolve_kwargs...) | ||
end | ||
λ = linsol.u | ||
|
||
_, pb_f = Zygote.pullback(@closure(p->vec(f(y, p, nothing))), p) | ||
∂p = only(pb_f(λ)) | ||
∂p === nothing && | ||
!sensealg.autojacvec.allow_nothing && | ||
throw(SciMLSensitivity.ZygoteVJPNothingError()) | ||
|
||
if g !== nothing || dgdp !== nothing | ||
error("Not implemented yet") | ||
else | ||
SciMLSensitivity.recursive_neg!(∂p) | ||
return ∂p | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.