From 02643d063fd2ffae56824197b1b4d8d738d11390 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 10 Mar 2024 16:07:52 -0400 Subject: [PATCH 1/7] let Parallel(+, f)(x, y, z) work like broadcasting --- src/layers/basic.jl | 26 +++++++++++++++++++------- test/layers/basic.jl | 9 +++++++-- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 254f06db0c..963cda162d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -475,8 +475,11 @@ end Create a layer which passes an input array to each path in `layers`, before reducing the output with `connection`. -Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`. -If called with multiple inputs, one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`. +Obeys the similar rules to broadcasting: +* Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`. +* With multiple `inputs` and just one layer, it is instead `connection([layer(x) for x in inputs]...)`. +* With multiple inputs and multiple layers, one input is passed to each layer, + thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`. Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor. These can be accessed by indexing: `m[1] == m[:name]` is the first layer. @@ -528,23 +531,32 @@ end @layer :expand Parallel -(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) -(m::Parallel)(xs::Tuple) = m(xs...) +(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) # one argument function _parallel_check(layers, xs) nl = length(layers) nx = length(xs) if (nl != nx) - throw(ArgumentError(lazy"Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs")) + throw(ArgumentError(lazy"Parallel with $nl > 1 sub-layers can take one input or $nl inputs, but got $nx inputs")) end end ChainRulesCore.@non_differentiable _parallel_check(nl, nx) +(m::Parallel)(xs::Tuple) = m(xs...) + function (m::Parallel)(xs...) - _parallel_check(m.layers, xs) - m.connection(map(|>, xs, Tuple(m.layers))...) + if length(m.layers) == 1 + f = only(m.layers) + m.connection(map(x -> f(x), xs)...) # multiple arguments, one layer + else + _parallel_check(m.layers, xs) + m.connection(map(|>, xs, Tuple(m.layers))...) # multiple arguments & multiple layers + end end +# (m::Parallel{<:Any, <:Union{Tuple{Any}, NamedTuple{<:Any, <:Tuple{Any}}}})(xs...) = +# m.connection(map(x -> only(m.layers)(x), xs)...) # multiple arguments, one layer + Base.getindex(m::Parallel, i) = m.layers[i] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]) Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) = diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 95da13f0c9..83795a09a0 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -234,11 +234,14 @@ using Flux: activations end @testset "vararg input" begin - inputs = randn(10), randn(5), randn(4) + inputs = randn32(10), randn32(5), randn32(4) @test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,) @test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,) @test_throws ArgumentError Parallel(+, sin, cos)(1,2,3) # wrong number of inputs - @test Parallel(+, sin, cos)(pi/2) ≈ 1 + @test Parallel(+, sin, cos)(pi/2) ≈ 1 # one input, several layers + @test Parallel(/, abs)(3, -4) ≈ 3/4 # one layer, several inputs + @test Parallel(/, abs)((3, -4)) ≈ 3/4 + @test Parallel(/; f=abs)(3, -4) ≈ 3/4 end @testset "named access" begin @@ -270,6 +273,8 @@ using Flux: activations @test CNT[] == 2 Parallel(f_cnt, sin)(1) @test CNT[] == 3 + Parallel(f_cnt, sin)(1,2,3) + @test CNT[] == 4 end # Ref https://github.com/FluxML/Flux.jl/issues/1673 From a9c9a21f4a58b3345e6c7303f16fa83aedf53129 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 13 Mar 2024 08:40:11 -0400 Subject: [PATCH 2/7] add (::Chain)(xs...) method --- src/layers/basic.jl | 18 +++++++++++++++++- test/layers/basic.jl | 2 ++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 963cda162d..17050daadf 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -28,6 +28,20 @@ julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x) true ``` +A chain may be called with multiple arguments, which is equivalent to calling it +with one tuple of these arguments. Such a tuple is understood by [`Parallel`](@ref) +to mean the same as several arguments: + +```jldoctest +julia> Chain(println, println)(1, 2, 3) # three arguments become a tuple +(1, 2, 3) +nothing + +julia> Chain(x->@show(x), Parallel(+, inv, abs2))(4, 5) # returns 1/4 + 5^2 +x = (4, 5) +25.25 +``` + For large models, there is a special type-unstable path which can reduce compilation times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`. This feature is somewhat experimental, beware! @@ -46,9 +60,10 @@ end @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex, Base.keys, Base.firstindex -@layer :expand Chain # the + opts-in to container-style pretty-printing +@layer :expand Chain # the option :expand opts-in to container-style pretty-printing (c::Chain)(x) = _applychain(c.layers, x) +(c::Chain)(x, ys...) = _applychain(c.layers, (x, ys...)) @generated function _applychain(layers::Tuple{Vararg{Any,N}}, x) where {N} symbols = vcat(:x, [gensym() for _ in 1:N]) @@ -68,6 +83,7 @@ end Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]) Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = Chain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i])) + function Base.show(io::IO, c::Chain) print(io, "Chain(") _show_layers(io, c.layers) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 83795a09a0..b203eaf099 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -35,6 +35,8 @@ using Flux: activations c = Chain(Dense(10, 5, σ), Dense(5, 2), Dense(2, 1, relu)) @test c[1] == c[begin] @test c[3] == c[end] + + @test Chain(identity)(1,2,3) == (1,2,3) # multiple args become a tuple end @testset "Activations" begin From 44f5746f2d2f849460f3fb4c4f9d4590a4fd1af6 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 13 Mar 2024 08:40:20 -0400 Subject: [PATCH 3/7] more examples --- src/layers/basic.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 17050daadf..77e8d45cd5 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -505,6 +505,25 @@ and [`Maxout`](@ref) which reduces by broadcasting `max`. # Examples +```jldoctest +julia> p = Parallel(+, abs2, sqrt); + +julia> p(3, 4) # == 3^2 + √4, two functions two inputs +11.0 + +julia> p((3, 4)) # tuple is always splatted +11.0 + +julia> p(4, 4) # == 4^2 + √4, one input used twice +18.0 + +julia> Parallel(hcat, inv)(1, 2, 4) # one function three inputs +1×3 Matrix{Float64}: + 1.0 0.5 0.25 +``` + +With Flux layers: + ```jldoctest julia> model = Chain(Dense(3 => 5), Parallel(vcat, Dense(5 => 4), Chain(Dense(5 => 7), Dense(7 => 4))), From 2945667c496174dd96475e7d8a65a550ad566725 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 13 Mar 2024 08:59:11 -0400 Subject: [PATCH 4/7] correction --- src/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 77e8d45cd5..56e5af5d69 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -514,7 +514,7 @@ julia> p(3, 4) # == 3^2 + √4, two functions two inputs julia> p((3, 4)) # tuple is always splatted 11.0 -julia> p(4, 4) # == 4^2 + √4, one input used twice +julia> p(4) # == 4^2 + √4, one input used twice 18.0 julia> Parallel(hcat, inv)(1, 2, 4) # one function three inputs From 5fae6a97cc9df1c5d498073633bd2fb68f48e42c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 13 Mar 2024 09:14:01 -0400 Subject: [PATCH 5/7] change implementation to dispatch --- src/layers/basic.jl | 23 +++++++++++------------ test/layers/basic.jl | 2 +- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 56e5af5d69..3fb03e1621 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -554,6 +554,8 @@ struct Parallel{F, T<:Union{Tuple, NamedTuple}} layers::T end +_ParallelONE{T} = Parallel{T, <:Union{Tuple{Any}, NamedTuple{<:Any, <:Tuple{Any}}}} + Parallel(connection, layers...) = Parallel(connection, layers) function Parallel(connection; kw...) layers = NamedTuple(kw) @@ -577,20 +579,17 @@ function _parallel_check(layers, xs) end ChainRulesCore.@non_differentiable _parallel_check(nl, nx) -(m::Parallel)(xs::Tuple) = m(xs...) - -function (m::Parallel)(xs...) - if length(m.layers) == 1 - f = only(m.layers) - m.connection(map(x -> f(x), xs)...) # multiple arguments, one layer - else - _parallel_check(m.layers, xs) - m.connection(map(|>, xs, Tuple(m.layers))...) # multiple arguments & multiple layers - end +function (m::Parallel)(x, ys...) + xs = (x, ys...) + _parallel_check(m.layers, xs) + m.connection(map(|>, xs, Tuple(m.layers))...) # multiple arguments & multiple layers end -# (m::Parallel{<:Any, <:Union{Tuple{Any}, NamedTuple{<:Any, <:Tuple{Any}}}})(xs...) = -# m.connection(map(x -> only(m.layers)(x), xs)...) # multiple arguments, one layer +(m::_ParallelONE)(x, ys...) = + m.connection(map(z -> only(m.layers)(z), (x, ys...))...) # multiple arguments, one layer + +(m::Parallel)(xs::Tuple) = m(xs...) # tuple is always splatted +(m::_ParallelONE)(xs::Tuple) = m(xs...) # solves an ambiguity Base.getindex(m::Parallel, i) = m.layers[i] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index b203eaf099..cd7918e487 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -230,7 +230,7 @@ using Flux: activations end @testset "concat size" begin - input = randn(10, 2) + input = randn32(10, 2) @test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), identity)(input)) == (10, 4) @test size(Parallel(hcat, one = Dense(10, 10), two = identity)(input)) == (10, 4) end From ff73b44f5af732b2617bff0ab0b7c3f0b390c26c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 13 Mar 2024 10:28:38 -0400 Subject: [PATCH 6/7] nicer errors when called on zero inputs --- src/layers/basic.jl | 5 ++++- test/layers/basic.jl | 7 +++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 3fb03e1621..a8cb0a9d1c 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -572,7 +572,8 @@ end function _parallel_check(layers, xs) nl = length(layers) - nx = length(xs) + @assert nl > 1 # dispatch handles nl==1 cases + nx = length(xs) if (nl != nx) throw(ArgumentError(lazy"Parallel with $nl > 1 sub-layers can take one input or $nl inputs, but got $nx inputs")) end @@ -591,6 +592,8 @@ end (m::Parallel)(xs::Tuple) = m(xs...) # tuple is always splatted (m::_ParallelONE)(xs::Tuple) = m(xs...) # solves an ambiguity +(m::Parallel)() = throw(ArgumentError("Parallel layer cannot take 0 inputs")) + Base.getindex(m::Parallel, i) = m.layers[i] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]) Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) = diff --git a/test/layers/basic.jl b/test/layers/basic.jl index cd7918e487..d09579c257 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -261,8 +261,15 @@ using Flux: activations end @testset "trivial cases" begin + # zero inputs, always an error + @test_throws ArgumentError Parallel(hcat)() + @test_throws ArgumentError Parallel(hcat, inv)() + @test_throws ArgumentError Parallel(hcat, inv, sqrt)() + + # zero layers -- not useful... can we make this an error without a breaking change? @test Parallel(hcat) isa Parallel{typeof(hcat), Tuple{}} # not a NamedTuple @test Parallel(hcat)(1) == hcat() + @test Parallel(hcat, inv)(2) == hcat(1/2) # still calls connection once. end From 288f8d50bae7fb193c070d0e1f57a9567dbd3106 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 13 Mar 2024 11:55:12 -0400 Subject: [PATCH 7/7] disallow zero layers, let's try this out --- src/layers/basic.jl | 3 ++- test/layers/basic.jl | 7 ++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index a8cb0a9d1c..3c615ae06d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -562,9 +562,10 @@ function Parallel(connection; kw...) if :layers in keys(layers) || :connection in keys(layers) throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`")) end - isempty(layers) && return Parallel(connection, ()) Parallel(connection, layers) end +Parallel(connection, layers::Union{Tuple{}, @NamedTuple{}}) = + throw(ArgumentError("cannot construct a Parallel layer with no sub-layers")) @layer :expand Parallel diff --git a/test/layers/basic.jl b/test/layers/basic.jl index d09579c257..8e33340611 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -266,11 +266,8 @@ using Flux: activations @test_throws ArgumentError Parallel(hcat, inv)() @test_throws ArgumentError Parallel(hcat, inv, sqrt)() - # zero layers -- not useful... can we make this an error without a breaking change? - @test Parallel(hcat) isa Parallel{typeof(hcat), Tuple{}} # not a NamedTuple - @test Parallel(hcat)(1) == hcat() - - @test Parallel(hcat, inv)(2) == hcat(1/2) # still calls connection once. + # zero layers -- not useful... now made an error + @test_throws ArgumentError Parallel(hcat) end @testset "connection is called once" begin