From 534ccd811a45e90375c4853863b98785776bcf15 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Wed, 25 Oct 2023 13:25:30 -0400 Subject: [PATCH] add method for zygote.seed(::CA) --- ext/ComponentArraysZygoteExt.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ext/ComponentArraysZygoteExt.jl b/ext/ComponentArraysZygoteExt.jl index 32b0e37f..22b4ec92 100644 --- a/ext/ComponentArraysZygoteExt.jl +++ b/ext/ComponentArraysZygoteExt.jl @@ -10,4 +10,10 @@ function Zygote.accum(x::ComponentArray, ys::ComponentArray...) return ComponentArray(Zygote.accum(getdata(x), getdata.(ys)...), getaxes(x)) end +function Zygote.seed(x::ComponentArray, ::Val{N}, offset = 0) where{N} + data = Zygote.seed(getdata(x), Val(N), offset) + + ComponentArray(data, getaxes(x)) +end + end