From a1e83656b98738662b03d9b8cb197784787cf8e8 Mon Sep 17 00:00:00 2001 From: CarloLucibello <carlo.lucibello@gmail.com> Date: Sun, 5 Mar 2023 12:34:35 +0100 Subject: [PATCH] cleanup --- test.jl | 37 ------------------------------------- 1 file changed, 37 deletions(-) delete mode 100644 test.jl diff --git a/test.jl b/test.jl deleted file mode 100644 index c51a899fd5..0000000000 --- a/test.jl +++ /dev/null @@ -1,37 +0,0 @@ -using Flux, Test - - -@testset "attention" begin - dim = 4; nheads = 2; len = 3; batch_size = 5 - mha = MultiHeadAttention(dim, nheads) - q = rand(Float32, (dim, len, batch_size)) - k = rand(Float32, (dim, len, batch_size)) - v = rand(Float32, (dim, len, batch_size)) - - y, α = mha(q, k, v, withscores=true) - @test y isa Array{Float32, 3} - @test size(y) == (dim, len, batch_size) - @test α isa Array{Float32, 4} - @test size(α) == (len, len, nheads, batch_size) - - @testset "self-attention" begin - y1 = mha(q) - y2 = mha(q, q, q) - @test y1 ≈ y2 - end - - @testset "key and value are the same" begin - y1 = mha(q, k) - y2 = mha(q, k, k) - @test y1 ≈ y2 - end - - @testset "change dims" begin - dims = 4 => 10 => 5 - nhead = 5 - mha2 = MultiHeadAttention(dims, nheads) - y2 = mha2(q, k, v) - @test size(y2) == (dims.second.second, len, batch_size) - end -end -