Skip to content

Commit

Permalink
Fix memoization issue (#1414)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Sep 24, 2020
1 parent e6430f1 commit 96a79f3
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.14.3"
version = "0.14.4"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
20 changes: 8 additions & 12 deletions src/core/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function gradient_logp(
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
T = typeof(getlogp(vi))

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
Expand All @@ -46,12 +46,10 @@ end
@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin
setrdcache(::Val{true}) = RDCache[] = true
function emptyrdcache()
for k in keys(Memoization.caches)
if k[1] === typeof(memoized_taperesult)
pop!(Memoization.caches, k)
end
end
Memoization.empty_cache!(memoized_taperesult)
return
end

function gradient_logp(
backend::ReverseDiffAD{true},
θ::AbstractVector{<:Real},
Expand All @@ -61,7 +59,7 @@ end
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
T = typeof(getlogp(vi))

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
Expand All @@ -81,15 +79,13 @@ end
f::F
x::Tx
end
function Memoization._get!(f::Union{Function, Type}, d::IdDict, keys::Tuple{Tuple{RDTapeKey}, Any})
function Memoization._get!(f, d::Dict, keys::Tuple{Tuple{RDTapeKey}, Any})
key = keys[1][1]
return Memoization._get!(f, d, (typeof(key.f), typeof(key.x), size(key.x)))
return Memoization._get!(f, d, (key.f, typeof(key.x), size(key.x), Threads.threadid()))
end
memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x))
Memoization.@memoize function memoized_taperesult(k::RDTapeKey)
Memoization.@memoize Dict function memoized_taperesult(k::RDTapeKey)
return compiledtape(k.f, k.x), GradientResult(k.x)
end
memoized_tape(f, x) = memoized_tape(RDTapeKey(f, x))
Memoization.@memoize memoized_tape(k::RDTapeKey) = compiledtape(k.f, k.x)
compiledtape(f, x) = compile(GradientTape(f, x))
end
28 changes: 26 additions & 2 deletions test/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,13 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
sample(dir(), HMC(0.01, 1), 1000);
Turing.setrdcache(true)
sample(dir(), HMC(0.01, 1), 1000);
@test length(Memoization.caches) == 1
caches = Memoization.find_caches(Turing.Core.memoized_taperesult)
@test length(caches) == 1
@test !isempty(first(values(caches)))
Turing.emptyrdcache()
@test length(Memoization.caches) == 0
caches = Memoization.find_caches(Turing.Core.memoized_taperesult)
@test length(caches) == 1
@test isempty(first(values(caches)))
end
# FIXME: For some reasons PDMatDistribution AD tests fail with ReverseDiff
@testset "PDMatDistribution AD" begin
Expand Down Expand Up @@ -340,4 +344,24 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
@test H_f == [1.0 0.0; 0.0 1.0]
@test H_f == H_r
end

@testset "memoization: issue #1393" begin
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)

@model function demo(data)
sigma ~ Uniform(0.0, 20.0)
data ~ Normal(0, sigma)
end

N = 1000
for i in 1:5
d = Normal(0.0, i)
data = rand(d, N)
chn = sample(demo(data), NUTS(0.65), 1000)
@test mean(Array(chn[:sigma])) std(data) atol=0.5
end

Turing.emptyrdcache()
end
end
7 changes: 1 addition & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,8 @@ include("test_utils/AllUtils.jl")
include("core/container.jl")
end

test_adbackends = if VERSION >= v"1.2"
[:forwarddiff, :tracker, :reversediff]
else
[:forwarddiff, :tracker]
end
Turing.setrdcache(false)
for adbackend in test_adbackends
for adbackend in (:forwarddiff, :tracker, :reversediff)
Turing.setadbackend(adbackend)
@testset "inference: $adbackend" begin
@testset "samplers" begin
Expand Down

2 comments on commit 96a79f3

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/21909

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.4 -m "<description of version>" 96a79f31db303d63ce815c1d267aa4c2864c06d5
git push origin v0.14.4

Please sign in to comment.