Skip to content

Commit

Permalink
Merge pull request #670 from ldeso/add-cartesian-embedding-methods
Browse files Browse the repository at this point in the history
Add Cartesian Embedding methods
  • Loading branch information
avik-pal authored Jun 1, 2024
2 parents 60c595e + fc020c9 commit ca23485
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 28 deletions.
36 changes: 29 additions & 7 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,8 @@ end
Embedding(in_dims => out_dims; init_weight=randn32)
A lookup table that stores embeddings of dimension `out_dims` for a vocabulary of size
`in_dims`.
`in_dims`. When the vocabulary is multi-dimensional, the input is expected to be a tuple
of Cartesian indices.
This layer is often used to store word embeddings and retrieve them using indices.
Expand All @@ -461,19 +462,22 @@ This layer is often used to store word embeddings and retrieve them using indice
## Arguments
- `in_dims`: number of input dimensions
- `in_dims`: number(s) of input dimensions
- `out_dims`: number of output dimensions
## Keyword Arguments
- `init_weight`: initializer for the weight matrix
(`weight = init_weight(rng, out_dims, in_dims)`)
(`weight = init_weight(rng, out_dims, in_dims...)`)
## Input
- Integer OR
- Abstract Vector of Integers OR
- Abstract Array of Integers
- Abstract Array of Integers OR
- Tuple of Integers OR
- Tuple of Abstract Vectors of Integers OR
- Tuple of Abstract Arrays of Integers
## Returns
Expand All @@ -482,17 +486,19 @@ This layer is often used to store word embeddings and retrieve them using indice
- Empty `NamedTuple()`
"""
@concrete struct Embedding <: AbstractExplicitLayer
in_dims::Int
in_dims
out_dims::Int
init_weight
end

function Embedding((in_dims, out_dims)::Pair{<:Integer, <:Integer}; init_weight=randn32)
function Embedding(
(in_dims, out_dims)::Pair{<:Union{Integer, NTuple{<:Any, <:Integer}}, <:Integer};
init_weight=randn32)
return Embedding(in_dims, out_dims, init_weight)
end

function initialparameters(rng::AbstractRNG, e::Embedding)
return (weight=e.init_weight(rng, e.out_dims, e.in_dims),)
return (weight=e.init_weight(rng, e.out_dims, e.in_dims...),)
end

(e::Embedding)(x::Integer, ps, st::NamedTuple) = view(ps.weight, :, x), st
Expand All @@ -502,6 +508,22 @@ end
function (e::Embedding)(x::AbstractArray{<:Integer}, ps, st::NamedTuple)
return reshape(e(vec(x), ps, st)[1], :, size(x)...), st
end
function (e::Embedding)(x::NTuple{<:Any, <:Integer}, ps, st::NamedTuple)
view(ps.weight, :, x...), st
end
function (e::Embedding)(x::NTuple{<:Any, <:AbstractVector{<:Integer}}, ps, st::NamedTuple)
sizes = size.(x)
@argcheck allequal(sizes) DimensionMismatch("Input vectors must have the same shape")
return NNlib.gather(ps.weight, x...), st
end
function (e::Embedding)(x::NTuple{<:Any, <:AbstractArray{<:Integer}}, ps, st::NamedTuple)
sizes = size.(x)
@argcheck allequal(sizes) DimensionMismatch("Input arrays must have the same shape")
return reshape(e(vec.(x), ps, st)[1], :, first(sizes)...), st
end
function (e::Embedding)(x::Tuple{}, ps, st::NamedTuple)
throw(ArgumentError("Input tuple must contain at least one element"))
end

function Base.show(io::IO, e::Embedding)
return print(io, "Embedding(", e.in_dims, " => ", e.out_dims, ")")
Expand Down
85 changes: 64 additions & 21 deletions test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,34 +334,77 @@ end
rng = get_stable_rng(12345)

@testset "$mode" for (mode, aType, device, ongpu) in MODES
vocab_size, embed_size = 10, 4
layer = Embedding(vocab_size => embed_size)
__display(layer)
ps, st = Lux.setup(rng, layer) .|> device
@testset "Linear indices" begin
vocab_size, embed_size = 10, 4
layer = Embedding(vocab_size => embed_size)
__display(layer)
ps, st = Lux.setup(rng, layer) .|> device

@test size(ps.weight) == (embed_size, vocab_size)

@test LuxCore.outputsize(layer) == (4,)

x = rand(1:vocab_size, 1)[1]
y, st_ = layer(x, ps, st)
@test size(layer(x, ps, st)[1]) == (embed_size,)
@test y == ps.weight[:, x]

@jet layer(x, ps, st)

x = rand(1:vocab_size, 3) |> aType
y, st_ = layer(x, ps, st)
@test y isa aType{Float32}
@test y == ps.weight[:, x]

@jet layer(x, ps, st)

x = rand(1:vocab_size, 3, 4) |> aType
y, st_ = layer(x, ps, st)
@test y isa aType{Float32, 3}
@test size(y) == (embed_size, 3, 4)

@jet layer(x, ps, st)
end

@testset "Cartesian indices" begin
vocab_size, embed_size = (5, 2), 4
layer = Embedding(vocab_size => embed_size)
__display(layer)
ps, st = Lux.setup(rng, layer) .|> device

@test size(ps.weight) == (embed_size, vocab_size...)

@test size(ps.weight) == (embed_size, vocab_size)
@test LuxCore.outputsize(layer) == (4,)

@test LuxCore.outputsize(layer) == (4,)
x = (rand(1:vocab_size[1], 1)[1], rand(1:vocab_size[2], 1)[1])
y, st_ = layer(x, ps, st)
@test size(layer(x, ps, st)[1]) == (embed_size,)
@test y == ps.weight[:, x...]

x = rand(1:vocab_size, 1)[1]
y, st_ = layer(x, ps, st)
@test size(layer(x, ps, st)[1]) == (embed_size,)
@test y == ps.weight[:, x]
@jet layer(x, ps, st)

x = (rand(1:vocab_size[1], 3), rand(1:vocab_size[2], 3)) .|> aType
y, st_ = layer(x, ps, st)
@test y isa aType{Float32}
@test y == ps.weight[:, CartesianIndex.(x...)]

@jet layer(x, ps, st)

@jet layer(x, ps, st)
x = (rand(1:vocab_size[1], 3, 4), rand(1:vocab_size[2], 3, 4)) .|> aType
y, st_ = layer(x, ps, st)
@test y isa aType{Float32, 3}
@test size(y) == (embed_size, 3, 4)

x = rand(1:vocab_size, 3) |> aType
y, st_ = layer(x, ps, st)
@test y isa aType{Float32}
@test y == ps.weight[:, x]
@jet layer(x, ps, st)

@jet layer(x, ps, st)
x = (rand(1:vocab_size[1], 3), rand(1:vocab_size[2], 4)) .|> aType
@test_throws DimensionMismatch layer(x, ps, st)

x = rand(1:vocab_size, 3, 4) |> aType
y, st_ = layer(x, ps, st)
@test y isa aType{Float32, 3}
@test size(y) == (embed_size, 3, 4)
x = (rand(1:vocab_size[1], 3, 4), rand(1:vocab_size[2], 4, 5)) .|> aType
@test_throws DimensionMismatch layer(x, ps, st)

@jet layer(x, ps, st)
x = ()
@test_throws ArgumentError layer(x, ps, st)
end
end
end

1 comment on commit ca23485

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: ca23485 Previous: 60c595e Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3668.125 ns 3646.75 ns 1.01
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7213.5 ns 7285 ns 0.99
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20719 ns 21210 ns 0.98
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9491.8 ns 9781.666666666666 ns 0.97
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8806 ns 9087.2 ns 0.97
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4459.5 ns 4453.888888888889 ns 1.00
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1168.8964285714287 ns 1176.2706766917292 ns 0.99
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1114.5125 ns 1112.28025477707 ns 1.00
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1189.8149606299212 ns 1189.374074074074 ns 1.00
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1789.4912280701753 ns 1814.3181818181818 ns 0.99
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 180.64739069111425 ns 179.93324061196105 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17293 ns 17212 ns 1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 17072 ns 17463 ns 0.98
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 37099 ns 36689 ns 1.01
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 28267.5 ns 28147.5 ns 1.00
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19726 ns 20058 ns 0.98
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17041 ns 16921 ns 1.01
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4381 ns 4310.5 ns 1.02
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3919.75 ns 3867.25 ns 1.01
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3968.75 ns 3951.125 ns 1.00
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4980.714285714285 ns 4787.571428571428 ns 1.04
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1651.1 ns 1659.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 38465553 ns 38839150 ns 0.99
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 57659332.5 ns 57478179 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 75839555 ns 68637336 ns 1.10
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 88548014.5 ns 80248739.5 ns 1.10
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 72339313 ns 66510498 ns 1.09
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11961965 ns 11601127 ns 1.03
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 17669780 ns 8302158.5 ns 2.13
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 6995246 ns 6958814.5 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 6962483.5 ns 6935871 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 9966425 ns 9886349 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6382304 ns 6377484 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 697509537 ns 711495815.5 ns 0.98
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2845090431 ns 2802293498 ns 1.02
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 145155933 ns 158450926 ns 0.92
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 758659457 ns 745197995 ns 1.02
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 2623555098 ns 2536517155 ns 1.03
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 200798179 ns 186814591 ns 1.07
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 699711870.5 ns 698620045 ns 1.00
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2744166224 ns 2703329300 ns 1.02
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 126860588 ns 122294200.5 ns 1.04
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 172959021 ns 172044480 ns 1.01
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 641219872 ns 643441503 ns 1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 34131572 ns 45114156 ns 0.76
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 163961135.5 ns 163454975.5 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 640443712 ns 628139701 ns 1.02
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 44002993 ns 29335904 ns 1.50
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 187456866.5 ns 207955667.5 ns 0.90
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 713774150 ns 722173872 ns 0.99
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 35709005 ns 37423155 ns 0.95
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1018397448 ns 1242027523.5 ns 0.82
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1855972375.5 ns 1847309072 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2022754963.5 ns 1988297584 ns 1.02
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2335834774.5 ns 2337208631 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1986661818 ns 1825164998 ns 1.09
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 553501401 ns 347875405.5 ns 1.59
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 315992941 ns 318366365 ns 0.99
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 317023478 ns 319738018 ns 0.99
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 347738717 ns 452781616 ns 0.77
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11815270 ns 11803413 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17834636 ns 17882962 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19192128 ns 19018033 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23877834 ns 23755630 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17885797.5 ns 17832966.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1159948 ns 1148767 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 5756297 ns 2512938 ns 2.29
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2045846 ns 2035570 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2025049 ns 2023578.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2063594 ns 2055760 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 196457 ns 195727.5 ns 1.00
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 293588 ns 288322 ns 1.02
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 266698 ns 262603 ns 1.02
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 372466.5 ns 354936.5 ns 1.05
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 411890 ns 400938 ns 1.03
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 275955 ns 270257 ns 1.02
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 409134 ns 397421 ns 1.03
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83476 ns 83306 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81713 ns 80271 ns 1.02
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 81983 ns 80581 ns 1.02
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 87163 ns 85480 ns 1.02
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104885 ns 104617 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 194040756 ns 187932820.5 ns 1.03
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 322670128 ns 321827872.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 389461276 ns 393773632.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 444888078.5 ns 454117809 ns 0.98
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 370487396.5 ns 366877761 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 322523445.5 ns 309426428 ns 1.04
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 100716976.5 ns 51303991 ns 1.96
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 43623492 ns 43675671.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 43429429 ns 43447693 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 49450779 ns 49289683 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 29173706 ns 28489085 ns 1.02
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 18903962 ns 18511523 ns 1.02
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19514972 ns 19373919.5 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23289075 ns 22860858 ns 1.02
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24083888 ns 23821494.5 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19573855.5 ns 19452776.5 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6522431 ns 6471809.5 ns 1.01
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6516955 ns 6467840.5 ns 1.01
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6494018.5 ns 6458192 ns 1.01
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6498216.5 ns 6475071.5 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.