Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected return values of pullback on GPU #1424

Open
YichengDWu opened this issue May 23, 2023 · 8 comments
Open

Unexpected return values of pullback on GPU #1424

YichengDWu opened this issue May 23, 2023 · 8 comments

Comments

@YichengDWu
Copy link

YichengDWu commented May 23, 2023

using CUDA, Zygote

x = CUDA.rand(2)
y = CUDA.rand(2)

f(x,y) = broadcast(tuple,x,y)

pullback(f, x, y)[1]
2-element CuArray{Tuple{ForwardDiff.Dual{Nothing, Float32, 2}, ForwardDiff.Dual{Nothing, Float32, 2}}, 1, CUDA.Mem.DeviceBuffer}:
 (Dual{Nothing}(0.19380774,1.0,0.0), Dual{Nothing}(0.0026825257,0.0,1.0))
 (Dual{Nothing}(0.28045696,1.0,0.0), Dual{Nothing}(0.62378126,0.0,1.0))

Not sure how ForwardDiff.Dual kicks in here. On CPU it's normal:

pullback(f, Vector(x), Vector(y))[1]
2-element Vector{Tuple{Float32, Float32}}:
 (0.19380774, 0.0026825257)
 (0.28045696, 0.62378126)

This causes the following bug:

julia> o, back = pullback(f, x, y)
(Tuple{ForwardDiff.Dual{Nothing, Float32, 2}, ForwardDiff.Dual{Nothing, Float32, 2}}[(Dual{Nothing}(0.8242247,1.0,0.0), Dual{Nothing}(0.5431607,0.0,1.0)), (Dual{Nothing}(0.83507776,1.0,0.0), Dual{Nothing}(0.4801908,0.0,1.0))], Zygote.var"#68#69"{typeof((f))}((f)))

julia> back(o)
(nothing, nothing)

julia> o, back = pullback(f, Vector(x), Vector(y))
(Tuple{Float32, Float32}[(0.8242247, 0.5431607), (0.83507776, 0.4801908)], Zygote.var"#68#69"{typeof((f))}((f)))

julia> back(o)
(Float32[0.8242247, 0.83507776], Float32[0.5431607, 0.4801908])
@YichengDWu YichengDWu changed the title Unexpected return values of pullback Unexpected return values of pullback on GPU May 23, 2023
@ToucheSir
Copy link
Member

Broadcasting on GPU unconditionally takes the ForwardDiff path, which is why you see Duals. But those Duals should not make it to the user, so that's a bug. Evidently https://github.com/FluxML/Zygote.jl/blob/v0.6.61/src/lib/broadcast.jl#L295 is not smart enough to recurse into the Tuples to remove any Duals there.

@YichengDWu
Copy link
Author

What about a custom struct? It just throws an error

struct Point{T}
       x::T
       y::T
       Point(x::T,y::T) where T = new{T}(x,y)
end

import Adapt
Adapt.@adapt_structure Point

f(x,y) = Point.(x,y)

julia> pullback(f, x, y)[1]
ERROR: GPU broadcast resulted in non-concrete element type Any.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] copy
    @ ~/.julia/packages/GPUArrays/g2pOV/src/host/broadcast.jl:34 [inlined]
  [3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, Zygote.var"#1550#1551"{UnionAll}, Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}})
    @ Base.Broadcast ./broadcast.jl:860
  [4] broadcast_forward(::Type, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/lib/broadcast.jl:269
  [5] adjoint
    @ ~/.julia/packages/Zygote/g2w9o/src/lib/broadcast.jl:348 [inlined]
  [6] _pullback(::Zygote.Context{false}, ::typeof(Base.Broadcast.broadcasted), ::CUDA.CuArrayStyle{1}, ::Type, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66
  [7] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
  [8] adjoint
    @ ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:203 [inlined]
  [9] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [10] _pullback
    @ ./broadcast.jl:1304 [inlined]
 [11] _pullback
    @ ./REPL[262]:1 [inlined]
 [12] _pullback(::Zygote.Context{false}, ::typeof(f), ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [13] pullback(::Function, ::Zygote.Context{false}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:44
 [14] pullback(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:42
 [15] top-level scope
    @ REPL[263]:1
 [16] top-level scope
    @ ~/.julia/packages/CUDA/BbliS/src/initialization.jl:52

but on cpu it's fine

julia> pullback(f, Vector(x), Vector(y))[1]
2-element Vector{Point{Float32}}:
 Point{Float32}(0.26743817f0, 0.2564943f0)
 Point{Float32}(0.34023497f0, 0.41681844f0)

@ToucheSir
Copy link
Member

That's because of JuliaGPU/CUDA.jl#1761, which Zygote doesn't have any control over. Defining a differently-named constructor like tuple is to Tuple and using that should work. Filling out the type parameters of Point (i.e.
f(x::AbstractArray{A}, y::AbstractArray{B}) where {A, B} = Point{A, B}.(x,y)) might also work as it avoids the UnionAll.

@YichengDWu
Copy link
Author

I don't understand why it is using the forward mode AD here. Is there a way to force using the reverse mode AD on GPU? Say writing a custom rrule.

@ToucheSir
Copy link
Member

ToucheSir commented May 25, 2023

Only by defining a rule for broadcasted(::myfunc, ...), which doesn't exist for most functions. Otherwise it won't be GPU compatible. You could see how ChainRule's broadcast rule handles this. If it does a better job on GPU, that's another argument for trying to replace Zygote's broadcasting machinery with it.

@YichengDWu
Copy link
Author

Thanks I will try writing a rrule then.

@YichengDWu
Copy link
Author

pullback still uses the forward mode even there is an rrule.

julia> function rrule(::typeof(f), x, y)
       o = f(x,y)
       function f_pullback(x̄)
           return NoTangent(), x, y
       end
       return o, f_pullback
       end
rrule (generic function with 2 methods)

julia> o, back = pullback(f, x, y)
(Tuple{ForwardDiff.Dual{Nothing, Float32, 2}, ForwardDiff.Dual{Nothing, Float32, 2}}[(Dual{Nothing}(0.61599565,1.0,0.0), Dual{Nothing}(0.16058706,0.0,1.0)), (Dual{Nothing}(0.7189002,1.0,0.0), Dual{Nothing}(0.69142073,0.0,1.0))], Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(f), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#4155#back#1388"{Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(tuple), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#2157#back#289"{Zygote.var"#287#288"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.var"#4165#back#1428"{Zygote.var"#1394#1396"}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2001#back#200"{typeof(identity)}, Zygote.var"#2001#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2157#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1170"{Tuple{Nothing, Nothing, Nothing}}}}, Zygote.var"#2865#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}}}}}}}(∂(f)))

julia> o, back = rrule(f, x, y)
(Tuple{Float32, Float32}[(0.61599565, 0.16058706), (0.7189002, 0.69142073)], var"#f_pullback#34"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}(Float32[0.61599565, 0.7189002], Float32[0.16058706, 0.69142073]))

@YichengDWu
Copy link
Author

Ok this works

f(x,y) = ChainRulesCore.@ignore_derivatives broadcast(tuple,x,y)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants