Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiHeadAttention implementation #2146

Merged
merged 21 commits into from
Mar 11, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add cuda tests
CarloLucibello committed Mar 5, 2023
commit 2ecf19ba1d4a0381e50f5b3d8c176eb6e171ef28
26 changes: 26 additions & 0 deletions test/cuda/layers.jl
Original file line number Diff line number Diff line change
@@ -338,3 +338,29 @@ end
@test eltype(pool(reshape(gx,3,4,1))) == Float16
end
end

@testset "MultiHeadAttention" begin
dim = 4; nheads = 2; len = 3; batch_size = 5
mha_cpu = MultiHeadAttention(dim; nheads)
x_cpu = rand(Float32, (dim, len, batch_size))
y_cpu, α_cpu = mha_cpu(x_cpu, withscores=true)

mha_gpu = mha_cpu |> gpu
x_gpu = x_cpu |> gpu
y_gpu, α_gpu = mha_gpu(x_gpu, withscores=true)
@test y_gpu isa CuArray{Float32}
@test α_gpu isa CuArray{Float32}
@test Array(y_gpu) ≈ y_cpu atol=1e-4
@test Array(α_gpu) ≈ α_cpu atol=1e-4

gm_cpu, gx_cpu = gradient(mha_cpu, x_cpu) do mha, x
y, α = mha(x, withscores=true)
return sum(y.^2) + sum(α.^2)
end
gm_gpu, gx_gpu = gradient(mha_gpu, x_gpu) do mha, x
y, α = mha(x, withscores=true)
return sum(y.^2) + sum(α.^2)
end
test_grad_equal(gm_gpu, gm_cpu)
test_grad_equal(gx_gpu, gx_cpu)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@ Random.seed!(0)
end

@testset "Layers" begin
include("layers/attention.jl")
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
13 changes: 13 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -96,3 +96,16 @@ function test_grad_type(g::NamedTuple, x::T) where T
test_grad_type(g[f], getfield(x, f))
end
end

test_grad_equal(g1::Nothing, g2::Nothing) = nothing

function test_grad_equal(g1::AnyCuArray{T}, g2::Array{T}; atol=1e-4) where T
@test Array(g1) ≈ g2 atol=atol
end

function test_grad_equal(g1::T1, g2::T2) where {T1 <: NamedTuple, T2 <: NamedTuple}
@test fieldnames(T1) == fieldnames(T2)
for f in fieldnames(T1)
test_grad_equal(g1[f], g2[f])
end
end