diff --git a/test/dynamicppl/varinfo.jl b/test/dynamicppl/varinfo.jl index df6c4e40f..39c310e76 100644 --- a/test/dynamicppl/varinfo.jl +++ b/test/dynamicppl/varinfo.jl @@ -50,9 +50,9 @@ using Turing alg = HMC(0.1, 5) spl = DynamicPPL.Sampler(alg, model) v = copy(meta.vals) - DynamicPPL.link!(vi, spl) + DynamicPPL.link!!(vi, spl, model) @test all(x -> DynamicPPL.istrans(vi, x), meta.vns) - DynamicPPL.invlink!(vi, spl) + DynamicPPL.invlink!!(vi, spl, model) @test all(x -> !DynamicPPL.istrans(vi, x), meta.vns) @test meta.vals == v @@ -64,10 +64,10 @@ using Turing @test all(x -> !DynamicPPL.istrans(vi, x), meta.m.vns) v_s = copy(meta.s.vals) v_m = copy(meta.m.vals) - DynamicPPL.link!(vi, spl) + DynamicPPL.link!!(vi, spl, model) @test all(x -> DynamicPPL.istrans(vi, x), meta.s.vns) @test all(x -> DynamicPPL.istrans(vi, x), meta.m.vns) - DynamicPPL.invlink!(vi, spl) + DynamicPPL.invlink!!(vi, spl, model) @test all(x -> !DynamicPPL.istrans(vi, x), meta.s.vns) @test all(x -> !DynamicPPL.istrans(vi, x), meta.m.vns) @test meta.s.vals == v_s @@ -347,7 +347,7 @@ using Turing n = 10 model = state_space(y, length(t)) - @test size(sample(model, NUTS(; adtype=AutoReverseDiff(true)), n), 1) == n + @test size(sample(model, NUTS(; adtype=AutoReverseDiff(; compile=true)), n), 1) == n end if Threads.nthreads() > 1