Skip to content

Commit

Permalink
Fix wrong nb of pullbacks (#91)
Browse files Browse the repository at this point in the history
* Fix wrong nb of pullbacks

* Remove bad backends

* More systematic tests for all 3 variants, JET still skipped
  • Loading branch information
gdalle authored Aug 5, 2023
1 parent 450d8aa commit 4d08067
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 23 deletions.
2 changes: 1 addition & 1 deletion ext/ImplicitDifferentiationChainRulesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function _apply(implicit_pullback::ImplicitPullback, dy)
mul!(dx_vec, Bᵀ_op, dF_vec)
lmul!(-one(R), dx_vec)
dx = reshape(dx_vec, size(x))
return (NoTangent(), dx, NoTangent())
return (NoTangent(), dx)
end

end
81 changes: 59 additions & 22 deletions test/systematic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,26 @@ function is_static_array(a)
end

function mysqrt(x::AbstractArray)
return sqrt.(identity_break_autodiff(x))
return sqrt.(identity_break_autodiff(abs.(x)))
end

function make_implicit_sqrt(; kwargs...)
forward(x) = mysqrt(x)
conditions(x, y) = y .^ 2 .- x
conditions(x, y) = y .^ 2 .- abs.(x)
implicit = ImplicitFunction(forward, conditions; kwargs...)
return implicit
end

function make_implicit_sqrt_byproduct(; kwargs...)
forward(x) = mysqrt(x), 0.5
conditions(x, y, z) = y .^ (1 / z) .- x
conditions(x, y, z) = y .^ (1 / z) .- abs.(x)
implicit = ImplicitFunction(forward, conditions; kwargs...)
return implicit
end

function make_implicit_power_kwargs(; kwargs...)
forward(x; p) = x .^ p
conditions(x, y; p) = y .^ (1 / p) .- x
forward(x; p) = abs.(x) .^ p
conditions(x, y; p) = y .^ (1 / p) .- abs.(x)
implicit = ImplicitFunction(forward, conditions; kwargs...)
return implicit
end
Expand All @@ -77,16 +77,23 @@ function test_implicit_call(x; kwargs...)
@test y3 y_true
@test z2 0.5
end

if typeof(x) <: StaticArray
@testset "Static arrays" begin
@test is_static_array(y1)
@test is_static_array(y2)
@test is_static_array(y3)
end
end

@testset "JET" begin
@test_opt target_modules = (ID,) imf1(x)
@test_opt target_modules = (ID,) imf2(x)
@test_opt target_modules = (ID,) imf3(x; p=0.5)

@test_call target_modules = (ID,) imf1(x)
@test_call target_modules = (ID,) imf2(x)
@test_call target_modules = (ID,) imf3(x; p=0.5)
end
end

Expand Down Expand Up @@ -119,8 +126,13 @@ function test_implicit_duals(x; kwargs...)
end

@testset "JET" begin
@test_opt target_modules = (ID,) imf1(x_and_dx)
@test_opt target_modules = (ID,) imf2(x_and_dx)
@test_opt target_modules = (ID,) imf3(x_and_dx; p=0.5)

@test_call target_modules = (ID,) imf1(x_and_dx)
@test_call target_modules = (ID,) imf2(x_and_dx)
@test_call target_modules = (ID,) imf3(x_and_dx; p=0.5)
end
end

Expand All @@ -137,18 +149,20 @@ function test_implicit_rrule(rc, x; kwargs...)
(y2, z2), pb2 = @inferred rrule(rc, imf2, x)
y3, pb3 = @inferred rrule(rc, imf3, x; p=0.5)

dimp1, dx1 = @inferred pb1(dy)
dimp2, dx2 = @inferred pb2((dy, dz))
dimp3, dx3 = @inferred pb3(dy)
dimf1, dx1 = @inferred pb1(dy)
dimf2, dx2 = @inferred pb2((dy, dz))
dimf3, dx3 = @inferred pb3(dy)

@testset "Pullbacks" begin
@test y1 y_true
@test y2 y_true
@test y3 y_true
@test z2 0.5
@test dimp1 isa NoTangent
@test dimp2 isa NoTangent
@test dimp3 isa NoTangent

@test dimf1 isa NoTangent
@test dimf2 isa NoTangent
@test dimf3 isa NoTangent

@test size(dx1) == size(x)
@test size(dx2) == size(x)
@test size(dx3) == size(x)
Expand All @@ -159,22 +173,35 @@ function test_implicit_rrule(rc, x; kwargs...)
@test is_static_array(y1)
@test is_static_array(y2)
@test is_static_array(y3)

@test is_static_array(dx1)
@test is_static_array(dx2)
@test is_static_array(dx3)
end
end

@testset "JET" begin
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf1, x)
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf2, x)
@test_skip @test_opt target_modules = (ID,) pb2(dy)
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf3, x; p=0.5)

@test_skip @test_opt target_modules = (ID,) pb1(dy)
@test_skip @test_opt target_modules = (ID,) pb2((dy, dz))
@test_skip @test_opt target_modules = (ID,) pb3(dy)

@test_call target_modules = (ID,) rrule(rc, imf1, x)
@test_call target_modules = (ID,) rrule(rc, imf2, x)
@test_call target_modules = (ID,) pb2(dy)
@test_call target_modules = (ID,) rrule(rc, imf3, x; p=0.5)

@test_call target_modules = (ID,) pb1(dy)
@test_call target_modules = (ID,) pb2((dy, dz))
@test_call target_modules = (ID,) pb3(dy)
end

@testset "ChainRulesTestUtils" begin
# Skipped because of https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/232 and because it detects weird type instabilities
@test_skip test_rrule(imf2, x)
test_rrule(rc, imf1, x; atol=1e-2)
test_rrule(rc, imf2, x; atol=1e-2)
test_rrule(rc, imf3, x; atol=1e-2, fkwargs=(p=0.5,))
end
end

Expand Down Expand Up @@ -234,15 +261,25 @@ end

## Actual loop

x_candidates = (
rand(2), rand(2, 3, 4), SVector{2}(rand(2)), SArray{Tuple{2,3,4}}(rand(2, 3, 4))
linear_solver_candidates = (
IterativeLinearSolver(), #
DirectLinearSolver(), #
)

conditions_backend_candidates = (
nothing, #
AD.ForwardDiffBackend(), #
# AD.ZygoteBackend(), # TODO: failing
# AD.ReverseDiffBackend() # TODO: failing
# AD.FiniteDifferencesBackend() # TODO: failing
);

linear_solver_candidates = (IterativeLinearSolver(), DirectLinearSolver())
conditions_backend_candidates = (nothing, AD.ForwardDiffBackend());
# conditions_backend_failing_candidates = (
# AD.ZygoteBackend(), AD.FiniteDifferencesBackend, AD.ReverseDiffBackend()()
# ) # TODO: understand why
x_candidates = (
rand(2), #
rand(2, 3, 4), #
SVector{2}(rand(2)), #
SArray{Tuple{2,3,4}}(rand(2, 3, 4)), #
);

for linear_solver in linear_solver_candidates,
conditions_backend in conditions_backend_candidates,
Expand Down

0 comments on commit 4d08067

Please sign in to comment.