diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index cbd896337..bb894b8a0 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -20,6 +20,7 @@ using Test: @test, @test_broken, @test_logs, @testset, @test_throws using Turing @testset "Testing hmc.jl with $adbackend" for adbackend in ADUtils.adbackends + @info "Running HMC tests with $adbackend" # Set a seed rng = StableRNG(123) @testset "constrained bounded" begin @@ -332,12 +333,16 @@ using Turing end @testset "Check ADType" begin - alg = HMC(0.1, 10; adtype=adbackend) - m = DynamicPPL.contextualize( - gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context) - ) - # These will error if the adbackend being used is not the one set. - sample(rng, m, alg, 10) + # These tests don't make sense for Enzyme, since it does not use a particular element + # type. + if !(adbackend isa AutoEnzyme) + alg = HMC(0.1, 10; adtype=adbackend) + m = DynamicPPL.contextualize( + gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context) + ) + # These will error if the adbackend being used is not the one set. + @test (sample(rng, m, alg, 10); true) + end end end diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index 3ba6a73d6..231c9ec88 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -1,11 +1,11 @@ module ADUtils +using Enzyme: Enzyme using ForwardDiff: ForwardDiff using Pkg: Pkg using Random: Random using ReverseDiff: ReverseDiff using Mooncake: Mooncake -using Test: Test using Turing: Turing using Turing: DynamicPPL using Zygote: Zygote @@ -239,7 +239,10 @@ adbackends = [ Turing.AutoForwardDiff(; chunksize=0), Turing.AutoReverseDiff(; compile=false), Turing.AutoMooncake(; config=nothing), - Turing.AutoEnzyme(), + # TODO(mhauru) Do we want to run both? For now yes, while building up Enzyme + # integration, but in the long term maybe not? + Turing.AutoEnzyme(; mode=Enzyme.Forward), + Turing.AutoEnzyme(; mode=Enzyme.Reverse), ] end diff --git a/test/test_utils/test_utils.jl b/test/test_utils/test_utils.jl index bf9f2b9b8..b29ae6226 100644 --- a/test/test_utils/test_utils.jl +++ b/test/test_utils/test_utils.jl @@ -1,7 +1,7 @@ """Module for testing the test utils themselves.""" module TestUtilsTests -using ..ADUtils: ADTypeCheckContext, AbstractWrongADBackendError +using ..ADUtils: ADTypeCheckContext, AbstractWrongADBackendError, adbackends using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff using Test: @test, @testset, @test_throws @@ -13,12 +13,11 @@ using Zygote: Zygote @testset "ADTypeCheckContext" begin Turing.@model test_model() = x ~ Turing.Normal(0, 1) tm = test_model() - adtypes = ( - Turing.AutoForwardDiff(), - Turing.AutoReverseDiff(), - Turing.AutoZygote(), - # TODO: Mooncake - # Turing.AutoMooncake(config=nothing), + # These tests don't make sense for Enzyme, since it doesn't have its own element type. + # TODO(mhauru): Make these tests work for more Mooncake. + adtypes = filter( + adtype -> !(adtype isa Turing.AutoMooncake || adtype isa Turing.AutoEnzyme), + adbackends, ) for actual_adtype in adtypes sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)