From fccdd11b3e5a8dd0fe26a327ca2719d5d33fc1d0 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 13 Apr 2024 16:48:30 +0200 Subject: [PATCH 1/3] cl/comp --- src/Optimisers.jl | 3 +- src/trainables.jl | 105 ++++++++++++++++++++++++++++++++++++++++++++- test/trainables.jl | 19 ++++++++ 3 files changed, 124 insertions(+), 3 deletions(-) diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 2e115c4..54c8df9 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -5,6 +5,7 @@ using Functors: functor, fmap, fmap_with_path, isleaf, @functor, fmapstructure, children, AbstractWalk using LinearAlgebra + include("interface.jl") export AbstractRule @@ -16,7 +17,7 @@ include("destructure.jl") export destructure include("trainables.jl") -export trainables +export trainables, trainables_nt export KeyPath, haskeypath, getkeypath # from Functors.jl include("rules.jl") diff --git a/src/trainables.jl b/src/trainables.jl index 370de0f..1159b26 100644 --- a/src/trainables.jl +++ b/src/trainables.jl @@ -76,7 +76,7 @@ end function ∇trainables(x, Δ) i = 0 - return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _ + return fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do _ return Δ[i+=1] end end @@ -113,7 +113,7 @@ end function ∇trainables_with_path(x, Δ) i = 0 - return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _ + return fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do _ Δi = Δ[i+=1] if isnothing(Δi) return nothing @@ -122,3 +122,104 @@ function ∇trainables_with_path(x, Δ) end end end + + +### trainables_nt ###################### + +""" + trainables_nt(model) -> ps, re + +Return a pair `(ps, re)` where `ps` is a nested named tuple with the same structure as +the trainable part of `model` and with leaves the trainable parameters. + +Parameters are not copied, but the returned `ps` is a view into the original model. + +The `re` is a function that reconstructs a model from the parameters, +i.e. `re(ps)` is the same as the origin `model` but with the trainable parameters replaced by `ps`. + +# Examples + +```jldoctest +julia> using Flux, Optimisers + +julia> model = Chain(Dense(784, 32, relu), Dense(32, 10)); + +julia> ps, re = trainables_nt(model); + +julia> ps.layers._1.weight === model.layers[1].weight +true +``` + +```jldoctest + +julia> v = ComponentVector(ps) + +julia> model2 = re(2 * v) + +``` +""" +function trainables_nt(x) + walknt = TrainableNamedTupleWalk() + ps = fmap(identity, x; exclude=isnumeric, walk=walknt, cache=nothing) + re = RestructureFromNT(x) + return ps, re +end + + +struct RestructureFromNT{T} + x::T +end + +function (re::RestructureFromNT)(ps) + walk = RestructureFromNamedTupleWalk() + return fmap(re.x, ps; exclude=isnumeric, walk, cache=nothing) do y, p + return p + end +end + +struct TrainableNamedTupleWalk <: AbstractWalk end + +function (::TrainableNamedTupleWalk)(recurse, x) + ch = trainable(x) + y = map(recurse, make_named_tuple(ch)) + return y +end + +struct RestructureFromNamedTupleWalk <: AbstractWalk end + +function (::RestructureFromNamedTupleWalk)(recurse, x, nt) + children, re = functor(x) + newchildren = map_commons(recurse, children, nt) + return re(newchildren) +end + +function map_commons(f, x::NamedTuple{xkeys}, y) where {xkeys} + ykeys = propertynames(y) + vals = map(k -> k in ykeys ? f(x[k], getproperty(y, k)) : x[k], xkeys) + return NamedTuple{xkeys}(vals) +end + +function map_commons(f, x::Tuple, y) + ykeys = propertynames(y) + vals = ntuple(length(x)) do i + k = Symbol("_", i) + k in ykeys ? f(x[i], getproperty(y, k)) : x[i] + end + return vals +end + +function map_commons(f, x::Vector, y) + ykeys = propertynames(y) + vals = map(1:length(x)) do i + k = Symbol("_", i) + k in ykeys ? f(x[i], getproperty(y, k)) : x[i] + end + return vals +end + +make_named_tuple(x::NamedTuple) = x +make_named_tuple(x::AbstractDict{Symbol}) = NamedTuple(x) +make_named_tuple(x::AbstractDict) = NamedTuple(Symbol("_", k) => v for (k, v) in pairs(x)) +make_named_tuple(x::Tuple) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x) +make_named_tuple(x::Vector) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x) + diff --git a/test/trainables.jl b/test/trainables.jl index e1aa011..8a2c54d 100644 --- a/test/trainables.jl +++ b/test/trainables.jl @@ -139,3 +139,22 @@ end @test g.y == [2.0, 4.0, 6.0] @test g.z === nothing end + +using Flux, Optimisers +using ComponentArrays +using Test + + +model0 = Chain( + Dense(784, 32, relu), + Dense(32, 10)) + +ps, re = trainables_nt(model0) +@test ps.layers._1.weight === model0[1].weight +model1 = re(ps) +@test model1[1].weight === ps.layers._1.weight + +v = ComponentVector(ps) +v2 = 2 * v +model2 = re(v2) +@test model2[1].weight === v2.layers._1.weight From 6f053f31be5484203a126b8f9eab00266ca45ca0 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 14 Apr 2024 09:50:41 +0200 Subject: [PATCH 2/3] gradients and more tests --- Project.toml | 3 +- src/trainables.jl | 93 ++++++++-- src/utils.jl | 4 + test/runtests.jl | 1 + test/trainables.jl | 441 ++++++++++++++++++++++++++++++++++++++------- 5 files changed, 465 insertions(+), 77 deletions(-) diff --git a/Project.toml b/Project.toml index dab1f6e..f7cc8f7 100644 --- a/Project.toml +++ b/Project.toml @@ -18,9 +18,10 @@ Zygote = "0.6.40" julia = "1.6" [extras] +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "StaticArrays", "Zygote"] +test = ["Test", "ComponentArrays", "StaticArrays", "Zygote"] diff --git a/src/trainables.jl b/src/trainables.jl index 1159b26..91e2e87 100644 --- a/src/trainables.jl +++ b/src/trainables.jl @@ -155,28 +155,88 @@ true julia> v = ComponentVector(ps) julia> model2 = re(2 * v) - +Chain( + Dense(784 => 32, relu), # 25_120 parameters + Dense(32 => 10), # 330 parameters +) # Total: 4 arrays, 25_450 parameters, 100.281 KiB. ``` """ -function trainables_nt(x) +function trainables_nt(model) + ps = _trainables_nt(model) + re = RestructureFromNT(model) + return ps, re +end + +function _trainables_nt(x) walknt = TrainableNamedTupleWalk() ps = fmap(identity, x; exclude=isnumeric, walk=walknt, cache=nothing) - re = RestructureFromNT(x) - return ps, re + return ps end +function ChainRulesCore.rrule(::typeof(_trainables_nt), model) + ps = _trainables_nt(model) + function _trainables_nt_back(Δps) + walk = TrainableNamedTupleBackWalk() + Δmodel = fmap(model, Δps; exclude=isnumeric, walk, cache=nothing) do x, Δ + return Δ + end + return (NoTangent(), Δmodel) + end + return ps, _trainables_nt_back +end struct RestructureFromNT{T} x::T end -function (re::RestructureFromNT)(ps) +(re::RestructureFromNT)(ps) = rebuild_from_nt(re.x, ps) + +function rebuild_from_nt(model, ps) walk = RestructureFromNamedTupleWalk() - return fmap(re.x, ps; exclude=isnumeric, walk, cache=nothing) do y, p + return fmap(model, ps; exclude=isnumeric, walk, cache=nothing) do x, p return p end end +struct RestructureFromNamedTupleWalk <: AbstractWalk end + +function (::RestructureFromNamedTupleWalk)(recurse, x, nt) + @show 1 x nt + children, re = functor(x) + @show 2 children + newchildren = map_commons(recurse, children, nt) + @show 3 x nt children newchildren + return re(newchildren) +end + +function ChainRulesCore.rrule(::typeof(rebuild_from_nt), x, ps) + model = rebuild_from_nt(x, ps) + function rebuild_from_nt_back(Δmodel_raw) + Δmodel = unthunk(Δmodel_raw) + walk = RestructureFromNamedTupleWalk() + Δps = fmap(ps, Δmodel; exclude=isnumeric, walk, cache=nothing) do p, Δ + return Δ + end + return (NoTangent(), NoTangent(), Δps) + end + return model, rebuild_from_nt_back +end + + +struct TrainableNamedTupleBackWalk <: AbstractWalk end + +function (::TrainableNamedTupleBackWalk)(recurse, model, Δps) + @show 1 typeof(model) typeof(Δps) + ch = trainable(model) + Δ = unmake_named_tuple(ch, Δps) + @show 2 typeof(ch) typeof(Δ) + Δ === nothing && return nothing + Δ === ZeroTangent() && return ZeroTangent() + y = mapvalue(recurse, ch, Δ) + @show 3 typeof(model) typeof(ch) typeof(Δ) typeof(y) + return y +end + struct TrainableNamedTupleWalk <: AbstractWalk end function (::TrainableNamedTupleWalk)(recurse, x) @@ -185,13 +245,6 @@ function (::TrainableNamedTupleWalk)(recurse, x) return y end -struct RestructureFromNamedTupleWalk <: AbstractWalk end - -function (::RestructureFromNamedTupleWalk)(recurse, x, nt) - children, re = functor(x) - newchildren = map_commons(recurse, children, nt) - return re(newchildren) -end function map_commons(f, x::NamedTuple{xkeys}, y) where {xkeys} ykeys = propertynames(y) @@ -223,3 +276,17 @@ make_named_tuple(x::AbstractDict) = NamedTuple(Symbol("_", k) => v for (k, v) in make_named_tuple(x::Tuple) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x) make_named_tuple(x::Vector) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x) + +unmake_named_tuple(x::NamedTuple, ps) = ps + +function unmake_named_tuple(x::Tuple, ps) + return ntuple(length(x)) do i + ps[Symbol("_", i)] + end +end + +function unmake_named_tuple(x::Vector, ps) + return map(1:length(x)) do i + ps[Symbol("_", i)] + end +end diff --git a/src/utils.jl b/src/utils.jl index 7c6c95b..a8f3eb8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,6 +2,10 @@ mapvalue(f, x...) = map(f, x...) mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x) +# without theses, tuples are returned instead of NamedTuples +mapvalue(f, x::NamedTuple{Ks}, y::Tangent{<:Any,<:NamedTuple}) where {Ks} = + NamedTuple{Ks}((f(v, y[k]) for (k,v) in pairs(x))) + mapkey(f, x::NamedTuple{Ks}) where Ks = NamedTuple{Ks}(map(f, Ks)) mapkey(f, x::Dict) = Dict(k => f(k) for k in keys(x)) mapkey(f, x::Tuple) = ntuple(i -> f(i), length(x)) diff --git a/test/runtests.jl b/test/runtests.jl index fc0fe57..2706e1f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using ChainRulesCore, Functors, StaticArrays, Zygote using LinearAlgebra, Statistics, Test, Random using Optimisers: @.., @lazy using Base.Broadcast: broadcasted, instantiate, Broadcasted +using ComponentArrays: ComponentArrays, ComponentVector Random.seed!(1) diff --git a/test/trainables.jl b/test/trainables.jl index 8a2c54d..9fb3fa5 100644 --- a/test/trainables.jl +++ b/test/trainables.jl @@ -60,85 +60,385 @@ m9 = (a = m1, b = mat, c = [mat, m1]) @test length(ps) == 2 @test ps[1] == 1:3 @test ps[2] == mat -end - -@testset "gradient" begin - loss(m) = sum([sum(abs2, p) for p in trainables(m)]) - g = gradient(loss, m1)[1] - @test g == [2.0, 4.0, 6.0] - - g = gradient(loss, m2)[1] - @test g == ([2.0, 4.0, 6.0], [8.0, 10.0, 12.0]) - g = gradient(loss, m3)[1] - @test g.x == [2.0, 4.0, 6.0] - @test g.y === nothing - @test g.z == [8.0, 10.0, 12.0] + @testset "gradient" begin + loss(m) = sum([sum(abs2, p) for p in trainables(m)]) + g = gradient(loss, m1)[1] + @test g == [2.0, 4.0, 6.0] + + g = gradient(loss, m2)[1] + @test g == ([2.0, 4.0, 6.0], [8.0, 10.0, 12.0]) + + g = gradient(loss, m3)[1] + @test g.x == [2.0, 4.0, 6.0] + @test g.y === nothing + @test g.z == [8.0, 10.0, 12.0] + + g = gradient(loss, m4)[1] + @test g == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]) + g.x === g.y # shared gradient for shared weights + + g = gradient(loss, m5)[1] + @test g == (a = ((x = [2.0, 4.0, 6.0], y = nothing, z = [8.0, 10.0, 12.0]), nothing), b = ([2.0, 4.0, 6.0], nothing), c = ((x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]), nothing)) + + g = gradient(loss, m6)[1] + @test g == (a = [2.0, 4.0, 6.0], b = ComplexF64[8.0 + 2.0im], c = [2.0, 4.0, 6.0]) + + g = gradient(loss, m7)[1] + @test g == (a = (nothing, [2.0, 4.0, 6.0]), b = nothing, c = nothing) + + g = gradient(loss, m8)[1] + @test g[1] == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0]) + @test g[2] == (a = nothing, b = (x = [8.0], y = nothing), c = nothing) + @test g[3] == [[10.0]] + + g = gradient(loss, m9)[1] + @test g == (a = [2.0, 4.0, 6.0], b = Float32[8.0 12.0; 10.0 14.0], c = Array[Float32[8.0 12.0; 10.0 14.0], [2.0, 4.0, 6.0]]) + end - g = gradient(loss, m4)[1] - @test g == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]) - g.x === g.y # shared gradient for shared weights + @testset "dict" begin + d = Dict(:a => rand(2), :b => ones(2)) + ps = trainables(d) + @test length(ps) == 2 + @test ps[1] == d[:a] + @test ps[2] == d[:b] - g = gradient(loss, m5)[1] - @test g == (a = ((x = [2.0, 4.0, 6.0], y = nothing, z = [8.0, 10.0, 12.0]), nothing), b = ([2.0, 4.0, 6.0], nothing), c = ((x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]), nothing)) - - g = gradient(loss, m6)[1] - @test g == (a = [2.0, 4.0, 6.0], b = ComplexF64[8.0 + 2.0im], c = [2.0, 4.0, 6.0]) - - g = gradient(loss, m7)[1] - @test g == (a = (nothing, [2.0, 4.0, 6.0]), b = nothing, c = nothing) + g = gradient(d -> sum(trainables(d)[1].^2) /2 + sum(trainables(d)[2]), d)[1] + @test g[:a] == d[:a] + @test_broken g[:b] == [1.0, 1.0] + end - g = gradient(loss, m8)[1] - @test g[1] == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0]) - @test g[2] == (a = nothing, b = (x = [8.0], y = nothing), c = nothing) - @test g[3] == [[10.0]] + @testset "second order derivatives" begin + struct DenseLayer + w + b + end - g = gradient(loss, m9)[1] - @test g == (a = [2.0, 4.0, 6.0], b = Float32[8.0 12.0; 10.0 14.0], c = Array[Float32[8.0 12.0; 10.0 14.0], [2.0, 4.0, 6.0]]) -end + Functors.@functor DenseLayer -@testset "dict" begin - d = Dict(:a => rand(2), :b => ones(2)) - ps = trainables(d) - @test length(ps) == 2 - @test ps[1] == d[:a] - @test ps[2] == d[:b] + loss(m) = sum([sum(abs2, p) for p in trainables(m)]) - g = gradient(d -> sum(trainables(d)[1].^2) /2 + sum(trainables(d)[2]), d)[1] - @test g[:a] == d[:a] - @test_broken g[:b] == [1.0, 1.0] -end + model = DenseLayer([1. 2.; 3. 4.], [0., 0.]) -@testset "second order derivatives" begin - struct DenseLayer - w - b + g = gradient(m -> loss(gradient(loss, m)), model)[1] + @test g.w == [8.0 16.0; 24.0 32.0] + @test g.b == [0.0, 0.0] end - Functors.@functor DenseLayer + @testset "trainables(x, path=true)" begin + loss(m) = sum(abs2, trainables(m, path=true)[1][2]) - loss(m) = sum([sum(abs2, p) for p in trainables(m)]) + ps = trainables(m4, path=true) + @test length(ps) == 2 + @test ps[1] == (KeyPath(:x,), [1.0, 2.0, 3.0]) + @test ps[2] == (KeyPath(:z,), [4.0, 5.0, 6.0]) - model = DenseLayer([1. 2.; 3. 4.], [0., 0.]) + g = gradient(loss, m4)[1] + @test g.x == [2.0, 4.0, 6.0] + @test g.y == [2.0, 4.0, 6.0] + @test g.z === nothing + end +end + +@testset "trainables_nt" begin + + @testset "nt & rebuild" begin + @test trainables_nt(m1)[1] isa Vector{Float64} + @test trainables_nt(m1)[1] == 1:3 + @test trainables_nt(m2)[1] == (_1 = [1.0, 2.0, 3.0], _2 = [4.0, 5.0, 6.0]) + @test trainables_nt(m3)[1] == (x = [1.0, 2.0, 3.0], y = NamedTuple(), z = [4.0, 5.0, 6.0]) + @test trainables_nt(m4)[1] == (x = [1.0, 2.0, 3.0], y = [1.0, 2.0, 3.0], z = [4.0, 5.0, 6.0]) + @test trainables_nt(m5)[1] == (a = (_1 = (x = [1.0, 2.0, 3.0], y = NamedTuple(), z = [4.0, 5.0, 6.0]), _2 = NamedTuple()), b = (_1 = [1.0, 2.0, 3.0], _2 = NamedTuple()), c = (_1 = (x = [1.0, 2.0, 3.0], y = [1.0, 2.0, 3.0], z = [4.0, 5.0, 6.0]), _2 = NamedTuple())) + @test trainables_nt(m6)[1] == (a = [1.0, 2.0, 3.0], b = ComplexF64[4.0 + 1.0im], c = [1.0, 2.0, 3.0]) + nt9 = trainables_nt(m9)[1] + @test nt9 == (a = [1.0, 2.0, 3.0], b = Float32[4.0 6.0; 5.0 7.0], c = (_1 = Float32[4.0 6.0; 5.0 7.0], _2 = [1.0, 2.0, 3.0])) + @test nt9.a === m9.a # no copies + @test nt9.a === nt9.c._2 # keeps shared references + + @test trainables_nt(m1)[2](7:9) == [7,8,9] + @test trainables_nt(m2)[2]((_1 = 4:6, _2 = 7:9)) == ([4,5,6], [7,8,9]) + # reconstruction doesn't need the full input + @test trainables_nt(m2)[2]((; _2 = 7:9)) == ([1,2,3], [7,8,9]) + @test trainables_nt(m3)[2]((x =4:6, z=7:9)) == (x = [4,5,6], y = sin, z = [7,8,9]) + a = [4,5,6] + Δ = (x=a, y=a, z=7:9) + m4′ = trainables_nt(m4)[2](Δ) + @test m4′ == (x = [4,5,6], y = [4,5,6], z = [7,8,9]) + @test m4′.x === m4′.y # shared references are preserved + + # struct, partially trainable + a = [10,20,30] + @test trainables_nt(m7)[1] == (a = (_1 = NamedTuple(), _2 = [1.0, 2.0, 3.0]),) + m7′ = trainables_nt(m7)[2]((; a = (; _2 = a))) + @test m7′.a == (sin, a) + @test m7′.b == (cos, [4,5,6]) + @test m7′.c == (tan, [7,8,9]) + @test m7′.a[2] === a # no copies + + @test trainables_nt(m8)[1] == (_1 = (y = [1.0, 2.0, 3.0], x = [1.0, 2.0, 3.0]), _2 = (a = NamedTuple(), b = (y = NamedTuple(), x = [4.0]), c = NamedTuple()), _3 = (_1 = [5.0],)) + a = [7, 8, 9] + m8′ = trainables_nt(m8)[2]((; _1 = (; x = a))) + @test m8′[1].x === a + @test m8′[1].y == [1.0, 2.0, 3.0] # tie is broken + @test m8′[2].b.y === false + @test m8′[3][1] == [5.0] + end - g = gradient(m -> loss(gradient(loss, m)), model)[1] - @test g.w == [8.0 16.0; 24.0 32.0] - @test g.b == [0.0, 0.0] + @testset "re(ComponentArrays)" begin + model = TwoThirds((a=[1.0,2.0], b=([3.,4.0], [5.0 6.0; 7.0 8.0]), c=[11,12]), [9.0], [10.0]) + ps, re = trainables_nt(model) + @test ps == (a = (a = [1.0, 2.0], b = (_1 = [3.0, 4.0], _2 = [5.0 6.0; 7.0 8.0]), c = NamedTuple()),) + v = ComponentVector(ps) + v2 = 2v + model′ = re(v2) + @test model′.a.a == 2*model.a.a + @test model′.a.b[1] == 2*model.a.b[1] + @test model′.a.b[2] == 2*model.a.b[2] + @test model′.a.c == model.a.c + @test model′.b == model.b + @test model′.c == model.c + @test model′.a.b[1] === v2.a.b._1 # no copies + @test model′.a.a === v2.a.a + end end -@testset "trainables(x, path=true)" begin - loss(m) = sum(abs2, trainables(m, path=true)[1][2]) +# @testset "gradient of flatten" begin +# @test gradient(m -> trainables_nt(m)[1][1], m1)[1] == [1,0,0] +# @test gradient(m -> trainables_nt(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) +# @test gradient(m -> trainables_nt(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing) +# @test gradient(m -> trainables_nt(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) +# @test gradient(m -> trainables_nt(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0]) + +# g5 = gradient(m -> trainables_nt(m)[1][3], m5)[1] +# @test g5.a[1].x == [0,0,1] +# @test g5.a[2] === nothing + +# g6 = gradient(m -> imag(trainables_nt(m)[1][4]), m6)[1] +# @test g6.a == [0,0,0] +# @test g6.a isa Vector{Float64} +# @test g6.b == [0+im] + +# g8 = gradient(m -> sum(abs2, trainables_nt(m)[1]), m8)[1] +# @test g8[1].x == [2,4,6] +# @test g8[2].b.x == [8] +# @test g8[3] == [[10.0]] + +# g9 = gradient(m -> sum(sqrt, trainables_nt(m)[1]), m9)[1] +# @test g9.c === nothing + +# @testset "second derivative" begin +# @test gradient([1,2,3.0]) do v +# sum(abs2, gradient(m -> sum(abs2, trainables_nt(m)[1]), (v, [4,5,6.0]))[1][1]) +# end[1] ≈ [8,16,24] +# # With Diffractor, non-leaf _grad!(x, dx, off, flat::AbstractVector) gets double-wrapped dx: +# # off = (0, 3), dx = Tangent{Tangent{Tuple{Vector{Float64}, Vector{Float64}}, ... +# # until you add explicit double-unwrap: base(dx::Tangent{<:Tangent}) = backing(dx).backing +# # With Zygote, instead: +# # dx = Tangent{Any}(backing = Tangent{Any}([4.0, 8.0, 12.0], ZeroTangent()),) + +# @test gradient([1,2,3.0]) do v +# sum(gradient(m -> sum(trainables_nt(m)[1])^3, (v, [4,5,6.0]))[1][1]) +# end[1] == [378, 378, 378] + +# VERSION >= v"1.10" && @test gradient([1,2,3.0]) do v +# sum(abs2, gradient(m -> sum(abs2, trainables_nt(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1]) +# end[1] ≈ [8,16,24] +# # Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z) +# # Diffractor error in perform_optic_transform +# end + +# false && @testset "using Yota" begin +# @test Yota_gradient(m -> trainables_nt(m)[1][1], m1)[1] == [1,0,0] +# @test Yota_gradient(m -> trainables_nt(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) +# @test Yota_gradient(m -> trainables_nt(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing) +# @test Yota_gradient(m -> trainables_nt(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) +# @test Yota_gradient(m -> trainables_nt(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0]) + +# g5 = Yota_gradient(m -> trainables_nt(m)[1][3], m5)[1] +# @test g5.a[1].x == [0,0,1] +# @test g5.a[2] === nothing + +# g6 = Yota_gradient(m -> imag(trainables_nt(m)[1][4]), m6)[1] +# @test g6.a == [0,0,0] +# @test g6.a isa Vector{Float64} +# @test g6.b == [0+im] + +# g8 = Yota_gradient(m -> sum(abs2, trainables_nt(m)[1]), m8)[1] +# @test g8[1].x == [2,4,6] +# @test g8[2].b.x == [8] +# @test g8[3] == [[10.0]] + +# g9 = Yota_gradient(m -> sum(sqrt, trainables_nt(m)[1]), m9)[1] +# @test g9.c === nothing +# end +# end + +# @testset "gradient of rebuild" begin +# re1 = trainables_nt(m1)[2] +# @test gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] +# re2 = trainables_nt(m2)[2] +# @test gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0] +# re3 = trainables_nt(m3)[2] +# @test gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0] +# @test gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0] + +# re4 = trainables_nt(m4)[2] +# @test gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0] +# @test gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0] +# @test gradient(rand(6)) do x +# m = re4(x) +# m.x[1] + 2*m.y[2] + 3*m.z[3] +# end[1] == [1,2,0, 0,0,3] + +# re7 = trainables_nt(m7)[2] +# @test gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1] +# @test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] +# @test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] + +# v8, re8 = trainables_nt(m8) +# @test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] +# @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] + +# re9 = trainables_nt(m9)[2] +# @test gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] + +# @testset "second derivative" begin +# @test_broken gradient(collect(1:6.0)) do y +# sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1]) +# end[1] ≈ [8,16,24,0,0,0] +# # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} +# # with Zygote, which can be fixed by: +# # Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,) + +# @test_broken gradient(collect(1:6.0)) do y +# sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1]) +# end[1] ≈ [0,0,0,32,40,48] +# # Not fixed by this: +# # Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,) +# end + +# false && @testset "using Yota" begin +# re1 = trainables_nt(m1)[2] +# @test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] +# re2 = trainables_nt(m2)[2] +# @test Yota_gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0] +# re3 = trainables_nt(m3)[2] +# @test Yota_gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0] +# @test Yota_gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0] + +# re4 = trainables_nt(m4)[2] +# @test Yota_gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0] +# @test Yota_gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0] +# @test Yota_gradient(rand(6)) do x +# m = re4(x) +# m.x[1] + 2*m.y[2] + 3*m.z[3] +# end[1] == [1,2,0, 0,0,3] + +# re7 = trainables_nt(m7)[2] +# @test Yota_gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1] +# @test Yota_gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] +# @test Yota_gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] + +# v8, re8 = trainables_nt(m8) +# @test Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] +# @test Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] + +# re9 = trainables_nt(m9)[2] +# @test Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] +# end +# end + +# @testset "Flux issue 1826" begin +# v, re = trainables_nt((x=[1,2.0], y=[3,4,5.0])) +# @test gradient(zero(v)) do w +# m = re(w) +# 5 * sum(m.x) + 7 * sum(m[2]) # uses both x and y +# end == ([5.0, 5.0, 7.0, 7.0, 7.0],) +# # This, using only x, was broken on Flux: +# @test gradient(w -> sum(re(w).x), zero(v)) == ([1.0, 1.0, 0.0, 0.0, 0.0],) + +# sh = [7,7.0]; +# v, re = trainables_nt((x=sh, y=[3.0,4.0], z=sh)) # shared array in the model +# @test v == [7, 7, 3, 4] +# @test re([1,10,100,1000]) == (x = [1, 10], y = [100, 1000], z = [1, 10]) + +# @test gradient(zero(v)) do w +# m = re(w) +# 3 * sum(m.x) + 13 * sum(m.z) # no dependence on y, but two distinct gradient arrays +# end == ([16, 16, 0, 0],) # Flux gave ([3.0, 3.0, 13.0, 13.0],) + +# @test gradient(zero(v)) do w +# m = re(w) +# 4(sum(m.x) + sum(m.z)) # now two gradients are ===, so it eliminates one +# end == ([8,8,0,0],) + +# @test gradient(zero(v)) do w +# m = re(w) +# 4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one +# end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],) +# end + +# @testset "DiffEqFlux issue 699" begin +# # The gradient of `re` is a vector into which we accumulate contributions, and the issue +# # is that one contribution may have a wider type than `v`, especially for `Dual` numbers. +# v, re = destructure((x=Float32[1,2], y=Float32[3,4,5])) +# _, bk = Zygote.pullback(re, ones(Float32, 5)) +# # Testing with `Complex` isn't ideal, but this was an error on 0.2.1. +# # If some upgrade inserts ProjectTo, this will fail, and can be changed: +# @test bk((x=[1.0,im], y=nothing)) == ([1,im,0,0,0],) + +# @test bk((x=nothing, y=[10,20,30]))[1] isa Vector{Float32} # despite some ZeroTangent +# @test bk((x=nothing, y=nothing)) == ([0,0,0,0,0],) +# @test bk((x=nothing, y=@thunk [1,2,3] .* 10.0)) == ([0,0,10,20,30],) +# @test bk((x=[1.2, 3.4], y=Float32[5,6,7])) == ([1.2, 3.4, 5, 6, 7],) +# end + +# #= + +# # Adapted from https://github.com/SciML/DiffEqFlux.jl/pull/699#issuecomment-1092846657 +# using ForwardDiff, Zygote, Flux, Optimisers, Test + +# y = Float32[0.8564646, 0.21083355] +# p = randn(Float32, 27); +# t = 1.5f0 +# λ = [ForwardDiff.Dual(0.87135935, 1, 0, 0, 0, 0, 0), ForwardDiff.Dual(1.5225363, 0, 1, 0, 0, 0, 0)] + +# model = Chain(x -> x .^ 3, +# Dense(2 => 5, tanh), +# Dense(5 => 2)) + +# p,re = Optimisers.destructure(model) +# f(u, p, t) = re(p)(u) +# _dy, back = Zygote.pullback(y, p) do u, p +# vec(f(u, p, t)) +# end +# tmp1, tmp2 = back(λ); +# tmp1 +# @test tmp2 isa Vector{<:ForwardDiff.Dual} + +# =# + +# @testset "empty, issue 67" begin +# m0 = (nothing, missing, isempty) +# @test destructure(m0)[1] isa Vector{<:Real} +# v0, re0 = destructure(m0) +# @test re0(Float32[]) === m0 +# @test_throws DimensionMismatch re0([1]) + +# # This is an elaborate way of checking that it doesn't cause promotions, even of small floats: +# m01 = [(x=nothing, y=0), (x=Float16[1, 2], y=Float16[3])] +# v01, _ = destructure(m01) +# v012 = vcat(destructure(m01[1])[1], destructure(m01[2])[1]) +# @test v01 == v012 +# @test v012 isa Vector{Float16} + +# y, bk = Zygote.pullback(x -> sum(destructure(x)[1]), ("a", :beta)) +# @test bk(1.0) == (nothing,) +# # Zygote regards 3,4 as differentiable, but Optimisers does not regard them as parameters: +# y, bk = Zygote.pullback(x -> sum(destructure(x)[1]), (3, 4)) +# @test bk(1.0) == (nothing,) +# end - ps = trainables(m4, path=true) - @test length(ps) == 2 - @test ps[1] == (KeyPath(:x,), [1.0, 2.0, 3.0]) - @test ps[2] == (KeyPath(:z,), [4.0, 5.0, 6.0]) - - g = gradient(loss, m4)[1] - @test g.x == [2.0, 4.0, 6.0] - @test g.y == [2.0, 4.0, 6.0] - @test g.z === nothing -end using Flux, Optimisers using ComponentArrays @@ -158,3 +458,18 @@ v = ComponentVector(ps) v2 = 2 * v model2 = re(v2) @test model2[1].weight === v2.layers._1.weight + +g = gradient(model0) do m + ps, re = trainables_nt(m) + return sum(ps.layers._1.weight) +end[1] +@test eltype(g.layers[1].weight) == Float32 +@test g.layers[1].weight ≈ ones(Float32, 32, 784) +@test g.layers[1].bias === nothing +@test g.layers[2] === nothing + + +# # TODO +# - [] Name? +# - [] Should the named tuple contain NamedTuple() leaves? +# - [] Optimize performance and improve type stability From d6d761b8594393abe8319bd922b530b06e6f81b2 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 19 Apr 2024 08:49:10 +0200 Subject: [PATCH 3/3] cannot fix test --- Project.toml | 1 + src/trainables.jl | 104 ++++++++++++------ src/utils.jl | 4 + test/trainables.jl | 267 +++++++++++++++++++++------------------------ 4 files changed, 199 insertions(+), 177 deletions(-) diff --git a/Project.toml b/Project.toml index f7cc8f7..b2a73cb 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.3.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Debugger = "31a5f54b-26ea-5ae9-a837-f05ce5417438" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/trainables.jl b/src/trainables.jl index 91e2e87..4177857 100644 --- a/src/trainables.jl +++ b/src/trainables.jl @@ -185,13 +185,37 @@ function ChainRulesCore.rrule(::typeof(_trainables_nt), model) return ps, _trainables_nt_back end + +struct TrainableNamedTupleWalk <: AbstractWalk end + +function (::TrainableNamedTupleWalk)(recurse, x) + ch = trainable(x) + y = map(recurse, make_named_tuple(ch)) + return y +end + +struct TrainableNamedTupleBackWalk <: AbstractWalk end + +function (::TrainableNamedTupleBackWalk)(recurse, model, Δps) + # @show 1 typeof(model) typeof(Δps) + ch = trainable(model) + Δ = unmake_named_tuple(ch, Δps) + # @show 2 typeof(ch) typeof(Δ) + Δ === nothing && return nothing + Δ === ZeroTangent() && return ZeroTangent() + y = mapvalue(recurse, ch, Δ) + # @show 3 typeof(model) typeof(ch) typeof(Δ) typeof(y) + return y +end + + struct RestructureFromNT{T} x::T end -(re::RestructureFromNT)(ps) = rebuild_from_nt(re.x, ps) +(re::RestructureFromNT)(ps) = restructure_from_nt(re.x, ps) -function rebuild_from_nt(model, ps) +function restructure_from_nt(model, ps) walk = RestructureFromNamedTupleWalk() return fmap(model, ps; exclude=isnumeric, walk, cache=nothing) do x, p return p @@ -201,51 +225,57 @@ end struct RestructureFromNamedTupleWalk <: AbstractWalk end function (::RestructureFromNamedTupleWalk)(recurse, x, nt) - @show 1 x nt children, re = functor(x) - @show 2 children newchildren = map_commons(recurse, children, nt) - @show 3 x nt children newchildren return re(newchildren) end -function ChainRulesCore.rrule(::typeof(rebuild_from_nt), x, ps) - model = rebuild_from_nt(x, ps) - function rebuild_from_nt_back(Δmodel_raw) +function ChainRulesCore.rrule(::typeof(restructure_from_nt), x, ps) + model = restructure_from_nt(x, ps) + proj_ps = ProjectTo(ps) + + function restructure_from_nt_back(Δmodel_raw) Δmodel = unthunk(Δmodel_raw) - walk = RestructureFromNamedTupleWalk() - Δps = fmap(ps, Δmodel; exclude=isnumeric, walk, cache=nothing) do p, Δ + walk = RestructureFromNamedTupleBackWalk() + function exclude(x) + @show "exclude" x isnumeric(x) + # i += 1 + # return i > 1 + return isnumeric(x) + end + Δps = fmap(ps, Δmodel; exclude, walk, cache=nothing) do p, Δ + @show "fmap" Δ p + return Δ end - return (NoTangent(), NoTangent(), Δps) + @show "rrule" Δmodel x ps Δps + @show typeof(Δmodel) typeof(ps) typeof(Δps) + Δps = (_1=ones(3), _2=zeros(3)) + Δpst = Tangent{typeof(Δps)}(; Δps...) + # pR + return (NoTangent(), NoTangent(), proj_ps(Δpst)) end - return model, rebuild_from_nt_back + return model, restructure_from_nt_back end - -struct TrainableNamedTupleBackWalk <: AbstractWalk end - -function (::TrainableNamedTupleBackWalk)(recurse, model, Δps) - @show 1 typeof(model) typeof(Δps) - ch = trainable(model) - Δ = unmake_named_tuple(ch, Δps) - @show 2 typeof(ch) typeof(Δ) - Δ === nothing && return nothing - Δ === ZeroTangent() && return ZeroTangent() - y = mapvalue(recurse, ch, Δ) - @show 3 typeof(model) typeof(ch) typeof(Δ) typeof(y) +struct RestructureFromNamedTupleBackWalk <: AbstractWalk end + +function (::RestructureFromNamedTupleBackWalk)(recurse, ps, Δmodel) + @show 1 typeof(Δmodel) typeof(ps) + Δm = make_named_tuple(Δmodel) + @show 2 typeof(Δm) ps Δm + # Δm isa Float64 && return Δm + # Δm isa Array && return Δm + # ps isa Float64 && return ps + # ps isa Array && return ps + # return nothing + Δm === nothing && return nothing + Δm === ZeroTangent() && return ZeroTangent() + y = mapvalue(recurse, ps, Δm) + @show 3 typeof(Δmodel) typeof(Δm) typeof(y) return y end -struct TrainableNamedTupleWalk <: AbstractWalk end - -function (::TrainableNamedTupleWalk)(recurse, x) - ch = trainable(x) - y = map(recurse, make_named_tuple(ch)) - return y -end - - function map_commons(f, x::NamedTuple{xkeys}, y) where {xkeys} ykeys = propertynames(y) vals = map(k -> k in ykeys ? f(x[k], getproperty(y, k)) : x[k], xkeys) @@ -270,12 +300,18 @@ function map_commons(f, x::Vector, y) return vals end -make_named_tuple(x::NamedTuple) = x +make_named_tuple(x) = x make_named_tuple(x::AbstractDict{Symbol}) = NamedTuple(x) make_named_tuple(x::AbstractDict) = NamedTuple(Symbol("_", k) => v for (k, v) in pairs(x)) make_named_tuple(x::Tuple) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x) make_named_tuple(x::Vector) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x) +make_named_tuple(x::Tangent{<:Any,<:NamedTuple}) = x +make_named_tuple(x::Tangent{<:Any,<:AbstractDict{Symbol}}) = NamedTuple(x) +make_named_tuple(x::Tangent{<:Any,<:AbstractDict}) = NamedTuple(Symbol("_", k) => v for (k, v) in pairs(x)) +make_named_tuple(x::Tangent{<:Any,<:Tuple}) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x) +make_named_tuple(x::Tangent{<:Any,<:Vector}) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x) + unmake_named_tuple(x::NamedTuple, ps) = ps diff --git a/src/utils.jl b/src/utils.jl index a8f3eb8..328f61a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,9 @@ mapvalue(f, x...) = map(f, x...) +mapvalue(f, x::NamedTuple, ys::NamedTuple...) = map(f, x, ys...) +mapvalue(f, x, y::NamedTuple{ykeys}) where {ykeys} = + NamedTuple{ykeys}((f(getproperty(x ,k), yk) for (k, yk) in pairs(y))) # used in rrule for restructure_from_nt + mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x) # without theses, tuples are returned instead of NamedTuples diff --git a/test/trainables.jl b/test/trainables.jl index 9fb3fa5..e8a8731 100644 --- a/test/trainables.jl +++ b/test/trainables.jl @@ -200,154 +200,133 @@ end @test model′.a.b[1] === v2.a.b._1 # no copies @test model′.a.a === v2.a.a end + + @testset "gradient of trainables_nt(m)[1]" begin + f(m) = trainables_nt(m)[1] + @test gradient(m -> f(m)[1], m1)[1] == [1,0,0] + @test gradient(m -> f(m)._1[2], m2)[1] == ([0.0, 1.0, 0.0], nothing) + @test gradient(m -> f(m)._1[3], (m1, m1))[1] == ([0,0,1], nothing) + @test gradient(m -> f(m).x[1], m3)[1] == (x = [1,0,0], y = nothing, z = nothing) + @test gradient(m -> f(m).x[2], m4)[1] == (x = [0,1,0], y = nothing, z = nothing) + + g5 = gradient(m -> f(m).a._1.z[3], m5)[1] + @test g5.a[1].z == [0,0,1] + @test g5.a[2] === nothing + + g6 = gradient(m -> imag(f(m).b[1]), m6)[1] + @test g6 == (a = nothing, b = ComplexF64[0.0 + 1.0im], c = nothing) + @test eltype(g6.b) == ComplexF64 + + # TODO add second derivative tests when support is ready + # @testset "second derivative" begin + # @test gradient([1,2,3.0]) do v + # sum(abs2, gradient(m -> sum(abs2, f(m)), (v, [4,5,6.0]))[1][1]) + # end[1] ≈ [8,16,24] + # # With Diffractor, non-leaf _grad!(x, dx, off, flat::AbstractVector) gets double-wrapped dx: + # # off = (0, 3), dx = Tangent{Tangent{Tuple{Vector{Float64}, Vector{Float64}}, ... + # # until you add explicit double-unwrap: base(dx::Tangent{<:Tangent}) = backing(dx).backing + # # With Zygote, instead: + # # dx = Tangent{Any}(backing = Tangent{Any}([4.0, 8.0, 12.0], ZeroTangent()),) + + # @test gradient([1,2,3.0]) do v + # sum(gradient(m -> sum(trainables_nt(m)[1])^3, (v, [4,5,6.0]))[1][1]) + # end[1] == [378, 378, 378] + + # VERSION >= v"1.10" && @test gradient([1,2,3.0]) do v + # sum(abs2, gradient(m -> sum(abs2, trainables_nt(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1]) + # end[1] ≈ [8,16,24] + # # Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z) + # # Diffractor error in perform_optic_transform + # end + end end -# @testset "gradient of flatten" begin -# @test gradient(m -> trainables_nt(m)[1][1], m1)[1] == [1,0,0] -# @test gradient(m -> trainables_nt(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) -# @test gradient(m -> trainables_nt(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing) -# @test gradient(m -> trainables_nt(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) -# @test gradient(m -> trainables_nt(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0]) - -# g5 = gradient(m -> trainables_nt(m)[1][3], m5)[1] -# @test g5.a[1].x == [0,0,1] -# @test g5.a[2] === nothing - -# g6 = gradient(m -> imag(trainables_nt(m)[1][4]), m6)[1] -# @test g6.a == [0,0,0] -# @test g6.a isa Vector{Float64} -# @test g6.b == [0+im] - -# g8 = gradient(m -> sum(abs2, trainables_nt(m)[1]), m8)[1] -# @test g8[1].x == [2,4,6] -# @test g8[2].b.x == [8] -# @test g8[3] == [[10.0]] - -# g9 = gradient(m -> sum(sqrt, trainables_nt(m)[1]), m9)[1] -# @test g9.c === nothing - -# @testset "second derivative" begin -# @test gradient([1,2,3.0]) do v -# sum(abs2, gradient(m -> sum(abs2, trainables_nt(m)[1]), (v, [4,5,6.0]))[1][1]) -# end[1] ≈ [8,16,24] -# # With Diffractor, non-leaf _grad!(x, dx, off, flat::AbstractVector) gets double-wrapped dx: -# # off = (0, 3), dx = Tangent{Tangent{Tuple{Vector{Float64}, Vector{Float64}}, ... -# # until you add explicit double-unwrap: base(dx::Tangent{<:Tangent}) = backing(dx).backing -# # With Zygote, instead: -# # dx = Tangent{Any}(backing = Tangent{Any}([4.0, 8.0, 12.0], ZeroTangent()),) - -# @test gradient([1,2,3.0]) do v -# sum(gradient(m -> sum(trainables_nt(m)[1])^3, (v, [4,5,6.0]))[1][1]) -# end[1] == [378, 378, 378] - -# VERSION >= v"1.10" && @test gradient([1,2,3.0]) do v -# sum(abs2, gradient(m -> sum(abs2, trainables_nt(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1]) -# end[1] ≈ [8,16,24] -# # Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z) -# # Diffractor error in perform_optic_transform -# end - -# false && @testset "using Yota" begin -# @test Yota_gradient(m -> trainables_nt(m)[1][1], m1)[1] == [1,0,0] -# @test Yota_gradient(m -> trainables_nt(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) -# @test Yota_gradient(m -> trainables_nt(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing) -# @test Yota_gradient(m -> trainables_nt(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) -# @test Yota_gradient(m -> trainables_nt(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0]) - -# g5 = Yota_gradient(m -> trainables_nt(m)[1][3], m5)[1] -# @test g5.a[1].x == [0,0,1] -# @test g5.a[2] === nothing - -# g6 = Yota_gradient(m -> imag(trainables_nt(m)[1][4]), m6)[1] -# @test g6.a == [0,0,0] -# @test g6.a isa Vector{Float64} -# @test g6.b == [0+im] - -# g8 = Yota_gradient(m -> sum(abs2, trainables_nt(m)[1]), m8)[1] -# @test g8[1].x == [2,4,6] -# @test g8[2].b.x == [8] -# @test g8[3] == [[10.0]] - -# g9 = Yota_gradient(m -> sum(sqrt, trainables_nt(m)[1]), m9)[1] -# @test g9.c === nothing -# end -# end -# @testset "gradient of rebuild" begin -# re1 = trainables_nt(m1)[2] -# @test gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] -# re2 = trainables_nt(m2)[2] -# @test gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0] -# re3 = trainables_nt(m3)[2] -# @test gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0] -# @test gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0] - -# re4 = trainables_nt(m4)[2] -# @test gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0] -# @test gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0] -# @test gradient(rand(6)) do x -# m = re4(x) -# m.x[1] + 2*m.y[2] + 3*m.z[3] -# end[1] == [1,2,0, 0,0,3] - -# re7 = trainables_nt(m7)[2] -# @test gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1] -# @test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] -# @test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] - -# v8, re8 = trainables_nt(m8) -# @test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] -# @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] - -# re9 = trainables_nt(m9)[2] -# @test gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] - -# @testset "second derivative" begin -# @test_broken gradient(collect(1:6.0)) do y -# sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1]) -# end[1] ≈ [8,16,24,0,0,0] -# # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} -# # with Zygote, which can be fixed by: -# # Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,) - -# @test_broken gradient(collect(1:6.0)) do y -# sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1]) -# end[1] ≈ [0,0,0,32,40,48] -# # Not fixed by this: -# # Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,) -# end - -# false && @testset "using Yota" begin +# @testset "gradient of re(ps)" begin # re1 = trainables_nt(m1)[2] -# @test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] -# re2 = trainables_nt(m2)[2] -# @test Yota_gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0] -# re3 = trainables_nt(m3)[2] -# @test Yota_gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0] -# @test Yota_gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0] - -# re4 = trainables_nt(m4)[2] -# @test Yota_gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0] -# @test Yota_gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0] -# @test Yota_gradient(rand(6)) do x -# m = re4(x) -# m.x[1] + 2*m.y[2] + 3*m.z[3] -# end[1] == [1,2,0, 0,0,3] - -# re7 = trainables_nt(m7)[2] -# @test Yota_gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1] -# @test Yota_gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] -# @test Yota_gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] - -# v8, re8 = trainables_nt(m8) -# @test Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] -# @test Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] - -# re9 = trainables_nt(m9)[2] -# @test Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14] -# end +# @test gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] +# ps2, re2 = trainables_nt(m2) +# @test gradient(x -> re2(x)[1][2], ps2)[1] == (_1 = [0.0, 1.0, 0.0], _2 = nothing) +# ps3, re3 = trainables_nt(m3) +# @test gradient(x -> re3(x).x[3], ps3)[1] == (x = [0.0, 0.0, 1.0], y = nothing, z = nothing) +# @test gradient(x -> re3(x).z[1], ps3)[1] == (x = nothing, y = nothing, z = [1.0, 0.0, 0.0]) + +# ps4, re4 = trainables_nt(m4) +# @test gradient(x -> re4(x).y[2], ps4)[1] == (x = nothing, y = [0.0, 1.0, 0.0], z = nothing) +# @test gradient(ps4) do x +# m = re4(x) +# m.x[1] + 2*m.y[2] + 3*m.z[3] +# end[1] == (x = [1.0, 0.0, 0.0], y = [0.0, 2.0, 0.0], z = [0.0, 0.0, 3.0]) + +# ps7, re7 = trainables_nt(m7) +# @test gradient(x -> re7(x).a[2][3], ps7)[1] == (a = (_1 = nothing, _2 = [0.0, 0.0, 1.0]),) +# @test gradient(x -> re7(x).b[2][2], ps7)[1] === nothing +# @test gradient(x -> re7(x).c[2][1], ps7)[1] === nothing + +# ps8, re8 = trainables_nt(m8) +# @test gradient(x -> sum(abs2, re8(x)[1].y), ps8)[1] == (_1 = (y = [2.0, 4.0, 6.0], x = nothing), _2 = nothing, _3 = nothing) +# @test gradient(x -> only(sum(re8(x)[3]))^2, ps8)[1] == (_1 = nothing, _2 = nothing, _3 = (_1 = [10.0],)) + +# ps9, re9 = trainables_nt(m9) +# @test gradient(x -> sum(abs2, re9(x).c[1]), ps9)[1] == (a = nothing, b = nothing, c = (_1 = Float32[8.0 12.0; 10.0 14.0], _2 = nothing)) + +# # TODO add second derivative tests when support is ready +# # @testset "second derivative" begin +# # @test_broken gradient(collect(1:6.0)) do y +# # sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1]) +# # end[1] ≈ [8,16,24,0,0,0] +# # # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} +# # # with Zygote, which can be fixed by: +# # # Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,) + +# # @test_broken gradient(collect(1:6.0)) do y +# # sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1]) +# # end[1] ≈ [0,0,0,32,40,48] +# # # Not fixed by this: +# # # Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,) +# # end # end +Zygote.wrap_chainrules_output(::AbstractArray{Union{}}) = nothing + +# @testset "gradient of re(ps)" begin +# ps1, re1 = trainables_nt(m1) +# v1 = ComponentVector(ps1) +# @test gradient(x -> re1(x)[1], v1)[1] == [1,0,0] +# ps2, re2 = trainables_nt(m2) +# v2 = ComponentVector(ps2) +# @test gradient(x -> re2(x)[1][2], v2)[1] == (_1 = [0.0, 1.0, 0.0], _2 = nothing) +# ps3, re3 = trainables_nt(m3) +# @test gradient(x -> re3(x).x[3], ps3)[1] == (x = [0.0, 0.0, 1.0], y = nothing, z = nothing) +# @test gradient(x -> re3(x).z[1], ps3)[1] == (x = nothing, y = nothing, z = [1.0, 0.0, 0.0]) + +# ps4, re4 = trainables_nt(m4) +# @test gradient(x -> re4(x).y[2], ps4)[1] == (x = nothing, y = [0.0, 1.0, 0.0], z = nothing) +# @test gradient(ps4) do x +# m = re4(x) +# m.x[1] + 2*m.y[2] + 3*m.z[3] +# end[1] == (x = [1.0, 0.0, 0.0], y = [0.0, 2.0, 0.0], z = [0.0, 0.0, 3.0]) + +# ps7, re7 = trainables_nt(m7) +# @test gradient(x -> re7(x).a[2][3], ps7)[1] == (a = (_1 = nothing, _2 = [0.0, 0.0, 1.0]),) +# @test gradient(x -> re7(x).b[2][2], ps7)[1] === nothing +# @test gradient(x -> re7(x).c[2][1], ps7)[1] === nothing + +# ps8, re8 = trainables_nt(m8) +# @test gradient(x -> sum(abs2, re8(x)[1].y), ps8)[1] == (_1 = (y = [2.0, 4.0, 6.0], x = nothing), _2 = nothing, _3 = nothing) +# @test gradient(x -> only(sum(re8(x)[3]))^2, ps8)[1] == (_1 = nothing, _2 = nothing, _3 = (_1 = [10.0],)) + +# ps9, re9 = trainables_nt(m9) +# @test gradient(x -> sum(abs2, re9(x).c[1]), ps9)[1] == (a = nothing, b = nothing, c = (_1 = Float32[8.0 12.0; 10.0 14.0], _2 = nothing)) +# end + +m = (collect(1:3.0), collect(4:6.0)) +ps, re = trainables_nt(m2) +v = ComponentVector(ps) +Zygote.refresh() +gradient(x -> re(x)[1][2], v2)[1] + + # @testset "Flux issue 1826" begin # v, re = trainables_nt((x=[1,2.0], y=[3,4,5.0])) # @test gradient(zero(v)) do w @@ -470,6 +449,8 @@ end[1] # # TODO -# - [] Name? +# - [] `trainables_nt` is ok or change name? # - [] Should the named tuple contain NamedTuple() leaves? # - [] Optimize performance and improve type stability +# - [] Second order derivatives for `trainables_nt(m)[1]` +# - [] Second order derivatives for `re(ps)`