diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 7d7b1f3849..961f653f68 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -538,7 +538,7 @@ end This layer behaves differently based on input type: -1. If input `x` is a tuple of length `N`, matching the number of `layers`, +1. If input `x` is a tuple of length N (or the input is `xs` with N `x`'s), matching the number of `layers`, then each layer receives a new input `x[i]` combined with the previous output `y[i-1]` using `connection`. Thus `(y1, y2, y3) = PairwiseFusion(connection, layer1, layer2, layer3)((x1, x2, x3))` may be drawn as: @@ -567,7 +567,7 @@ end ## Returns -A tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above). +A tuple of length N with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above). """ struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}} connection::F @@ -597,6 +597,7 @@ function (m::PairwiseFusion)(x::T) where {T} _pairwise_check(x, m.layers, T) applypairwisefusion(m.layers, m.connection, x) end +(m::PairwiseFusion)(xs...) = m(xs) @generated function applypairwisefusion(layers::Tuple{Vararg{<:Any,N}}, connection, x::T) where {N, T} y_symbols = [gensym() for _ in 1:(N + 1)] diff --git a/test/layers/basic.jl b/test/layers/basic.jl index d4a43dbdf3..d66aad4f56 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -365,4 +365,7 @@ end @test length(y) == 2 @test size(y[1]) == (10, 10) @test size(y[2]) == (1, 10) + + @test PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(2, 10, 20) == (3, [5, 12], [125, 1728, 8000]) + @test PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(7) == (8, [10, 9], [1000, 729, 343]) end