diff --git a/Project.toml b/Project.toml index 4c1aa8419..417db886e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.14.3" +version = "0.14.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/core/compat/reversediff.jl b/src/core/compat/reversediff.jl index 430c072e0..513ee9694 100644 --- a/src/core/compat/reversediff.jl +++ b/src/core/compat/reversediff.jl @@ -23,7 +23,7 @@ function gradient_logp( context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() ) T = typeof(getlogp(vi)) - + # Specify objective function. function f(θ) new_vi = VarInfo(vi, sampler, θ) @@ -46,12 +46,10 @@ end @require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin setrdcache(::Val{true}) = RDCache[] = true function emptyrdcache() - for k in keys(Memoization.caches) - if k[1] === typeof(memoized_taperesult) - pop!(Memoization.caches, k) - end - end + Memoization.empty_cache!(memoized_taperesult) + return end + function gradient_logp( backend::ReverseDiffAD{true}, θ::AbstractVector{<:Real}, @@ -61,7 +59,7 @@ end context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() ) T = typeof(getlogp(vi)) - + # Specify objective function. function f(θ) new_vi = VarInfo(vi, sampler, θ) @@ -81,15 +79,13 @@ end f::F x::Tx end - function Memoization._get!(f::Union{Function, Type}, d::IdDict, keys::Tuple{Tuple{RDTapeKey}, Any}) + function Memoization._get!(f, d::Dict, keys::Tuple{Tuple{RDTapeKey}, Any}) key = keys[1][1] - return Memoization._get!(f, d, (typeof(key.f), typeof(key.x), size(key.x))) + return Memoization._get!(f, d, (key.f, typeof(key.x), size(key.x), Threads.threadid())) end memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x)) - Memoization.@memoize function memoized_taperesult(k::RDTapeKey) + Memoization.@memoize Dict function memoized_taperesult(k::RDTapeKey) return compiledtape(k.f, k.x), GradientResult(k.x) end - memoized_tape(f, x) = memoized_tape(RDTapeKey(f, x)) - Memoization.@memoize memoized_tape(k::RDTapeKey) = compiledtape(k.f, k.x) compiledtape(f, x) = compile(GradientTape(f, x)) end diff --git a/test/core/ad.jl b/test/core/ad.jl index c2a73cee8..6fbfb388e 100644 --- a/test/core/ad.jl +++ b/test/core/ad.jl @@ -276,9 +276,13 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...) sample(dir(), HMC(0.01, 1), 1000); Turing.setrdcache(true) sample(dir(), HMC(0.01, 1), 1000); - @test length(Memoization.caches) == 1 + caches = Memoization.find_caches(Turing.Core.memoized_taperesult) + @test length(caches) == 1 + @test !isempty(first(values(caches))) Turing.emptyrdcache() - @test length(Memoization.caches) == 0 + caches = Memoization.find_caches(Turing.Core.memoized_taperesult) + @test length(caches) == 1 + @test isempty(first(values(caches))) end # FIXME: For some reasons PDMatDistribution AD tests fail with ReverseDiff @testset "PDMatDistribution AD" begin @@ -340,4 +344,24 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...) @test H_f == [1.0 0.0; 0.0 1.0] @test H_f == H_r end + + @testset "memoization: issue #1393" begin + Turing.setadbackend(:reversediff) + Turing.setrdcache(true) + + @model function demo(data) + sigma ~ Uniform(0.0, 20.0) + data ~ Normal(0, sigma) + end + + N = 1000 + for i in 1:5 + d = Normal(0.0, i) + data = rand(d, N) + chn = sample(demo(data), NUTS(0.65), 1000) + @test mean(Array(chn[:sigma])) ≈ std(data) atol=0.5 + end + + Turing.emptyrdcache() + end end diff --git a/test/runtests.jl b/test/runtests.jl index f06abc371..12abd5e1d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,13 +15,8 @@ include("test_utils/AllUtils.jl") include("core/container.jl") end - test_adbackends = if VERSION >= v"1.2" - [:forwarddiff, :tracker, :reversediff] - else - [:forwarddiff, :tracker] - end Turing.setrdcache(false) - for adbackend in test_adbackends + for adbackend in (:forwarddiff, :tracker, :reversediff) Turing.setadbackend(adbackend) @testset "inference: $adbackend" begin @testset "samplers" begin