From 8285d3f5bfe10d15fecd0176153970350265a687 Mon Sep 17 00:00:00 2001 From: Bhavay Malhotra Date: Sun, 31 Dec 2023 00:02:48 +0530 Subject: [PATCH 1/4] Added test for Alpha Dropout --- test/dropout.jl | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 test/dropout.jl diff --git a/test/dropout.jl b/test/dropout.jl new file mode 100644 index 0000000000..459be4f7fa --- /dev/null +++ b/test/dropout.jl @@ -0,0 +1,44 @@ + +using Statistics +using Flux +using Test + +#initial x value +# x = randn32(1000,1); +# x = [1,2,3,4,5] + +# Mean +# E(xd + alpha(1-d)) = qu + (1-q)alpha +a_ = -1.7580993408473766 +d = 0.2 +q = 0.2 +u = mean(x) + +function mean_test(x) + # LHS + mean_left = (x*d) .+ (a_*(1-d)) + mean_left = mean(mean_left) + # println(mean_left) + + # RHS + mean_right = (q*u) .+ ((1-q)*a_) + # println(mean_right) + @test isapprox(mean_left, mean_right, atol=0.2) +end + +x = randn(2000,1); +@testset "Alphadropout Tests" begin + mean_test(x); +end + + +# Variance +# Var(xd + alpha(1-d)) = q((1-q)(alpha-u)^2 + v) +# v = var(x) + +# var_left = (x*d) .+ a_*(1-d) +# var_left = var(var_left) + +# var_right = q*((1-q)*(a_-u).^2 + v) + +# @test isapprox(var_left, var_right, atol=0.1) From 2125a95ab25f8c721ef2afbdfffd250fa70da122 Mon Sep 17 00:00:00 2001 From: Bhavay Malhotra Date: Sun, 31 Dec 2023 17:48:52 +0530 Subject: [PATCH 2/4] Added additional tests for AlphaDropout --- test/layers/normalisation.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 35f11a4adc..7307f3ae85 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -1,7 +1,7 @@ using Flux, Test, Statistics, Random using Zygote: pullback, ForwardDiff -evalwgrad(f, x...) = pullback(f, x...)[1] +global evalwgrad(f, x...) = pullback(f, x...)[1] @testset "Dropout" begin @testset for rng_kwargs in ((), (; rng = MersenneTwister())) @@ -88,12 +88,19 @@ end x = randn(1000) # large enough to prevent flaky test m = AlphaDropout(0.5; rng_kwargs...) + q = 0.5 + u = mean(x) + α′ = -1.7580993408473766 y = evalwgrad(m, x) # Should preserve unit mean and variance @test mean(y) ≈ 0 atol=0.2 @test var(y) ≈ 1 atol=0.2 + # Should check that the mean and variance matches the formula + # E(xd + α′(1-d)) = qu + (1-q)α′ + @test mean(y) ≈ (q*u) + ((1-q)*α′) + testmode!(m, true) # should override istraining @test evalwgrad(m, x) == x From cb8317bb4bfb6296fcfbd4d72688d8b546d4f456 Mon Sep 17 00:00:00 2001 From: Bhavay Malhotra Date: Thu, 4 Jan 2024 19:33:20 +0530 Subject: [PATCH 3/4] Changed normalisation.jl --- test/layers/normalisation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 7307f3ae85..5e68a22350 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -1,7 +1,7 @@ using Flux, Test, Statistics, Random using Zygote: pullback, ForwardDiff -global evalwgrad(f, x...) = pullback(f, x...)[1] +evalwgrad(f, x...) = pullback(f, x...)[1] @testset "Dropout" begin @testset for rng_kwargs in ((), (; rng = MersenneTwister())) From 82cdf421133051384adaa53d86d3001a09fb452f Mon Sep 17 00:00:00 2001 From: Bhavay Malhotra <56443877+Bhavay-2001@users.noreply.github.com> Date: Sat, 6 Jan 2024 19:08:38 +0530 Subject: [PATCH 4/4] Delete test/dropout.jl --- test/dropout.jl | 44 -------------------------------------------- 1 file changed, 44 deletions(-) delete mode 100644 test/dropout.jl diff --git a/test/dropout.jl b/test/dropout.jl deleted file mode 100644 index 459be4f7fa..0000000000 --- a/test/dropout.jl +++ /dev/null @@ -1,44 +0,0 @@ - -using Statistics -using Flux -using Test - -#initial x value -# x = randn32(1000,1); -# x = [1,2,3,4,5] - -# Mean -# E(xd + alpha(1-d)) = qu + (1-q)alpha -a_ = -1.7580993408473766 -d = 0.2 -q = 0.2 -u = mean(x) - -function mean_test(x) - # LHS - mean_left = (x*d) .+ (a_*(1-d)) - mean_left = mean(mean_left) - # println(mean_left) - - # RHS - mean_right = (q*u) .+ ((1-q)*a_) - # println(mean_right) - @test isapprox(mean_left, mean_right, atol=0.2) -end - -x = randn(2000,1); -@testset "Alphadropout Tests" begin - mean_test(x); -end - - -# Variance -# Var(xd + alpha(1-d)) = q((1-q)(alpha-u)^2 + v) -# v = var(x) - -# var_left = (x*d) .+ a_*(1-d) -# var_left = var(var_left) - -# var_right = q*((1-q)*(a_-u).^2 + v) - -# @test isapprox(var_left, var_right, atol=0.1)