diff --git a/test/runtests.jl b/test/runtests.jl index d5561bb92..530219c83 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,10 @@ macro timeit_include(path::AbstractString) end @testset "Turing" begin + @testset "Test utils" begin + @timeit_include("test_utils/test_utils.jl") + end + @testset "Aqua" begin @timeit_include("Aqua.jl") end diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index f7358de75..2c01dc524 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -229,44 +229,6 @@ function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, sampler, righ return logp, vi end -# Check that the ADTypeCheckContext works as expected. -Test.@testset "ADTypeCheckContext" begin - Turing.@model test_model() = x ~ Turing.Normal(0, 1) - tm = test_model() - adtypes = ( - Turing.AutoForwardDiff(), - Turing.AutoReverseDiff(), - Turing.AutoZygote(), - # TODO: Mooncake - # Turing.AutoMooncake(config=nothing), - ) - for actual_adtype in adtypes - sampler = Turing.HMC(0.1, 5; adtype=actual_adtype) - for expected_adtype in adtypes - if ( - actual_adtype == Turing.AutoForwardDiff() && - expected_adtype == Turing.AutoZygote() - ) - # TODO(mhauru) We are currently unable to check this case. - continue - end - contextualised_tm = DynamicPPL.contextualize( - tm, ADTypeCheckContext(expected_adtype, tm.context) - ) - Test.@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin - if actual_adtype == expected_adtype - # Check that this does not throw an error. - Turing.sample(contextualised_tm, sampler, 2) - else - Test.@test_throws AbstractWrongADBackendError Turing.sample( - contextualised_tm, sampler, 2 - ) - end - end - end - end -end - # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # List of AD backends to test. diff --git a/test/test_utils/test_utils.jl b/test/test_utils/test_utils.jl new file mode 100644 index 000000000..bf9f2b9b8 --- /dev/null +++ b/test/test_utils/test_utils.jl @@ -0,0 +1,50 @@ +"""Module for testing the test utils themselves.""" +module TestUtilsTests + +using ..ADUtils: ADTypeCheckContext, AbstractWrongADBackendError +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff +using Test: @test, @testset, @test_throws +using Turing: Turing +using Turing: DynamicPPL +using Zygote: Zygote + +# Check that the ADTypeCheckContext works as expected. +@testset "ADTypeCheckContext" begin + Turing.@model test_model() = x ~ Turing.Normal(0, 1) + tm = test_model() + adtypes = ( + Turing.AutoForwardDiff(), + Turing.AutoReverseDiff(), + Turing.AutoZygote(), + # TODO: Mooncake + # Turing.AutoMooncake(config=nothing), + ) + for actual_adtype in adtypes + sampler = Turing.HMC(0.1, 5; adtype=actual_adtype) + for expected_adtype in adtypes + if ( + actual_adtype == Turing.AutoForwardDiff() && + expected_adtype == Turing.AutoZygote() + ) + # TODO(mhauru) We are currently unable to check this case. + continue + end + contextualised_tm = DynamicPPL.contextualize( + tm, ADTypeCheckContext(expected_adtype, tm.context) + ) + @testset "Expected: $expected_adtype, Actual: $actual_adtype" begin + if actual_adtype == expected_adtype + # Check that this does not throw an error. + Turing.sample(contextualised_tm, sampler, 2) + else + @test_throws AbstractWrongADBackendError Turing.sample( + contextualised_tm, sampler, 2 + ) + end + end + end + end +end + +end