Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EmbeddingBag #2031

Merged
merged 23 commits into from
Apr 18, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/models/layers.md
Original file line number Diff line number Diff line change
@@ -61,6 +61,7 @@ Parallel
Flux.Bilinear
Flux.Scale
Flux.Embedding
Flux.EmbeddingBag
```

## Normalisation & Regularisation
81 changes: 81 additions & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
@@ -692,3 +692,84 @@ end
function Base.show(io::IO, m::Embedding)
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end

"""
EmbeddingBag(in => out, reduction=Statistics.mean; init=randn)

A lookup table that stores embeddings of dimension `out` for a vocabulary of size
`in`. Similar to [`Embedding`](@ref) but can take multiple inputs in a "bag". The
embeddings of these are then reduced to a single embedding based on `reduction`.
Typically, `reduction` is `Statistics.mean`, `sum`, or `maximum`.

This layer is often used to store word embeddings and retrieve them using indices.
The inputs can take several forms:
- A scalar := single bag with a single item
- A vector := single bag with multiple items
- A matrix := multiple bags with multiple items (each column is a bag)
- A vector of vectors: multiple mags with multiple items (each vector is a bag)
- An input vector and offset vector: Explained below

The `input`/`offset` input type is similar to PyTorch's implementation. `input` should be
a vector of class indices and `offset` should be a vector representing offsets from the
first index of `input`. The first element of `offsets` must be `0`, and `offsets` should
be monotonically increasing, but the second condition is not checked.

For example, the `input`/`offset` pair `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]`/`[0, 4, 5, 7]`
is equivalent to the bags `[[1, 2, 3, 4], [5], [6, 7], [8, 9, 10]]`

# Examples
```jldoctest
julia> vocab_size, embed_size = 1000, 4;

julia> model = Flux.EmbeddingBag(vocab_size => embed_size)
Embedding(1000 => 4) # 4_000 parameters

julia> bags = [[1, 200, 25, 789], [2, 5, 10, 999]];

julia> bags_mtx = [1 2; 200 5; 25 10; 789 999]

julia> model(bags) |> summary
"4×2 Matrix{Float32}"

