Skip to content

Commit

Permalink
Widen in _grad! (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Apr 20, 2022
1 parent 2bf0efa commit db3cf91
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
17 changes: 12 additions & 5 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,22 @@ function _grad!(x, dx, off, flat::AbstractVector)
x′, _ = functor(typeof(x), x)
dx′, _ = functor(typeof(x), base(dx))
off′, _ = functor(typeof(x), off)
foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
for (xᵢ, dxᵢ, oᵢ) in zip(x′, dx′, off′)
flat = _grad!(xᵢ, dxᵢ, oᵢ, flat)
end
flat
end
function _grad!(x, dx, off::Integer, flat::AbstractVector)
@views flat[off .+ (1:length(x))] .+= vec(dx) # must visit all tied nodes
function _grad!(x, dx, off::Integer, flat::AbstractVector{T}) where T
dx_un = unthunk(dx)
T2 = promote_type(T, eltype(dx_un))
if T != T2 # then we must widen the type
flat = copyto!(similar(flat, T2), flat)
end
@views flat[off .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes
flat
end
_grad!(x, dx::Zero, off, flat::AbstractVector) = dx
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity
_grad!(x, dx::Zero, off, flat::AbstractVector) = flat
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = flat # ambiguity

# These are only needed for 2nd derivatives:
function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)
Expand Down
40 changes: 40 additions & 0 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,43 @@ end
4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one
end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],)
end

@testset "DiffEqFlux issue 699" begin
# The gradient of `re` is a vector into which we accumulate contributions, and the issue
# is that one contribution may have a wider type than `v`, especially for `Dual` numbers.
v, re = destructure((x=Float32[1,2], y=Float32[3,4,5]))
_, bk = Zygote.pullback(re, ones(Float32, 5))
# Testing with `Complex` isn't ideal, but this was an error on 0.2.1.
# If some upgrade inserts ProjectTo, this will fail, and can be changed:
@test bk((x=[1.0,im], y=nothing)) == ([1,im,0,0,0],)

@test bk((x=nothing, y=[10,20,30]))[1] isa Vector{Float32} # despite some ZeroTangent
@test bk((x=nothing, y=nothing)) == ([0,0,0,0,0],)
@test bk((x=nothing, y=@thunk [1,2,3] .* 10.0)) == ([0,0,10,20,30],)
@test bk((x=[1.2, 3.4], y=Float32[5,6,7])) == ([1.2, 3.4, 5, 6, 7],)
end

#=
# Adapted from https://github.com/SciML/DiffEqFlux.jl/pull/699#issuecomment-1092846657
using ForwardDiff, Zygote, Flux, Optimisers, Test
y = Float32[0.8564646, 0.21083355]
p = randn(Float32, 27);
t = 1.5f0
λ = [ForwardDiff.Dual(0.87135935, 1, 0, 0, 0, 0, 0), ForwardDiff.Dual(1.5225363, 0, 1, 0, 0, 0, 0)]
model = Chain(x -> x .^ 3,
Dense(2 => 5, tanh),
Dense(5 => 2))
p,re = Optimisers.destructure(model)
f(u, p, t) = re(p)(u)
_dy, back = Zygote.pullback(y, p) do u, p
vec(f(u, p, t))
end
tmp1, tmp2 = back(λ);
tmp1
@test tmp2 isa Vector{<:ForwardDiff.Dual}
=#

0 comments on commit db3cf91

Please sign in to comment.