Skip to content

Commit

Permalink
Better document test_rrule tweak (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Aug 10, 2023
1 parent f5b1efb commit e725c2f
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions test/systematic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,22 @@ function make_implicit_sqrt(; kwargs...)
end

function make_implicit_sqrt_byproduct(; kwargs...)
forward(x) = mysqrt(x), eltype(x)(2) # fails rrule tests with 2 integer, probably due to https://juliadiff.org/ChainRulesTestUtils.jl/dev/index.html#Specifying-Tangents
conditions(x, y, z::Number) = y .^ z .- abs.(change_shape(x))
forward(x) = mysqrt(x), 2
conditions(x, y, z::Integer) = y .^ z .- abs.(change_shape(x))
implicit = ImplicitFunction(forward, conditions; kwargs...)
return implicit
end

function make_implicit_power_args(; kwargs...)
forward(x, p::Number) = mypower(x, one(eltype(x)) / p)
conditions(x, y, p::Number) = y .^ p .- abs.(change_shape(x))
forward(x, p::Integer) = mypower(x, one(eltype(x)) / p)
conditions(x, y, p::Integer) = y .^ p .- abs.(change_shape(x))
implicit = ImplicitFunction(forward, conditions; kwargs...)
return implicit
end

function make_implicit_power_kwargs(; kwargs...)
forward(x; p::Number) = mypower(x, one(eltype(x)) / p)
conditions(x, y; p::Number) = y .^ p .- abs.(change_shape(x))
forward(x; p::Integer) = mypower(x, one(eltype(x)) / p)
conditions(x, y; p::Integer) = y .^ p .- abs.(change_shape(x))
implicit = ImplicitFunction(forward, conditions; kwargs...)
return implicit
end
Expand Down Expand Up @@ -243,7 +243,7 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}

@testset "ChainRulesTestUtils" begin
test_rrule(rc, imf1, x; atol=1e-2)
test_rrule(rc, imf2, x; atol=5e-2)
test_rrule(rc, imf2, x; atol=5e-2, output_tangent=(dy, 0)) # see issue https://github.com/gdalle/ImplicitDifferentiation.jl/issues/112
test_rrule(rc, imf3, x, 2; atol=1e-2)
test_rrule(rc, imf4, x; atol=1e-2, fkwargs=(p=2,))
end
Expand Down

0 comments on commit e725c2f

Please sign in to comment.