Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Merge pull request #13 from LuxDL/ap/batched_nonlinearproblem
Browse files Browse the repository at this point in the history
Add Batched Nonlinear Solvers and Forward AD rules
  • Loading branch information
avik-pal authored Apr 25, 2024
2 parents b1f3f8f + a37f694 commit a2fb973
Show file tree
Hide file tree
Showing 15 changed files with 384 additions and 34 deletions.
19 changes: 15 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"

[weakdeps]
Expand All @@ -23,16 +24,18 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BatchedRoutinesCUDAExt = ["CUDA"]
BatchedRoutinesComponentArraysForwardDiffExt = ["ComponentArrays", "ForwardDiff"]
BatchedRoutinesCUDALinearSolveExt = ["CUDA", "LinearSolve"]
BatchedRoutinesComponentArraysForwardDiffExt = ["ComponentArrays", "ForwardDiff"]
BatchedRoutinesFiniteDiffExt = ["FiniteDiff"]
BatchedRoutinesForwardDiffExt = ["ForwardDiff"]
BatchedRoutinesLinearSolveExt = ["LinearSolve"]
BatchedRoutinesReverseDiffExt = ["ReverseDiff"]
BatchedRoutinesSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity", "Zygote"]
BatchedRoutinesZygoteExt = ["Zygote"]

[compat]
Expand All @@ -42,6 +45,7 @@ Aqua = "0.8.4"
ArrayInterface = "7.8.1"
CUDA = "5.2.0"
ChainRulesCore = "1.23"
Chairmarks = "1.2"
ComponentArrays = "0.15.10"
ConcreteStructs = "0.2.3"
ExplicitImports = "1.4.0"
Expand All @@ -56,18 +60,22 @@ LuxCUDA = "0.3.2"
LuxDeviceUtils = "0.1.17"
LuxTestUtils = "0.1.15"
PrecompileTools = "1.2.0"
Random = "<0.0.1, 1"
Random = "1.10"
ReTestItems = "1.23.1"
ReverseDiff = "1.15"
SciMLBase = "2.31"
SciMLOperators = "0.3.8"
SciMLSensitivity = "7.56"
SimpleNonlinearSolve = "1.7"
StableRNGs = "1.0.1"
Statistics = "1.11.1"
Test = "<0.0.1, 1"
Test = "1.10"
Zygote = "0.6.69"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Expand All @@ -80,10 +88,13 @@ LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ComponentArrays", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "Lux", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "StableRNGs", "Statistics", "Test", "Zygote"]
test = ["Aqua", "Chairmarks", "ComponentArrays", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "Lux", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "SciMLBase", "SciMLSensitivity", "SimpleNonlinearSolve", "StableRNGs", "Statistics", "Test", "Zygote"]
24 changes: 24 additions & 0 deletions ext/BatchedRoutinesForwardDiffExt/BatchedRoutinesForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module BatchedRoutinesForwardDiffExt

using ADTypes: AutoForwardDiff
using ArrayInterface: parameterless_type
using BatchedRoutines: BatchedRoutines, AbstractBatchedNonlinearAlgorithm,
UniformBlockDiagonalOperator, batched_jacobian, batched_mul,
batched_pickchunksize, _assert_type
using ChainRulesCore: ChainRulesCore
using FastClosures: @closure
using ForwardDiff: ForwardDiff, Dual
using LinearAlgebra: LinearAlgebra
using LuxDeviceUtils: LuxDeviceUtils, get_device
using SciMLBase: SciMLBase, NonlinearProblem

const CRC = ChainRulesCore

@inline BatchedRoutines._is_extension_loaded(::Val{:ForwardDiff}) = true

@inline BatchedRoutines.__can_forwarddiff_dual(::Type{T}) where {T} = ForwardDiff.can_dual(T)

include("jacobian.jl")
include("nonlinearsolve_ad.jl")

end
Original file line number Diff line number Diff line change
@@ -1,18 +1,3 @@
module BatchedRoutinesForwardDiffExt

using ADTypes: AutoForwardDiff
using ArrayInterface: parameterless_type
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, batched_jacobian,
batched_mul, batched_pickchunksize, _assert_type
using ChainRulesCore: ChainRulesCore
using FastClosures: @closure
using ForwardDiff: ForwardDiff
using LuxDeviceUtils: LuxDeviceUtils, get_device

const CRC = ChainRulesCore

@inline BatchedRoutines._is_extension_loaded(::Val{:ForwardDiff}) = true

# api.jl
function BatchedRoutines.batched_pickchunksize(
X::AbstractArray{T, N}, n::Int=ForwardDiff.DEFAULT_CHUNK_THRESHOLD) where {T, N}
Expand Down Expand Up @@ -242,5 +227,3 @@ end
partials = ForwardDiff.Partials{1, T}.(tuple.(u))
return ForwardDiff.Dual{Tag, T, 1}.(x, reshape(partials, size(x)))
end

