Skip to content

Commit

Permalink
Performance fixes (#100)
Browse files Browse the repository at this point in the history
* Performance fixes

* Simplify vectorization and devectorization

* Switch signs

* Fix type inference

* Remove conversion

* Remove type piracy

* Uncomment stuff
  • Loading branch information
gdalle authored Aug 9, 2023
1 parent fd643ea commit da589a0
Show file tree
Hide file tree
Showing 14 changed files with 98 additions and 194 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
ImplicitDifferentiationChainRulesCoreExt = "ChainRulesCore"
ImplicitDifferentiationComponentArraysExt = "ComponentArrays"
ImplicitDifferentiationForwardDiffExt = "ForwardDiff"
ImplicitDifferentiationStaticArraysExt = "StaticArrays"
ImplicitDifferentiationZygoteExt = "Zygote"
Expand Down
6 changes: 2 additions & 4 deletions benchmark/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -416,14 +416,12 @@ version = "0.5.0-DEV"

[deps.ImplicitDifferentiation.extensions]
ImplicitDifferentiationChainRulesCoreExt = "ChainRulesCore"
ImplicitDifferentiationComponentArraysExt = "ComponentArrays"
ImplicitDifferentiationForwardDiffExt = "ForwardDiff"
ImplicitDifferentiationStaticArraysExt = "StaticArrays"
ImplicitDifferentiationZygoteExt = "Zygote"

[deps.ImplicitDifferentiation.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down Expand Up @@ -479,9 +477,9 @@ version = "2.1.91+0"

[[deps.Krylov]]
deps = ["LinearAlgebra", "Printf", "SparseArrays"]
git-tree-sha1 = "6dc4ad9cd74ad4ca0a8e219e945dbd22039f2125"
git-tree-sha1 = "fbda7c58464204d92f3b158578fb0b3d4224cea5"
uuid = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
version = "0.9.2"
version = "0.9.3"

[[deps.LAME_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand Down
8 changes: 3 additions & 5 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.2"
manifest_format = "2.0"
project_hash = "237e72a31eedf37d2697a165f32eb8d2aada085c"
project_hash = "e53e426683e9d72288e035d7fd7b4528169a5566"

[[deps.AMD]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse_jll"]
Expand Down Expand Up @@ -324,14 +324,12 @@ version = "0.5.0-DEV"

[deps.ImplicitDifferentiation.extensions]
ImplicitDifferentiationChainRulesCoreExt = "ChainRulesCore"
ImplicitDifferentiationComponentArraysExt = "ComponentArrays"
ImplicitDifferentiationForwardDiffExt = "ForwardDiff"
ImplicitDifferentiationStaticArraysExt = "StaticArrays"
ImplicitDifferentiationZygoteExt = "Zygote"

[deps.ImplicitDifferentiation.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down Expand Up @@ -364,9 +362,9 @@ version = "0.21.4"

[[deps.Krylov]]
deps = ["LinearAlgebra", "Printf", "SparseArrays"]
git-tree-sha1 = "6dc4ad9cd74ad4ca0a8e219e945dbd22039f2125"
git-tree-sha1 = "fbda7c58464204d92f3b158578fb0b3d4224cea5"
uuid = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
version = "0.9.2"
version = "0.9.3"

[[deps.LDLFactorizations]]
deps = ["AMD", "LinearAlgebra", "SparseArrays", "Test"]
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Expand Down
7 changes: 6 additions & 1 deletion examples/3_tricks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ We demonstrate several features that may come in handy for some users.
using ComponentArrays
using ForwardDiff
using ImplicitDifferentiation
using Krylov
using LinearAlgebra
using Random
using Test #src
Expand Down Expand Up @@ -48,7 +49,11 @@ end

implicit_components = ImplicitFunction(forward_components, conditions_components)

# This is how it behaves.
# Since `ComponentVector`s are not yet compatible with iterative solvers from Krylov.jl, we (temporarily) need a bit of type piracy to make it work

Krylov.ktypeof(::ComponentVector{T,V}) where {T,V} = V

# Now we're good to go.

a, b, m = rand(2), rand(3), 7
x = ComponentVector(; a=a, b=b, m=m)
Expand Down
45 changes: 21 additions & 24 deletions ext/ImplicitDifferentiationChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and settin
Positional and keyword arguments are passed to both `implicit.forward` and `implicit.conditions`.
"""
function ChainRulesCore.rrule(
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}, args...; kwargs...
) where {R}
rc::RuleConfig, implicit::ImplicitFunction, x::X, args...; kwargs...
) where {R,X<:AbstractArray{R}}
y_or_yz = implicit(x, args...; kwargs...)
backend = reverse_conditions_backend(rc, implicit)
Aᵀ_op, Bᵀ_op = reverse_operators(backend, implicit, x, y_or_yz, args; kwargs)
Aᵀ_vec, pbBᵀ = reverse_operators(backend, implicit, x, y_or_yz, args; kwargs)
byproduct = y_or_yz isa Tuple
nbargs = length(args)
implicit_pullback = ImplicitPullback{byproduct,nbargs}(
Aᵀ_op, Bᵀ_op, implicit.linear_solver, x
implicit_pullback = ImplicitPullback{byproduct,nbargs,X}(
Aᵀ_vec, pbBᵀ, implicit.linear_solver
)
return y_or_yz, implicit_pullback
end
Expand All @@ -43,16 +43,15 @@ function reverse_conditions_backend(
return implicit.conditions_backend
end

struct ImplicitPullback{byproduct,nbargs,A,B,L,X}
Aᵀ_op::A
Bᵀ_op::B
struct ImplicitPullback{byproduct,nbargs,X,A,B,L}
Aᵀ_vec::A
pbBᵀ::B
linear_solver::L
x::X

function ImplicitPullback{byproduct,nbargs}(
Aᵀ_op::A, Bᵀ_op::B, linear_solver::L, x::X
) where {byproduct,nbargs,A,B,L,X}
return new{byproduct,nbargs,A,B,L,X}(Aᵀ_op, Bᵀ_op, linear_solver, x)
function ImplicitPullback{byproduct,nbargs,X}(
Aᵀ_vec::A, pbBᵀ::B, linear_solver::L
) where {byproduct,nbargs,X,A,B,L}
return new{byproduct,nbargs,X,A,B,L}(Aᵀ_vec, pbBᵀ, linear_solver)
end
end

Expand All @@ -64,23 +63,21 @@ function (implicit_pullback::ImplicitPullback{false})(dy)
return _apply(implicit_pullback, dy)
end

function unimplemented_tangent(i)
function unimplemented_tangent(_)
return @not_implemented(
"Tangents for positional arguments of an ImplicitFunction beyond x (the first one) are not implemented"
)
end

function _apply(
implicit_pullback::ImplicitPullback{byproduct,nbargs}, dy
) where {byproduct,nbargs}
@unpack Aᵀ_op, Bᵀ_op, linear_solver, x = implicit_pullback
R = eltype(x)
dy_vec = convert(AbstractVector{R}, vec(unthunk(dy)))
dc_vec = solve(linear_solver, Aᵀ_op, dy_vec)
dx_vec = similar(vec(x))
mul!(dx_vec, Bᵀ_op, dc_vec)
lmul!(-one(R), dx_vec)
dx = reshape(dx_vec, size(x))
implicit_pullback::ImplicitPullback{byproduct,nbargs,X}, dy_thunk
) where {byproduct,nbargs,X}
@unpack Aᵀ_vec, pbBᵀ, linear_solver = implicit_pullback
dy = unthunk(dy_thunk)
dy_vec = vec(dy)
dc_vec = solve(linear_solver, Aᵀ_vec, -dy_vec)
dc = reshape(dc_vec, size(dy))
dx = only(pbBᵀ(dc)) # TODO: type inference fails here
return (NoTangent(), dx, ntuple(unimplemented_tangent, nbargs)...)
end

Expand Down
13 changes: 0 additions & 13 deletions ext/ImplicitDifferentiationComponentArraysExt.jl

This file was deleted.

15 changes: 6 additions & 9 deletions ext/ImplicitDifferentiationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Overload an [`ImplicitFunction`](@ref) on dual numbers to ensure compatibility w
This is only available if ForwardDiff.jl is loaded (extension).
We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv = u`.
We compute the Jacobian-vector product `Jv` by solving `Au = -Bv` and setting `Jv = u`.
Positional and keyword arguments are passed to both `implicit.forward` and `implicit.conditions`.
"""
function (implicit::ImplicitFunction)(
Expand All @@ -30,16 +30,13 @@ function (implicit::ImplicitFunction)(
y = _output(y_or_yz)

backend = forward_conditions_backend(implicit)
A_op, B_op = forward_operators(backend, implicit, x, y_or_yz, args; kwargs)

x_and_dx_vec = vec(x_and_dx)
A_vec, pfB = forward_operators(backend, implicit, x, y_or_yz, args; kwargs)

dy = ntuple(Val(N)) do k
dₖx_vec = partials.(x_and_dx_vec, k)
Bdₖx = similar(vec(y))
mul!(Bdₖx, B_op, dₖx_vec)
dₖy_vec = solve(implicit.linear_solver, A_op, Bdₖx)
lmul!(-one(R), dₖy_vec)
dₖx = partials.(x_and_dx, k)
dₖc = only(pfB(dₖx))
dₖc_vec = vec(dₖc)
dₖy_vec = solve(implicit.linear_solver, A_vec, -dₖc_vec)
reshape(dₖy_vec, size(y))
end

Expand Down
3 changes: 0 additions & 3 deletions ext/ImplicitDifferentiationStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ else
end

import ImplicitDifferentiation: ImplicitDifferentiation, DirectLinearSolver
using Krylov: Krylov
using LinearAlgebra: lu, mul!

function ImplicitDifferentiation.presolve(::DirectLinearSolver, A, y::StaticArray)
Expand All @@ -23,6 +22,4 @@ function ImplicitDifferentiation.presolve(::DirectLinearSolver, A, y::StaticArra
return lu(A_static)
end

Krylov.ktypeof(::StaticVector{S,T}) where {S,T} = Vector{T} # TODO: type piracy

end
3 changes: 0 additions & 3 deletions src/ImplicitDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ export AbstractLinearSolver, IterativeLinearSolver, DirectLinearSolver
include("../ext/ImplicitDifferentiationChainRulesCoreExt.jl")
function __init__()
# Loaded conditionally on Julia < 1.9
@require ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" begin
include("../ext/ImplicitDifferentiationComponentArraysExt.jl")
end
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
include("../ext/ImplicitDifferentiationForwardDiffExt.jl")
end
Expand Down
4 changes: 2 additions & 2 deletions src/implicit_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ Wrapper for an implicit function defined by a forward mapping `y` and a set of c
An `ImplicitFunction` object behaves like a function, and every call is differentiable with respect to the first argument `x`.
When a derivative is queried, the Jacobian of `y` is computed using the implicit function theorem:
∂/∂y c(x, y(x)) * -∂/∂x y(x) = -∂/∂x c(x, y(x))
∂/∂y c(x, y(x)) * ∂/∂x y(x) = -∂/∂x c(x, y(x))
This requires solving a linear system `A * J = -B`.
This requires solving a linear system `A * J = -B` where `A = ∂c/∂y`, `B = ∂c/∂x` and `J = ∂y/∂x`.
# Fields
Expand Down
Loading

0 comments on commit da589a0

Please sign in to comment.