julia> model(bags) ≈ model(bags_mtx)
true
```
"""
struct EmbeddingBag{F, W}
weight::W
reduction::F
end

@functor EmbeddingBag

EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = Statistics.mean; init = randn32) = EmbeddingBag(init(out, in), reduction)
EmbeddingBag(weight) = EmbeddingBag(weight, Statistics.mean)

function (m::EmbeddingBag)(inputs::AbstractVector, offsets::AbstractVector)
offsets[1] == 0 || throw(ArgumentError("`offsets` must begin with 0."))
out = zeros(eltype(m.weight), size(m.weight, 1), length(offsets))
start = firstindex(inputs)
for i in eachindex(offsets[1:end-1])
out[:, i] = m(inputs[start:offsets[i+1]])
start = offsets[i+1]+1
end
out[:, end] = m(inputs[offsets[end]+1:end])
out
end
(m::EmbeddingBag)(idx::Integer) = m.weight[:, idx]
(m::EmbeddingBag)(bag::AbstractVector) = vec(m.reduction(NNlib.gather(m.weight, bag), dims=2))
(m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags))
(m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading the PyTorch docstring, it seems the main advantage of this layer is memory efficiency. So, shouldn't these be mapreduce instead of a broadcast to achieve the same feature?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, mapreduce(f, hcat, collection) is not optimized. But yes, I agree. I will add a todo for when specialized mapreduce functions are added. See: https://discourse.julialang.org/t/different-performance-between-reduce-map-and-mapreduce/85149 and JuliaLang/julia#31137.

julia> (m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags))
julia> (m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags)))

julia> test(m::EmbeddingBag, bags::AbstractVector{<:AbstractVector})  = mapreduce(m, hcat, bags)
julia> test(m::EmbeddingBag, bags::AbstractMatrix) = mapreduce(m, hcat, eachcol(bags))
julia> e = Flux.EmbeddingBag(100=>64)
julia> bags = [[rand(1:100) for _ in 1:3] for _ in 1:1000]
julia> @btime e(bags);
  709.630 μs (14004 allocations: 2.16 MiB)

julia> @btime test(e, bags);
  14.700 ms (15935 allocations: 124.18 MiB)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, mapreduce(f, hcat, collection) is not optimized

If this is the hurdle, then stack(f, collection) might be the solution, assuming f returns vectors. Needs using Compat, which is certainly already loaded downstream.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The really big memory cost is going to be the gradient of gather. For every column / vector, ∇gather_src is going to allocate like a copy of the weights.

https://github.com/FluxML/NNlib.jl/blob/6f74fad0a2a24e3594fc5229cc515fa25e80f877/src/gather.jl#L80

One could write a more efficient combined rule for this. Or add some thunks to the one in NNlib & wait for AD to learn to exploit them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done after this PR, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I just mean these concerns will dwarf the hcat cost. (Even on the forward pass, the thing you make to call mean on it will also be much larger.)


function (m::EmbeddingBag)(x::OneHotVector{T,L}) where {T,L}
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
return m(onecold(x))
end
function (m::EmbeddingBag)(x::OneHotMatrix{T,L}) where {T,L}
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
return m(LinearAlgebra.Transpose(onecold(x)))
end

function Base.show(io::IO, m::EmbeddingBag)
print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end
2 changes: 1 addition & 1 deletion src/layers/show.jl
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ _show_children(p::Parallel) = (p.connection, p.layers...)
_show_children(f::PairwiseFusion) = (f.connection, f.layers...)

for T in [
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, :EmbeddingBag,
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
]
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
57 changes: 57 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
@@ -311,6 +311,63 @@ import Flux: activations
@test m(OneHotVector(3, vocab_size)) ≈ m.weight[:,3]
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
end

@testset "EmbeddingBag" begin
for reduction in [sum, Statistics.mean, maximum]
vocab_size, embed_size = 10, 4
emb_bag = Flux.EmbeddingBag(vocab_size => embed_size, reduction)
emb = Flux.Embedding(emb_bag.weight)
@test size(emb_bag.weight) == (embed_size, vocab_size)

# scalar bag
@test emb_bag(2) ≈ emb_bag.weight[:,2]
@test emb_bag(3) ≈ emb(3)

# single bag (input as a vector)
x = rand(1:vocab_size, 3)
y = emb_bag(x)
z = vec(reduction(emb(x), dims=2))
@test y isa Vector{Float32}
@test y ≈ z

# PyTorch style `input`/`offset` bagging
@test emb_bag([1,3,2,4,5,7], [0,2,4]) ≈ emb_bag([[1,3], [2,4], [5,7]])
@test emb_bag([1,3,2,4,5,7], [0,2,4]) ≈ emb_bag([1 2 5; 3 4 7])
@test_throws ArgumentError emb_bag([1,2,3,4,5,6], [2,4])
@test_throws BoundsError emb_bag([1,2,3,4,5,6], [0,12])

# docstring example
@test emb_bag([1,2,3,4,5,6,7,8,9,10], [0,4,5,7]) ≈ emb_bag([[1,2,3,4], [5], [6,7], [8,9,10]])

# multiple bags (input as a vector of vectors)
x = [rand(1:vocab_size, 3) for _ in 1:4]
y = emb_bag(x)
z = reduce(hcat, reduction.(emb.(x), dims=2))
@test y isa Matrix{Float32}
@test y ≈ z

# multiple bags (input as a matrix)
x = rand(1:vocab_size, (3, 5))
xvec = collect(eachcol(x))
y = emb_bag(x)
z = reduce(hcat, reduction.(emb.(xvec), dims=2))
@test y ≈ emb_bag(xvec)
@test y ≈ z

# one hot bags. should be identical to Embedding, since the bags
# are of size 1.
@test emb_bag(Flux.OneHotVector(3, vocab_size)) ≈ emb_bag.weight[:,3]
@test emb_bag(Flux.OneHotVector(4, vocab_size)) ≈ emb(Flux.OneHotVector(4, vocab_size))
@test_throws DimensionMismatch emb_bag(Flux.OneHotVector(3, 1000))

x2 = Flux.OneHotMatrix(rand(1:vocab_size, 3), vocab_size)
y2 = emb_bag(x2)
z2 = emb(x2)
@test y2 isa Matrix{Float32}
@test y2 ≈ z2
@test_throws DimensionMismatch emb_bag(Flux.OneHotMatrix(1:5, 1000))
end
end
end

@testset "second derivatives" begin