end
44 changes: 44 additions & 0 deletions ext/BatchedRoutinesForwardDiffExt/nonlinearsolve_ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
function SciMLBase.solve(
prob::NonlinearProblem{<:AbstractArray, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::AbstractBatchedNonlinearAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end

function __nlsolve_ad(prob::NonlinearProblem, alg, args...; kwargs...)
p = ForwardDiff.value.(prob.p)
u0 = ForwardDiff.value.(prob.u0)
newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...)

sol = SciMLBase.solve(newprob, alg, args...; kwargs...)

uu = sol.u
Jₚ = ForwardDiff.jacobian(Base.Fix1(prob.f, uu), p)
Jᵤ = if prob.f.jac === nothing
BatchedRoutines.batched_jacobian(AutoForwardDiff(), prob.f, uu, p)
else
BatchedRoutines._wrap_batched_operator(prob.f.jac(uu, p))
end

Jᵤ_fact = LinearAlgebra.lu!(Jᵤ)

map_fn = @closure zp -> begin
Jₚᵢ, p = zp
LinearAlgebra.ldiv!(Jᵤ_fact, Jₚᵢ)
Jₚᵢ .*= -1
return map(Base.Fix2(*, ForwardDiff.partials(p)), Jₚᵢ)
end

return sol, sum(map_fn, zip(eachcol(Jₚ), prob.p))
end

@inline function __nlsolve_dual_soln(u::AbstractArray, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
_partials = reshape(partials, size(u))
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials))
end
5 changes: 3 additions & 2 deletions ext/BatchedRoutinesLinearSolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, getdata
using ChainRulesCore: ChainRulesCore, NoTangent
using FastClosures: @closure
using LinearAlgebra: LinearAlgebra
using LinearSolve: LinearSolve, SciMLBase
using LinearSolve: LinearSolve
using SciMLBase: SciMLBase

const CRC = ChainRulesCore

