From 2ab42cdc1378b9b73bc5fd219451b80534740e46 Mon Sep 17 00:00:00 2001 From: Saransh Date: Sun, 26 Jun 2022 19:22:40 +0530 Subject: [PATCH] Update the doctests of `Flux.reset!` --- src/layers/recurrent.jl | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index baf6190f32..760933bb96 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -150,23 +150,29 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to: rnn.state = hidden(rnn.cell) # Examples -```jldoctest; filter = r"[+-]?([0-9]*[.])?[0-9]+" -julia> r = RNN(1 => 1); +```jldoctest +julia> r = Flux.RNNCell(relu, ones(1,1), zeros(1,1), ones(1,1), zeros(1,1)); # users should use the RNN wrapper struct instead + +julia> y = Flux.Recur(r, ones(1,1)); -julia> a = ones(Float32, 1) -1-element Vector{Float32}: +julia> y.state +1×1 Matrix{Float64}: 1.0 -julia> r.state -1×1 Matrix{Float32}: - 0.0 +julia> y(ones(1,1)) # relu(1*1 + 1) +1×1 Matrix{Float64}: + 2.0 -julia> r(a); r.state -1×1 Matrix{Float32}: - 0.61431444 +julia> y.state +1×1 Matrix{Float64}: + 2.0 + +julia> Flux.reset!(y) +1×1 Matrix{Float64}: + 0.0 -julia> Flux.reset!(r) -1×1 Matrix{Float32}: +julia> y.state +1×1 Matrix{Float64}: 0.0 ``` """