Skip to content

Commit

Permalink
NaNs for linear solvers when failed (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Aug 4, 2023
1 parent 7204b22 commit d97362e
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 90 deletions.
9 changes: 4 additions & 5 deletions ext/ImplicitDifferentiationStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@ else
using ..StaticArrays: StaticArray, MMatrix
end

import ImplicitDifferentiation: ImplicitDifferentiation, DirectLinearSolver
import ImplicitDifferentiation: ImplicitDifferentiation, DirectLinearSolver, solve
using LinearAlgebra: lu, mul!

function ImplicitDifferentiation.presolve(
::DirectLinearSolver, A, y::StaticArray{S,T,N}
) where {S,T,N}
function ImplicitDifferentiation.presolve(::DirectLinearSolver, A, y::StaticArray)
T = eltype(A)
m = length(y)
A_static = zero(MMatrix{m,m,T})
v = vec(similar(y, T))
for i in axes(A_static, 2)
v = vec(similar(y))
v .= zero(T)
v[i] = one(T)
mul!(@view(A_static[:, i]), A, v)
Expand Down
2 changes: 1 addition & 1 deletion src/ImplicitDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module ImplicitDifferentiation

using Krylov: KrylovStats, gmres
using LinearOperators: LinearOperators, LinearOperator
using LinearAlgebra: lu, SingularException
using LinearAlgebra: lu, SingularException, issuccess
using Requires: @require
using SimpleUnPack: @unpack

Expand Down
40 changes: 24 additions & 16 deletions src/linear_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
All linear solvers used within an `ImplicitFunction` must satisfy this interface.
It can be useful to roll out your own solver if you need more fine-grained control on convergence / speed / behavior in case of singularity.
Check out the source code of `IterativeLinearSolver` and `DirectLinearSolver` for implementation examples.
# Required methods
- `presolve(linear_solver, A, y)`: return a matrix-like object `A` for which it is cheaper to solve several linear systems with different vectors `b` (a typical example would be to perform LU factorization).
- `solve(linear_solver, A, b)`: return a tuple `(x, stats)` where `x` satisfies `Ax = b` and `stats.solved ∈ {true, false}`.
- `presolve(linear_solver, A, y)`: Returns a matrix-like object `A` for which it is cheaper to solve several linear systems with different vectors `b` (a typical example would be to perform LU factorization).
- `solve(linear_solver, A, b)`: Returns a vector `x` satisfying `Ax = b`. If the linear system has not been solved to satisfaction, every element of `x` should be a `NaN` of the appropriate floating point type.
"""
abstract type AbstractLinearSolver end

Expand All @@ -20,11 +23,15 @@ struct IterativeLinearSolver <: AbstractLinearSolver end
presolve(::IterativeLinearSolver, A, y) = A

function solve(::IterativeLinearSolver, A, b)
T = float(promote_type(eltype(A), eltype(b)))
x, stats = gmres(A, b)
if !stats.solved
throw(SolverFailureException(gmres, stats))
x_maybenan = similar(x, T)
if stats.solved && !stats.inconsistent
x_maybenan .= x
else
x_maybenan .= convert(T, NaN)
end
return x
return x_maybenan
end

"""
Expand All @@ -34,17 +41,18 @@ An implementation of `AbstractLinearSolver` using the built-in backslash operato
"""
struct DirectLinearSolver <: AbstractLinearSolver end

presolve(::DirectLinearSolver, A, y) = lu(Matrix(A))
solve(::DirectLinearSolver, A, b) = A \ b

struct SolverFailureException{A,B} <: Exception
solver::A
stats::B
function presolve(::DirectLinearSolver, A, y)
return lu(Matrix(A); check=false)
end

function Base.show(io::IO, sfe::SolverFailureException)
return println(
io,
"SolverFailureException: \n Linear solver: $(sfe.solver) \n Solver stats: $(string(sfe.stats))",
)
function solve(::DirectLinearSolver, A_lu, b)
# workaround for https://github.com/JuliaArrays/StaticArrays.jl/issues/1190
T = float(promote_type(eltype(A_lu.L), eltype(A_lu.U), eltype(b)))
x_maybenan = Vector{T}(undef, size(A_lu.L, 2))
if issuccess(A_lu)
x_maybenan .= A_lu \ b
else
x_maybenan .= convert(T, NaN)
end
return x_maybenan
end
54 changes: 54 additions & 0 deletions test/errors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using ForwardDiff
using ImplicitDifferentiation
using Test
using Zygote

@testset "Byproduct handling" begin
f = (_) -> [1.0, 2.0]
c = (_, _) -> [0.0, 0.0]
imf1 = ImplicitFunction(f, c, HandleByproduct())
@test_throws ArgumentError imf1(zeros(2))
f = (_) -> [1.0, 2.0, 3.0]
imf2 = ImplicitFunction(f, c, HandleByproduct())
@test_throws ArgumentError imf2(zeros(2))
end

@testset "Only accept one array" begin
f = (_) -> [1.0]
c = (_, _) -> [0.0]
imf = ImplicitFunction(f, c)
@test_throws MethodError imf("hello")
@test_throws MethodError imf([1.0], [1.0])
end

@testset verbose = true "Derivative NaNs" begin
x = zeros(Float32, 2)
linear_solvers = (IterativeLinearSolver(), DirectLinearSolver())
@testset "Infinite derivative" begin
f = x -> sqrt.(x) # nondifferentiable at 0
c = (x, y) -> y .^ 2 .- x
for linear_solver in linear_solvers
@testset "$(typeof(linear_solver))" begin
implicit = ImplicitFunction(f, c; linear_solver)
J1 = ForwardDiff.jacobian(implicit, x)
J2 = Zygote.jacobian(implicit, x)[1]
@test all(isnan, J1) && eltype(J1) == Float32
@test all(isnan, J2) && eltype(J2) == Float32
end
end
end

@testset "Singular linear system" begin
f = x -> x # wrong solver
c = (x, y) -> (x .+ 1) .^ 2 .- y .^ 2
for linear_solver in linear_solvers
@testset "$(typeof(linear_solver))" begin
implicit = ImplicitFunction(f, c; linear_solver)
J1 = ForwardDiff.jacobian(implicit, x)
J2 = Zygote.jacobian(implicit, x)[1]
@test all(isnan, J1) && eltype(J1) == Float32
@test all(isnan, J2) && eltype(J2) == Float32
end
end
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples")
@testset verbose = true "Systematic" begin
include("systematic.jl")
end
@testset verbose = true "Errors" begin
include("errors.jl")
end
@testset verbose = true "Examples" begin
for file in readdir(EXAMPLES_DIR_JL)
path = joinpath(EXAMPLES_DIR_JL, file)
Expand Down
153 changes: 85 additions & 68 deletions test/systematic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,95 +62,122 @@ function make_implicit_sqrt_byproduct(; kwargs...)
end

function test_implicit_call(implicit, x; y_true)
@test_throws MethodError implicit("hello")
@test_throws MethodError implicit(x, x)
y1 = @inferred implicit(x)
y2, z2 = @inferred implicit(x, ReturnByproduct())
@test y1 y_true
@test y2 y_true
@testset "Exact value" begin
@test y1 y_true
@test y2 y_true
end
@testset "Byproduct" begin
if handles_byproduct(implicit)
@test z2 == 2
else
@test z2 == 0
end
end
if typeof(x) <: StaticArray
@test is_static_array(y1)
@test is_static_array(y2)
@testset "Static arrays" begin
@test is_static_array(y1)
@test is_static_array(y2)
end
end
if handles_byproduct(implicit)
@test z2 == 2
else
@test z2 == 0
@testset "JET" begin
@test_opt target_modules = (ImplicitDifferentiation,) implicit(x)
@test_call target_modules = (ImplicitDifferentiation,) implicit(x)
end
@test_opt target_modules = (ImplicitDifferentiation,) implicit(x)
@test_call target_modules = (ImplicitDifferentiation,) implicit(x)
end

function test_implicit_forward(implicit, x; y_true, J_true)
# High-level
J1 = ForwardDiff.jacobian(implicit, x)
J2 = ForwardDiff.jacobian(x -> implicit(x, ReturnByproduct())[1], x)
@test J1 J_true
@test J2 J_true
@testset "Exact Jacobian" begin
@test J1 J_true
@test J2 J_true
end
# Low-level
x_and_dx = ForwardDiff.Dual.(x, ((0, 0),))
y_and_dy1 = @inferred implicit(x_and_dx)
y_and_dy2, z2 = @inferred implicit(x_and_dx, ReturnByproduct())
@test size(y_and_dy1) == size(y_true)
@test size(y_and_dy2) == size(y_true)
@test ForwardDiff.value.(y_and_dy1) y_true
@test ForwardDiff.value.(y_and_dy2) y_true
@testset "Dual numbers" begin
@test size(y_and_dy1) == size(y_true)
@test size(y_and_dy2) == size(y_true)
@test ForwardDiff.value.(y_and_dy1) y_true
@test ForwardDiff.value.(y_and_dy2) y_true
end
@testset "Byproduct" begin
if handles_byproduct(implicit)
@test z2 == 2
else
@test z2 == 0
end
end
if typeof(x) <: StaticArray
@test is_static_array(y_and_dy1)
@test is_static_array(y_and_dy2)
@testset "Static arrays" begin
@test is_static_array(y_and_dy1)
@test is_static_array(y_and_dy2)
end
end
if handles_byproduct(implicit)
@test z2 == 2
else
@test z2 == 0
@testset "JET" begin
@test_opt target_modules = (ImplicitDifferentiation,) implicit(x_and_dx)
@test_call target_modules = (ImplicitDifferentiation,) implicit(x_and_dx)
end
@test_opt target_modules = (ImplicitDifferentiation,) implicit(x_and_dx)
@test_call target_modules = (ImplicitDifferentiation,) implicit(x_and_dx)
end

function test_implicit_reverse(implicit, x; y_true, J_true)
# High-level
J1 = Zygote.jacobian(implicit, x)[1]
J2 = Zygote.jacobian(x -> implicit(x, ReturnByproduct())[1], x)[1]
@test J1 J_true
@test J2 J_true
@testset "Exact Jacobian" begin
@test J1 J_true
@test J2 J_true
end
# Low-level
y1, pb1 = @inferred rrule(ZygoteRuleConfig(), implicit, x)
(y2, z2), pb2 = @inferred rrule(ZygoteRuleConfig(), implicit, x, ReturnByproduct())
@test y1 y_true
@test y2 y_true
dy1 = zeros(eltype(y1), size(y1)...)
dy2 = zeros(eltype(y2), size(y2)...)
dz2 = nothing
dimp1, dx1 = @inferred pb1(dy1)
dimp2, dx2, drp = @inferred pb2((dy2, dz2))
@test size(dx1) == size(x)
@test size(dx2) == size(x)
@testset "Pullbacks" begin
@test y1 y_true
@test y2 y_true
@test size(dx1) == size(x)
@test size(dx2) == size(x)
@test dimp1 isa NoTangent
@test dimp2 isa NoTangent
@test drp isa NoTangent
end
@testset "Byproduct" begin
if handles_byproduct(implicit)
@test z2 == 2
else
@test z2 == 0
end
end
if typeof(x) <: StaticArray
@test is_static_array(y1)
@test is_static_array(y2)
@test is_static_array(dx1)
@test is_static_array(dx2)
end
@test dimp1 isa NoTangent
@test dimp2 isa NoTangent
@test drp isa NoTangent
if handles_byproduct(implicit)
@test z2 == 2
else
@test z2 == 0
end
@test_skip @test_opt target_modules = (ImplicitDifferentiation,) rrule(
ZygoteRuleConfig(), implicit, x
)
@test_skip @test_opt target_modules = (ImplicitDifferentiation,) pb1(dy1)
@test_call target_modules = (ImplicitDifferentiation,) rrule(
ZygoteRuleConfig(), implicit, x
)
@test_call target_modules = (ImplicitDifferentiation,) pb1(dy1)
# Skipped because of https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/232 and because it detects weird type instabilities
@test_skip test_rrule(implicit, x)
@test_skip test_rrule(implicit, x, ReturnByproduct())
@testset "Static arrays" begin
@test is_static_array(y1)
@test is_static_array(y2)
@test is_static_array(dx1)
@test is_static_array(dx2)
end
end
@testset "JET" begin
@test_skip @test_opt target_modules = (ImplicitDifferentiation,) rrule(
ZygoteRuleConfig(), implicit, x
)
@test_skip @test_opt target_modules = (ImplicitDifferentiation,) pb1(dy1)
@test_call target_modules = (ImplicitDifferentiation,) rrule(
ZygoteRuleConfig(), implicit, x
)
@test_call target_modules = (ImplicitDifferentiation,) pb1(dy1)
end
@testset "ChainRulesTestUtils" begin
# Skipped because of https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/232 and because it detects weird type instabilities
@test_skip test_rrule(implicit, x)
@test_skip test_rrule(implicit, x, ReturnByproduct())
end
end

x_candidates = (
Expand All @@ -170,6 +197,7 @@ for linear_solver in linear_solver_candidates, x in x_candidates
implicit_sqrt = make_implicit_sqrt(; linear_solver)
implicit_sqrt_byproduct = make_implicit_sqrt_byproduct(; linear_solver)

@info "Systematic tests - $testsetname"
@testset verbose = true "$testsetname" begin
@testset "Call" begin
test_implicit_call(implicit_sqrt, x; y_true)
Expand All @@ -185,14 +213,3 @@ for linear_solver in linear_solver_candidates, x in x_candidates
end
end
end

@testset "Correct by-product handling" begin
f = (_) -> [1.0, 2.0]
c = (_, _) -> [0.0, 0.0]
imf1 = ImplicitFunction(f, c, HandleByproduct())
f = (_) -> [1.0, 2.0, 3.0]
imf2 = ImplicitFunction(f, c, HandleByproduct())
for imf in (imf1, imf2)
@test_throws ArgumentError imf(zeros(2))
end
end

0 comments on commit d97362e

Please sign in to comment.