diff --git a/ext/ComponentArraysZygoteExt.jl b/ext/ComponentArraysZygoteExt.jl index 32b0e37f..22b4ec92 100644 --- a/ext/ComponentArraysZygoteExt.jl +++ b/ext/ComponentArraysZygoteExt.jl @@ -10,4 +10,10 @@ function Zygote.accum(x::ComponentArray, ys::ComponentArray...) return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x)) end +function Zygote.seed(x::ComponentArray, ::Val{N}, offset = 0) where{N} + data = Zygote.seed(getdata(x), Val(N), offset) + + ComponentArray(data, getaxes(x)) +end + end