Skip to content

Commit

Permalink
Allow N inputs
Browse files Browse the repository at this point in the history
Add tests with `vcat`
  • Loading branch information
theabhirath committed Jun 6, 2022
1 parent 78157e3 commit d0f0a29
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ end
This layer behaves differently based on input type:
1. If input `x` is a tuple of length `N`, matching the number of `layers`,
1. If input `x` is a tuple of length N (or the input is `xs` with N `x`'s), matching the number of `layers`,
then each layer receives a new input `x[i]` combined with the previous output `y[i-1]` using `connection`.
Thus `(y1, y2, y3) = PairwiseFusion(connection, layer1, layer2, layer3)((x1, x2, x3))`
may be drawn as:
Expand Down Expand Up @@ -567,7 +567,7 @@ end
## Returns
A tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
A tuple of length N with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
"""
struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}}
connection::F
Expand Down Expand Up @@ -597,6 +597,7 @@ function (m::PairwiseFusion)(x::T) where {T}
_pairwise_check(x, m.layers, T)
applypairwisefusion(m.layers, m.connection, x)
end
(m::PairwiseFusion)(xs...) = m(xs)

@generated function applypairwisefusion(layers::Tuple{Vararg{<:Any,N}}, connection, x::T) where {N, T}
y_symbols = [gensym() for _ in 1:(N + 1)]
Expand Down
3 changes: 3 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,4 +365,7 @@ end
@test length(y) == 2
@test size(y[1]) == (10, 10)
@test size(y[2]) == (1, 10)

@test PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(2, 10, 20) == (3, [5, 12], [125, 1728, 8000])
@test PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(7) == (8, [10, 9], [1000, 729, 343])
end

0 comments on commit d0f0a29

Please sign in to comment.