Expand Down Expand Up @@ -113,7 +114,7 @@ function LinearSolve.solve!(cache::LinearSolve.LinearCache{<:UniformBlockDiagona
y = LinearAlgebra.ldiv!(
cache.u, LinearSolve.@get_cacheval(cache, :NormalCholeskyFactorization),
A' * cache.b)
return LinearSolve.SciMLBase.build_linear_solution(alg, y, nothing, cache)
return SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

# SVDFactorization
Expand Down
18 changes: 17 additions & 1 deletion ext/BatchedRoutinesReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module BatchedRoutinesReverseDiffExt

using ADTypes: AutoReverseDiff, AutoForwardDiff
using ArrayInterface: ArrayInterface
using BatchedRoutines: BatchedRoutines, batched_pickchunksize, _assert_type
using BatchedRoutines: BatchedRoutines, batched_pickchunksize, _assert_type,
UniformBlockDiagonalOperator, getdata
using ChainRulesCore: ChainRulesCore, NoTangent
using ConcreteStructs: @concrete
using FastClosures: @closure
Expand Down Expand Up @@ -30,6 +31,21 @@ function BatchedRoutines._batched_gradient(::AutoReverseDiff, f::F, u) where {F}
return ∂u
end

# ReverseDiff compatible `UniformBlockDiagonalOperator`
@inline function ReverseDiff.track(
op::UniformBlockDiagonalOperator, tp::ReverseDiff.InstructionTape)
return UniformBlockDiagonalOperator(ReverseDiff.track(getdata(op), tp))
end

@inline function ReverseDiff.deriv(x::UniformBlockDiagonalOperator)
return UniformBlockDiagonalOperator(ReverseDiff.deriv(getdata(x)))
end

@inline function ReverseDiff.value!(
op::UniformBlockDiagonalOperator, val::UniformBlockDiagonalOperator)
ReverseDiff.value!(getdata(op), getdata(val))
end

# Chain rules integration
function BatchedRoutines.batched_jacobian(
ad, f::F, x::AbstractMatrix{<:ReverseDiff.TrackedReal}) where {F}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module BatchedRoutinesSciMLSensitivityExt

using ADTypes: AutoForwardDiff, AutoFiniteDiff
using BatchedRoutines: BatchedRoutines, BatchedNonlinearSolution
using FastClosures: @closure
using LinearSolve: LinearSolve
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearSolution
using SciMLSensitivity: SciMLSensitivity, SteadyStateAdjoint, ZygoteVJP
using Zygote: Zygote

include("steadystateadjoint.jl")

end
92 changes: 92 additions & 0 deletions ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import SciMLSensitivity: SteadyStateAdjointProblem, SteadyStateAdjointSensitivityFunction

function SteadyStateAdjointProblem(
sol::BatchedNonlinearSolution, sensealg::SteadyStateAdjoint, alg,
dgdu::DG1=nothing, dgdp::DG2=nothing, g::G=nothing; kwargs...) where {DG1, DG2, G}
@assert sol.prob isa NonlinearProblem
(; f, p, u0) = sol.prob
f = SciMLBase.ODEFunction(f)

@assert !SciMLBase.isinplace(sol.prob) "Adjoint for Batched Problems does not support \
inplace problems."
@assert ndims(u0)==2 "u0 must be a matrix."
@assert dgdu!==nothing "`dgdu` must be specified. Automatic differentiation is not \
currently implemented for this part."
@assert sensealg.autojacvec isa ZygoteVJP

dgdu === nothing &&
dgdp === nothing &&
g === nothing &&
error("Either `dgdu`, `dgdp`, or `g` must be specified.")

needs_jac = ifelse(SciMLBase.has_adjoint(f),
false,
ifelse(sensealg.linsolve === nothing, size(u0, 1) 50,
SciMLSensitivity.__needs_concrete_A(sensealg.linsolve)))

p === SciMLBase.NullParameters() &&
error("Your model does not have parameters, and thus it is impossible to calculate \
the derivative of the solution with respect to the parameters. Your model \
must have parameters to use parameter sensitivity calculations!")

y = sol.u

if needs_jac
if SciMLBase.has_jac(f)
J = BatchedRoutines._wrap_batched_operator(f.jac(y, p, nothing))
else
uf = SciMLBase.UJacobianWrapper(f, nothing, p)
if SciMLSensitivity.alg_autodiff(sensealg)
J = BatchedRoutines.batched_jacobian(AutoFiniteDiff(), uf, y)
else
J = BatchedRoutines.batched_jacobian(AutoForwardDiff(), uf, y)
end
end
end

if dgdp === nothing && g === nothing
dgdu_val = similar(u0, length(u0))
dgdp_val = nothing
else
dgdu_val, dgdp_val = similar(u0, length(u0)), similar(u0, length(p))
end

if dgdu !== nothing
dgdu(dgdu_val, y, p, nothing, nothing)
else
error("Not implemented yet")
end

if !needs_jac # Construct an operator and use Jacobian-Free Linear Solve
linsolve = if sensealg.linsolve === nothing
LinearSolve.SimpleGMRES(; blocksize=size(u0, 1))
else
sensealg.linsolve
end
usize = size(y)
__f = @closure y -> vec(f(reshape(y, usize), p, nothing))
operator = SciMLSensitivity.VecJac(__f, vec(y);
autodiff=SciMLSensitivity.get_autodiff_from_vjp(sensealg.autojacvec))
linear_problem = SciMLBase.LinearProblem(operator, dgdu_val)
linsol = SciMLBase.solve(
linear_problem, linsolve; alias_A=true, sensealg.linsolve_kwargs...)
else
linear_problem = SciMLBase.LinearProblem(J', dgdu_val)
linsol = SciMLBase.solve(
linear_problem, sensealg.linsolve; alias_A=true, sensealg.linsolve_kwargs...)
end
λ = linsol.u

_, pb_f = Zygote.pullback(@closure(p->vec(f(y, p, nothing))), p)
∂p = only(pb_f(λ))
∂p === nothing &&
!sensealg.autojacvec.allow_nothing &&
throw(SciMLSensitivity.ZygoteVJPNothingError())

if g !== nothing || dgdp !== nothing
error("Not implemented yet")
else
SciMLSensitivity.recursive_neg!(∂p)
return ∂p
end
end
19 changes: 16 additions & 3 deletions src/BatchedRoutines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ module BatchedRoutines
import PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using ADTypes: AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoSparseForwardDiff,
AutoSparsePolyesterForwardDiff, AutoPolyesterForwardDiff, AutoZygote
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
AutoReverseDiff, AutoSparseForwardDiff, AutoSparsePolyesterForwardDiff,
AutoPolyesterForwardDiff, AutoZygote
using Adapt: Adapt
using ArrayInterface: ArrayInterface, parameterless_type
using ChainRulesCore: ChainRulesCore, HasReverseMode, NoTangent, RuleConfig
Expand All @@ -14,6 +15,8 @@ import PrecompileTools: @recompile_invalidations
using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero,
mul!, pinv
using LuxDeviceUtils: LuxDeviceUtils, get_device
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem,
NonlinearSolution, ReturnCode
using SciMLOperators: SciMLOperators, AbstractSciMLOperator
end

Expand All @@ -40,23 +43,33 @@ const AutoAllForwardDiff{CK} = Union{<:AutoForwardDiff{CK}, <:AutoSparseForwardD
const BatchedVector{T} = AbstractMatrix{T}
const BatchedMatrix{T} = AbstractArray{T, 3}

abstract type AbstractBatchedNonlinearAlgorithm <: SciMLBase.AbstractNonlinearAlgorithm end

@inline _is_extension_loaded(::Val) = false

include("operator.jl")

include("api.jl")
include("helpers.jl")

include("operator.jl")
include("factorization.jl")

include("nlsolve/utils.jl")
include("nlsolve/batched_raphson.jl")

include("impl/batched_mul.jl")
include("impl/batched_gmres.jl")

include("chainrules.jl")

# Core
export AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoZygote
export batched_adjoint, batched_gradient, batched_jacobian, batched_pickchunksize,
batched_mul, batched_pinv, batched_transpose
export batchview, nbatches
export UniformBlockDiagonalOperator

# Nonlinear Solvers
export BatchedSimpleNewtonRaphson, BatchedSimpleGaussNewton

end
Loading

0 comments on commit a2fb973

Please sign in to comment.