Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EmbeddingBag #2031

Merged
merged 23 commits into from
Apr 18, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add _splitat
mcognetta committed Sep 21, 2022
commit 6c04ecde17d35ef17caa66c3ac42a0c570d42eef
29 changes: 21 additions & 8 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
@@ -693,6 +693,25 @@ function Base.show(io::IO, m::Embedding)
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
end


"""
_splitat(data::AbstractVector, offsets::AbstractVector{Int})

Splits a vector of data into a vector of vectors based on offsets. Each offset
specifies the next sub-vectors starting index in the `data` vector. In otherwords,
the `data` vector is chuncked into vectors from `offsets[1]` to `offsets[2]` (not including the element at `offsets[2]`), `offsets[2]` to `offsets[3]`, etc.
The last offset specifies a bag that contains everything to the right of it.

The `offsets` vector must begin with `1` and be monotonically increasing. The last element of `offsets` must be at most `length(data)`.
"""
function _splitat(data::AbstractVector, offsets::AbstractVector{Int})
offsets[firstindex(offsets)] == 1 || throw(ArgumentError("`offsets` must begin with 1."))
offsets[end] <= length(data) || throw(ArgumentError("The last element in `offsets` must be at most the length of `data`."))
issorted(offsets, lt = <=) || throw(ArgumentError("`offsets` must be monotonically increasing with no duplicates."))
newoffsets = vcat(offsets, [lastindex(data)])
return [data[offsets[i]:(i+1 > lastindex(offsets) ? end : offsets[i+1]-1)] for i in eachindex(offsets)]
end

"""
EmbeddingBag(in => out, reduction=mean; init=Flux.randn32)

@@ -709,8 +728,7 @@ The inputs can take several forms:
- An input vector and offset vector := Explained below.

The `input`/`offset` input type is similar to PyTorch's implementation. `input` should be
a vector of class indices and `offset` should be a vector representing the starting index of a bag in the `inputs` vector. The first element of `offsets` must be `1`, and `offsets` should
be monotonically increasing, but the second condition is not checked.
a vector of class indices and `offset` should be a vector representing the starting index of a bag in the `inputs` vector. The first element of `offsets` must be `1`, and `offsets` must be monotonically increasing with no duplicates.

This format is useful for dealing with flattened representations of "ragged" tensors. E.g., if you have a flat vector of class labels that need to be grouped in a non-uniform way. However, under the hood, it is just syntactic sugar for the vector-of-vectors input style.

@@ -762,12 +780,7 @@ EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = mean;
EmbeddingBag(weight) = EmbeddingBag(weight, mean)

function (m::EmbeddingBag)(inputs::AbstractVector, offsets::AbstractVector)
offsets[firstindex(offsets)] == 1 || throw(ArgumentError("`offsets` must begin with 1."))
start = firstindex(inputs)
newoffsets = vcat(offsets, [lastindex(inputs)])
slices = [inputs[offsets[i]:(i+1 > lastindex(offsets) ? end : offsets[i+1]-1)] for i in eachindex(offsets)]

return m(slices)
return m(_splitat(inputs, offsets))
end
(m::EmbeddingBag)(idx::Integer) = m.weight[:, idx]
(m::EmbeddingBag)(bag::AbstractVector{<:Integer}) = vec(m.reduction(NNlib.gather(m.weight, bag), dims=2))
27 changes: 25 additions & 2 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
@@ -313,6 +313,29 @@ import Flux: activations
end

@testset "EmbeddingBag" begin

# test _splitat
inputs = [1, 2, 3, 4, 5, 6, 7, 8, 9]
offsets_good = [1, 3, 6]
offsets_each = [1,2,3,4,5,6,7,8,9]
offsets_just_one = [1]
offsets_all_but_last = [1, 9]

@test Flux._splitat(inputs, offsets_good) == [[1, 2], [3, 4, 5], [6, 7, 8, 9]]
@test Flux._splitat(inputs, offsets_each) == [[1], [2], [3], [4], [5], [6], [7], [8], [9]]
@test Flux._splitat(inputs, offsets_just_one) == [[1,2,3,4,5,6,7,8,9]]
@test Flux._splitat(inputs, offsets_all_but_last) == [[1,2,3,4,5,6,7,8], [9]]

offsets_non_monotonic = [1, 2, 2, 5]
offsets_non_sorted = [1, 5, 2]
offsets_non_one = [2, 3, 5]
offsets_too_large = [1, 5, 11]

@test_throws ArgumentError Flux._splitat(inputs, offsets_non_monotonic)
@test_throws ArgumentError Flux._splitat(inputs, offsets_non_sorted)
@test_throws ArgumentError Flux._splitat(inputs, offsets_non_one)
@test_throws ArgumentError Flux._splitat(inputs, offsets_too_large)

for reduction in [sum, Statistics.mean, maximum]
vocab_size, embed_size = 10, 4
emb_bag = Flux.EmbeddingBag(vocab_size => embed_size, reduction)
@@ -333,8 +356,8 @@ import Flux: activations
# PyTorch style `input`/`offset` bagging
@test emb_bag([1,3,2,4,5,7], [1,3,5]) ≈ emb_bag([[1,3], [2,4], [5,7]])
@test emb_bag([1,3,2,4,5,7], [1,3,5]) ≈ emb_bag([1 2 5; 3 4 7])
@test_throws ArgumentError emb_bag([1,2,3,4,5,6], [2,4])
@test_throws BoundsError emb_bag([1,2,3,4,5,6], [1,12])
@test_throws ArgumentError emb_bag([1,2,3,4,5,6], [2, 4])
@test_throws ArgumentError emb_bag([1,2,3,4,5,6], [1, 12])

# docstring example
@test emb_bag([1,2,3,4,5,6,7,8,9,10], [1,5,6,8]) ≈ emb_bag([[1,2,3,4], [5], [6,7], [8,9,10]])