From 2c270a9caf4bf035c6f77ceef033b5f40ae4985c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 16 Oct 2022 10:49:25 -0400 Subject: [PATCH 1/3] make outputsize work with Embedding --- src/layers/basic.jl | 13 ++++++++++--- src/outputsize.jl | 6 ++++++ test/outputsize.jl | 15 +++++++++++++++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 2a3bc9131c..7ec34b2454 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -655,13 +655,14 @@ or the corresponding [`onehot encoding`](@ref OneHotArrays.onehotbatch). For indices `x`, the result is of size `(out, size(x)...)`, allowing several batch dimensions. For one-hot `ohx`, the result is of size `(out, size(ohx)[2:end]...)`. +Note that [`outputsize`](@ref Flux.outputsize) expects `size(x)`, the indices not the one-hot array. # Examples ```jldoctest julia> emb = Embedding(26 => 4, init=Flux.identity_init(gain=22)) Embedding(26 => 4) # 104 parameters -julia> emb(2) # one column of e.weight (here not random!) +julia> emb(2) # one column of emb.weight (here not random!) 4-element Vector{Float32}: 0.0 22.0 @@ -680,6 +681,9 @@ true julia> emb(rand(1:26, (10, 1, 12))) |> size # three batch dimensions (4, 10, 1, 12) + +julia> Flux.outputsize(emb, (10, 1, 12)) # outputsize wants indices, not OneHotArray +(4, 10, 1, 12) ``` """ struct Embedding{W<:AbstractMatrix} @@ -691,8 +695,11 @@ end Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in)) (m::Embedding)(x::Integer) = m.weight[:, x] -(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x) -(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...) +(m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x) +(m::Embedding)(x::AbstractArray{<:Integer}) = reshape(m(vec(x)), :, size(x)...) + +(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1)) +(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...) (m::Embedding)(x::AbstractVector{Bool}) = m.weight * x # usually OneHotVector (m::Embedding)(x::AbstractMatrix{Bool}) = m.weight * x # usually OneHotMatrix diff --git a/src/outputsize.jl b/src/outputsize.jl index 9fd9545b5f..20c062499a 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -87,6 +87,12 @@ DimensionMismatch("Input channels must match! (7 vs. 3)") julia> outputsize([Dense(10 => 4), Dense(4 => 2)], (10, 1)) # Vector of layers becomes a Chain (2, 1) ``` + +Limitations: +* `Embedding` accepts either integers or one-hot arrays, and `ohx = onehotbatch(x, ...)` + has one more dimension than `x`. Here `outputsize` uses `size(x)`. +* At present `outputsize` does not work with recurrent layers, + `outputsize(RNN(2 => 3), (2, 1))` gives an error. This is a bug. """ function outputsize(m, inputsizes::Tuple...; padbatch=false) x = nil_input(padbatch, inputsizes...) diff --git a/test/outputsize.jl b/test/outputsize.jl index 0e5b807a60..ed6dc872ac 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -5,6 +5,8 @@ m = Dense(10, 5) @test_throws DimensionMismatch outputsize(m, (5, 2)) == (5, 1) @test outputsize(m, (10,); padbatch=true) == (5, 1) + @test outputsize(m, (10,)) == (5,) + @test outputsize(m, (10, 6, 7)) == (5, 6, 7) m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2)) @test outputsize(m, (10,); padbatch=true) == (2, 1) @@ -41,6 +43,19 @@ @test outputsize(m, (10, 10, 3, 1)) == (10, 10, 19, 1) end +@testset "embeddings" begin + # Here outputsize expects indices, not one-hot representation: + m = Embedding(3 => 4) + @test outputsize(m, (3, 7)) == (4, 3, 7) == size(m(rand(1:3, 3, 7))) + @test outputsize(m, (5, 6, 7)) == (4, 5, 6, 7) == size(m(rand(1:3, 5, 6, 7))) + + m = Chain(x -> Flux.onehotbatch(x, 1:5), Embedding(5 => 7)) + @test size(m([3,4])) == (7, 2) + @test outputsize(m, (2,)) == (7, 2) + # This works because Flux.onehotbatch([nil, nil], 1:5) makes a 5×2 OneHotMatrix + # But e.g. Flux.onehotbatch([nil, nil], 'a':'e') will not work. +end + @testset "multiple inputs" begin m = Parallel(vcat, Dense(2, 4, relu), Dense(3, 6, relu)) @test outputsize(m, (2,), (3,)) == (10,) From fa9279c934d01432859d3926c9ef429c41a51032 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 16 Oct 2022 11:06:22 -0400 Subject: [PATCH 2/3] move code --- src/layers/basic.jl | 3 --- src/outputsize.jl | 3 +++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 7ec34b2454..a6b2340cb4 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -698,9 +698,6 @@ Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(ini (m::Embedding)(x::AbstractVector{<:Integer}) = NNlib.gather(m.weight, x) (m::Embedding)(x::AbstractArray{<:Integer}) = reshape(m(vec(x)), :, size(x)...) -(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1)) -(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...) - (m::Embedding)(x::AbstractVector{Bool}) = m.weight * x # usually OneHotVector (m::Embedding)(x::AbstractMatrix{Bool}) = m.weight * x # usually OneHotMatrix (m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x,1), :)), :, size(x)[2:end]...) diff --git a/src/outputsize.jl b/src/outputsize.jl index 20c062499a..6778910c84 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -163,6 +163,9 @@ end ## fixes for layers that don't work out of the box +(m::Embedding)(x::Nil) = similar(m.weight, Nil, size(m.weight, 1)) +(m::Embedding)(x::AbstractArray{Nil}) = similar(m.weight, Nil, size(m.weight, 1), size(x)...) + for (fn, Dims) in ((:conv, DenseConvDims),) @eval begin function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims) From 86dc920e651b5955957ab3f213c0b991936e006f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 16 Oct 2022 11:33:04 -0400 Subject: [PATCH 3/3] Embedding and autosize --- src/outputsize.jl | 4 +++- test/outputsize.jl | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/outputsize.jl b/src/outputsize.jl index 6778910c84..e2f06f5713 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -288,9 +288,11 @@ is needed to make `@autosize (2,3,4) Dense(_ => 5)` return """ autosizefor(::Type, x::AbstractArray) = size(x, max(1, ndims(x)-1)) autosizefor(::Type{<:Dense}, x::AbstractArray) = size(x, 1) -autosizefor(::Type{<:Embedding}, x::AbstractArray) = size(x, 1) autosizefor(::Type{<:LayerNorm}, x::AbstractArray) = size(x, 1) +autosizefor(::Type{<:Embedding}, x::AbstractArray) = error( + "@autosize Embeeding(_ => n) cannot work, as this _ is the size of the vocabulary, not an array size") + _replaceunderscore(e, s) = e === :_ ? s : e _replaceunderscore(ex::Expr, s) = Expr(ex.head, map(a -> _replaceunderscore(a, s), ex.args)...) diff --git a/test/outputsize.jl b/test/outputsize.jl index ed6dc872ac..a98658bea6 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -190,11 +190,6 @@ end m = @autosize (2, 3, 4, 5) Dense(_ => 10) # goes by first dim, not 2nd-last @test randn(2, 3, 4, 5) |> m |> size == (10, 3, 4, 5) - @test_broken begin # outputsize fails on Embedding - m = @autosize (2, 3, 4, 5) Embedding(_ => 10) # goes by first dim, not 2nd-last - @test randn(2, 3, 4, 5) |> m |> size == (10, 3, 4, 5) - end - m = @autosize (9,) Dense(_ => div(_,2)) @test randn(9) |> m |> size == (4,) @@ -249,6 +244,11 @@ end # https://github.com/FluxML/Flux.jl/issues/2086 m = @autosize (3, 1) Chain(; c = Dense(_ => 2, sigmoid), b = BatchNorm(_, affine=false)) @test randn(Float32, 3, 32) |> m |> size == (2, 32) + + # Embedding takes a vocab size, not an array size + @test_throws ErrorException @autosize (2, 3) Embedding(_ => 10) + m = @autosize (3,) Chain(Embedding(26 => 10), Dense(_, 4)) + @test rand(1:26, 3) |> m |> size == (4, 3) end @testset "LazyLayer" begin