From bcecc6301685a8e34a8c51b02f1bdecd696f44e1 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Sun, 16 Jun 2024 11:13:08 -0400 Subject: [PATCH] wip --- test/MatrixFields/field_matrix_solvers.jl | 2 +- test/MatrixFields/matrix_field_test_utils.jl | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/MatrixFields/field_matrix_solvers.jl b/test/MatrixFields/field_matrix_solvers.jl index 28474db552..95dae9f708 100644 --- a/test/MatrixFields/field_matrix_solvers.jl +++ b/test/MatrixFields/field_matrix_solvers.jl @@ -55,7 +55,7 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false) # from CUBLAS (norm), KrylovKit (eigsolve), and CoreLogging (@debug). ignored = ( ignore_cuda..., - using_cuda ? AnyFrameModule(CUDA.CUBLAS) : + using_cuda ? AnyFrameModule(eval(Meta.parse("CUDA.CUBLAS"))) : AnyFrameModule(MatrixFields.KrylovKit), AnyFrameModule(Base.CoreLogging), ) diff --git a/test/MatrixFields/matrix_field_test_utils.jl b/test/MatrixFields/matrix_field_test_utils.jl index e60b157109..3ed042e323 100644 --- a/test/MatrixFields/matrix_field_test_utils.jl +++ b/test/MatrixFields/matrix_field_test_utils.jl @@ -1,9 +1,9 @@ using Test using JET -import CUDA import Random: seed! import ClimaComms +ClimaComms.@import_required_backends import ClimaCore: Geometry, Domains, @@ -45,7 +45,8 @@ const comms_device = ClimaComms.device() # comms_device = ClimaComms.CPUSingleThreaded() @show comms_device const using_cuda = comms_device isa ClimaComms.CUDADevice -const ignore_cuda = using_cuda ? (AnyFrameModule(CUDA),) : () +const ignore_cuda = + using_cuda ? (AnyFrameModule(eval(Meta.parse("CUDA"))),) : () # Test the allocating and non-allocating versions of a field broadcast against # a reference non-allocating implementation. Ensure that they are performant, @@ -63,7 +64,7 @@ function test_field_broadcast(; ) where {F1, F2, F3} @testset "$test_name" begin if test_broken_with_cuda && using_cuda - @test_throws CUDA.InvalidIRError get_result() + @test_throws eval(Meta.parse("CUDA")).InvalidIRError get_result() @warn "$test_name:\n\tCUDA.InvalidIRError" return end @@ -133,7 +134,7 @@ function test_field_broadcast_against_array_reference(; ) where {F1, F2, F3} @testset "$test_name" begin if test_broken_with_cuda && using_cuda - @test_throws CUDA.InvalidIRError get_result() + @test_throws eval(Meta.parse("CUDA")).InvalidIRError get_result() @warn "$test_name:\n\tCUDA.InvalidIRError" return end