Skip to content

Commit

Permalink
Merge pull request #1983 from theabhirath/pairwise-fusion-2
Browse files Browse the repository at this point in the history
`PairwiseFusion` layer, take 2
  • Loading branch information
ToucheSir authored Jun 6, 2022
2 parents f86b356 + d0f0a29 commit 0b01b77
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 14 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Flux Release Notes

## v0.13.4
* Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)

## v0.13
* After a deprecations cycle, the datasets in `Flux.Data` have
been removed in favour of MLDatasets.jl.
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export gradient
# Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.)
Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`")

export Chain, Dense, Maxout, SkipConnection, Parallel,
export Chain, Dense, Maxout, SkipConnection, Parallel, PairwiseFusion,
RNN, LSTM, GRU, GRUv3,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Expand Down
132 changes: 121 additions & 11 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end

Chain(xs...) = Chain(xs)
function Chain(; kw...)
:layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
:layers in keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
isempty(kw) && return Chain(())
Chain(values(kw))
end
Expand Down Expand Up @@ -67,7 +67,7 @@ end

Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Chain(NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i]))
Chain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i]))
function Base.show(io::IO, c::Chain)
print(io, "Chain(")
_show_layers(io, c.layers)
Expand Down Expand Up @@ -487,7 +487,7 @@ end
Parallel(connection, layers...) = Parallel(connection, layers)
function Parallel(connection; kw...)
layers = NamedTuple(kw)
if :layers in Base.keys(layers) || :connection in Base.keys(layers)
if :layers in keys(layers) || :connection in keys(layers)
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
end
isempty(layers) && return Parallel(connection, ())
Expand All @@ -498,28 +498,138 @@ end

(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
(m::Parallel)(xs::Tuple) = m(xs...)
function (m::Parallel)(xs...)
nl = length(m.layers)
nx = length(xs)
if nl != nx

function _parallel_check(layers, xs)
nl = length(layers)
nx = length(xs)
if (nl != nx)
throw(ArgumentError("Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs"))
end
end
ChainRulesCore.@non_differentiable _parallel_check(nl, nx)

function (m::Parallel)(xs...)
_parallel_check(m.layers, xs)
m.connection(map(|>, xs, Tuple(m.layers))...)
end

Base.getindex(m::Parallel, i) = m.layers[i]
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) =
Parallel(m.connection, NamedTuple{Base.keys(m)[i]}(Tuple(m.layers)[i]))
Parallel(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i]))

Base.keys(m::Parallel) = Base.keys(getfield(m, :layers))
Base.keys(m::Parallel) = keys(getfield(m, :layers))

function Base.show(io::IO, m::Parallel)
print(io, "Parallel(", m.connection, ", ")
_show_layers(io, m.layers)
print(io, ")")
end

"""
PairwiseFusion(connection, layers...)
## Arguments
- `connection`: A function taking 2 inputs and combining them into a single output
- `layers`: The layers whose outputs are combined
## Inputs
This layer behaves differently based on input type:
1. If input `x` is a tuple of length N (or the input is `xs` with N `x`'s), matching the number of `layers`,
then each layer receives a new input `x[i]` combined with the previous output `y[i-1]` using `connection`.
Thus `(y1, y2, y3) = PairwiseFusion(connection, layer1, layer2, layer3)((x1, x2, x3))`
may be drawn as:
```
x1 → layer1 → y1 ↘
connection → layer2 → y2 ↘
x2 ↗ connection → layer3 → y3
x3 ↗
```
... or written as:
```julia
y1 = layer1(x1)
y2 = layer2(connection(x2, y1))
y3 = layer3(connection(x3, y2))
```
2. With just one input, each layer receives the same `x` combined with the previous output.
Thus `y = PairwiseFusion(connection, layers...)(x)` obeys:
```julia
y[1] == layers[1](x)
for i in 2:length(layers)
y[i] == connection(x, layers[i](y[i-1]))
end
```
## Returns
A tuple of length N with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
"""
struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}}
connection::F
layers::T
end

PairwiseFusion(connection, layers...) = PairwiseFusion(connection, layers)
function PairwiseFusion(connection; kw...)
layers = NamedTuple(kw)
if :layers in keys(layers) || :connection in keys(layers)
throw(ArgumentError("a PairwiseFusion layer cannot have a named sub-layer called `connection` or `layers`"))
end
isempty(layers) && return PairwiseFusion(connection, ())
PairwiseFusion(connection, layers)
end

function _pairwise_check(x, layers, T)
lx = length(x)
N = length(layers)
if T <: Tuple && lx != N
throw(ArgumentError("PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs"))
end
end
ChainRulesCore.@non_differentiable _pairwise_check(lx, N, T)

function (m::PairwiseFusion)(x::T) where {T}
_pairwise_check(x, m.layers, T)
applypairwisefusion(m.layers, m.connection, x)
end
(m::PairwiseFusion)(xs...) = m(xs)

@generated function applypairwisefusion(layers::Tuple{Vararg{<:Any,N}}, connection, x::T) where {N, T}
y_symbols = [gensym() for _ in 1:(N + 1)]
getinput(i) = T <: Tuple ? :(x[$i]) : :x
calls = [:($(y_symbols[N + 1]) = $(getinput(1)))]
for i in 1:N - 1
push!(calls, quote
$(y_symbols[i]) = layers[$i]($(y_symbols[N + 1]))
$(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1)))
end)
end
push!(calls, :($(y_symbols[N]) = layers[$N]($(y_symbols[N + 1]))))
push!(calls, :(return tuple($(Tuple(y_symbols[1:N])...))))
return Expr(:block, calls...)
end
applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x)

@functor PairwiseFusion

Base.getindex(m::PairwiseFusion, i) = m.layers[i]
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
Base.getindex(m::PairwiseFusion{<:Any, <:NamedTuple}, i::AbstractVector) =
PairwiseFusion(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i]))

Base.keys(m::PairwiseFusion) = keys(getfield(m, :layers))

function Base.show(io::IO, m::PairwiseFusion)
print(io, "PairwiseFusion(", m.connection, ", ")
_show_layers(io, m.layers)
print(io, ")")
end

"""
Embedding(in => out; init=randn)
Expand Down Expand Up @@ -556,7 +666,7 @@ end
@functor Embedding

Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in))

(m::Embedding)(x::Integer) = m.weight[:, x]
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
Expand All @@ -565,7 +675,7 @@ function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
return m(onecold(x))
end

function Base.show(io::IO, m::Embedding)
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end
5 changes: 3 additions & 2 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

for T in [
:Chain, :Parallel, :SkipConnection, :Recur, :Maxout # container types
:Chain, :Parallel, :SkipConnection, :Recur, :Maxout, :PairwiseFusion # container types
]
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
Expand All @@ -25,7 +25,7 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
for k in Base.keys(obj)
_big_show(io, obj[k], indent+2, k)
end
elseif obj isa Parallel{<:Any, <:NamedTuple}
elseif obj isa Parallel{<:Any, <:NamedTuple} || obj isa PairwiseFusion{<:Any, <:NamedTuple}
_big_show(io, obj.connection, indent+2)
for k in Base.keys(obj)
_big_show(io, obj[k], indent+2, k)
Expand Down Expand Up @@ -53,6 +53,7 @@ _show_children(x) = trainable(x) # except for layers which hide their Tuple:
_show_children(c::Chain) = c.layers
_show_children(m::Maxout) = m.layers
_show_children(p::Parallel) = (p.connection, p.layers...)
_show_children(f::PairwiseFusion) = (f.connection, f.layers...)

for T in [
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,
Expand Down
19 changes: 19 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,22 @@ end
@test Flux.destructure(m1)[2](z1)[1].weight Flux.destructure(m1v)[2](z1)[1].weight
# Note that Flux.destructure(m1v)[2](z) has a Chain{Tuple}, as does m1v[1:2]
end

@testset "PairwiseFusion" begin
x = (rand(1, 10), rand(30, 10))
layer = PairwiseFusion(+, Dense(1, 30), Dense(30, 10))
y = layer(x)
@test length(y) == 2
@test size(y[1]) == (30, 10)
@test size(y[2]) == (10, 10)

x = rand(1, 10)
layer = PairwiseFusion(.+, Dense(1, 10), Dense(10, 1))
y = layer(x)
@test length(y) == 2
@test size(y[1]) == (10, 10)
@test size(y[2]) == (1, 10)

@test PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(2, 10, 20) == (3, [5, 12], [125, 1728, 8000])
@test PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(7) == (8, [10, 9], [1000, 729, 343])
end

0 comments on commit 0b01b77

Please sign in to comment.