-
-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Missing ChainRule for ComponentVector(a; b...)
#207
Comments
Here is a (failing) attempt of mine: using ComponentArrays
import ChainRulesCore
import Zygote
# -----------
# rule
function ChainRulesCore.rrule(::typeof(ComponentArrays.ComponentVector),
x::ComponentVector; kwargs...)
res = ComponentVector(x; kwargs...)
function pullback(Δ)
one_x = zero(similar(x, eltype(Δ))) .+ 1
one_y = zero(ComponentVector{eltype(Δ)}(kwargs)) .+ 1
return ChainRulesCore.NoTangent(), one_x, one_y
end
return res, pullback
end
# -----------
# test
function mymerge(x::ComponentVector, y::ComponentVector)
z = ComponentVector(x; y...)
z
end
x = ComponentVector(a=1.0, b=2, c=(e=3, f=4))
y = ComponentVector(a = 11, e=4.0, d=5.0)
mymerge(x, y)
Zygote.gradient(a -> sum(mymerge(a, y)), x)[1] # fails with StackOverflowError
Zygote.gradient(a -> sum(mymerge(x, a)), y)[1] # fails with StackOverflowError Not sure why this is causing a StackOverflowError. A test version without the |
It looks like you were super close. You just needed to splat out the keyword arguments in the pullback. function ChainRulesCore.rrule(::typeof(ComponentArrays.ComponentVector),
x::ComponentVector; kwargs...)
res = ComponentVector(x; kwargs...)
function pullback(Δ)
one_x = zero(similar(x, eltype(Δ))) .+ 1
one_y = zero(ComponentVector{eltype(Δ)}(; kwargs...)) .+ 1
return ChainRulesCore.NoTangent(), one_x, one_y
end
return res, pullback
end Thanks, though! I'll add it as soon as I get a chance. |
Wait no, that gives the wrong answer. |
Interesting: ChainRules doesn't work with keyword arguments. We may want to instead define the behavior in a |
I've tried to use Zygote with ComponentArrays but cannot it cannot get through this code:
I think it is just a missing ChainRule rule. I've tired, but unfortunately writing rules is black magic for me...
The text was updated successfully, but these errors were encountered: