From d060a84fbbc9481cd467e7540a198a1eff75bf7d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 7 Oct 2024 22:29:58 -0400 Subject: [PATCH] test: mark truncated normal on Metal as unbroken --- Project.toml | 2 +- test/initializers_tests.jl | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 308235c..dd2e473 100644 --- a/Project.toml +++ b/Project.toml @@ -37,7 +37,7 @@ ConcreteStructs = "0.2.3" GPUArraysCore = "0.1.6" GPUArrays = "10.2" LinearAlgebra = "1.10" -Metal = "1.1.0" +Metal = "1.3.0" Random = "1.10" SpecialFunctions = "2.4" Statistics = "1.10" diff --git a/test/initializers_tests.jl b/test/initializers_tests.jl index f3a5a0e..8f09f3a 100644 --- a/test/initializers_tests.jl +++ b/test/initializers_tests.jl @@ -154,7 +154,7 @@ end init === randn32) && continue - if (backend == "oneapi" || backend == "metal") && init === truncated_normal + if backend == "oneapi" && init === truncated_normal @test_broken size(init(rng, 3)) == (3,) # `erfinv` not implemented continue end @@ -229,9 +229,7 @@ end init === truncated_normal && !(T <: Real) && continue - if (backend == "oneapi" || backend == "metal") && - init === truncated_normal && - T == Float32 + if backend == "oneapi" && init === truncated_normal && T == Float32 @test_broken init(rng, T, 3) isa AbstractArray{T, 1} # `erfinv` not implemented continue end @@ -261,7 +259,7 @@ end @testset "Closure: $init" for init in [ kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, truncated_normal, identity_init] - if (backend == "oneapi" || backend == "metal") && init === truncated_normal + if backend == "oneapi" && init === truncated_normal @test_broken size(init(rng, 3)) == (3,) # `erfinv` not implemented continue end